| @@ -553,7 +553,7 @@ class _CellGraphExecutor: | |||
| """compile graph in auto parallel mode.""" | |||
| if not auto_parallel_mode: | |||
| replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) | |||
| self._updata_param_node_default_input(phase, replace) | |||
| self._update_param_node_default_input(phase, replace) | |||
| return | |||
| obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase) | |||
| @@ -564,13 +564,13 @@ class _CellGraphExecutor: | |||
| if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"): | |||
| obj.load_parameter_slice(None) | |||
| self._updata_param_node_default_input(phase, replace) | |||
| self._update_param_node_default_input(phase, replace) | |||
| # set parallel inputs in sink mode | |||
| if is_sink_mode: | |||
| obj.set_parallel_input_with_inputs(*args) | |||
| def _updata_param_node_default_input(self, phase, replace): | |||
| def _update_param_node_default_input(self, phase, replace): | |||
| new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} | |||
| return self._graph_executor.updata_param_node_default_input(phase, new_param) | |||
| @@ -332,15 +332,14 @@ class LeakyReLU(Cell): | |||
| validator.check_value_type('alpha', alpha, [float, int], self.cls_name) | |||
| self.greater_equal = P.GreaterEqual() | |||
| self.mul = P.Mul() | |||
| self.maximum = P.Maximum() | |||
| self.alpha = alpha | |||
| self.select_op = P.Maximum() | |||
| if self.alpha > 1: | |||
| self.select_op = P.Minimum() | |||
| def construct(self, x): | |||
| alpha_array = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x)) | |||
| if self.alpha <= 1: | |||
| out = self.maximum(alpha_array * x, x) | |||
| else: | |||
| out = self.maximum(alpha_array * x, x) | |||
| out = self.select_op(alpha_array * x, x) | |||
| return out | |||
| @@ -149,9 +149,9 @@ def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"): | |||
| if not isinstance(data_size, (int, float)): | |||
| raise TypeError('data_size in data_size_list is invalid') | |||
| c_array_sizeList = _c_array(ctypes.c_float, data_size_list) | |||
| c_array_size_list = _c_array(ctypes.c_float, data_size_list) | |||
| c_size_num = ctypes.c_uint(len(data_size_list)) | |||
| c_group = _c_str(group) | |||
| ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList) | |||
| ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_size_list) | |||
| if ret != 0: | |||
| raise RuntimeError('Allreduce split error') | |||
| @@ -30,9 +30,11 @@ def _get_parallel_mode(): | |||
| """Get parallel mode.""" | |||
| return auto_parallel_context().get_parallel_mode() | |||
| def _is_in_auto_parallel_mode(): | |||
| return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL] | |||
| def _get_full_batch(): | |||
| """Get whether to use full_batch.""" | |||
| return auto_parallel_context().get_full_batch() | |||
| @@ -51,12 +53,8 @@ def _check_task_sink_envs(): | |||
| """ | |||
| import os | |||
| task_sink = os.getenv("GRAPH_OP_RUN") | |||
| if task_sink: | |||
| try: | |||
| if int(task_sink) == 1: | |||
| return False | |||
| except ValueError: | |||
| return True | |||
| if task_sink and task_sink.isdigit() and int(task_sink) == 1: | |||
| return False | |||
| return True | |||
| @@ -18,9 +18,9 @@ NOTE: | |||
| This is an experimental interface that is subject to change and/or deletion. | |||
| """ | |||
| from .transformer import * | |||
| from .layers import * | |||
| from .loss import * | |||
| from .op_parallel_config import * | |||
| from .layers import FixedSparseAttention | |||
| from .loss import CrossEntropyLoss | |||
| from .op_parallel_config import OpParallelConfig | |||
| __all__ = [] | |||
| __all__.extend(transformer.__all__) | |||
| @@ -73,6 +73,7 @@ def _valid_type_checks(types, class_name): | |||
| # as the input of _args_type_validator_check is fixed, so we need to manually change the input order | |||
| partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) | |||
| return partial_check(name, type(value)) | |||
| return validator_check_func | |||
| @@ -83,6 +84,7 @@ def _valid_value_checks(types, class_name): | |||
| # as the input of _args_type_validator_check is fixed, so we need to manually change the input order | |||
| partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) | |||
| return partial_check(name, value) | |||
| return validator_check_func | |||
| @@ -334,7 +336,7 @@ class _Linear(Cell): | |||
| if self.activation_flag: | |||
| # some operations has many primitives, need to manually set the shard | |||
| if self.act_name.lower() == "leakyrelu": | |||
| self.activation.maximum.shard((strategy_activation[0], strategy_activation[0])) | |||
| self.activation.select_op.shard((strategy_activation[0], strategy_activation[0])) | |||
| elif self.act_name.lower() == "logsigmoid": | |||
| self.activation.mul.shard((strategy_activation[0], ())) | |||
| self.activation.exp.shard(strategy_activation) | |||
| @@ -402,6 +404,7 @@ class FixedSparseAttention(nn.Cell): | |||
| >>> print(output.shape) | |||
| (2, 1024, 512) | |||
| """ | |||
| @_args_type_validator_check(batch_size=Validator.check_positive_int, | |||
| num_heads=Validator.check_positive_int, | |||
| size_per_head=Validator.check_positive_int, | |||
| @@ -437,7 +440,7 @@ class FixedSparseAttention(nn.Cell): | |||
| self.transpose = P.Transpose().shard(((dp, 1, mp, 1),)) | |||
| self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) | |||
| self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1))) | |||
| self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32) | |||
| self.multiply_data = Tensor([-10000.0], dtype=mstype.float32) | |||
| self.parallel_config = parallel_config | |||
| size_per_head_list = [64, 128] | |||
| if self.seq_length != 1024: | |||
| @@ -460,7 +463,7 @@ class FixedSparseAttention(nn.Cell): | |||
| global_mask_original = -10000 * global_mask_original | |||
| global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16)) | |||
| global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3)) | |||
| global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :,], self.batch_size, axis=0) | |||
| global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0) | |||
| global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16)) | |||
| self.global_mask = Tensor(global_mask, mstype.float32) | |||
| self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32) | |||
| @@ -578,6 +581,7 @@ class _CumSum(Cell): | |||
| Outputs: | |||
| Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. | |||
| """ | |||
| def __init__(self, config): | |||
| super(_CumSum, self).__init__() | |||
| dp = config.data_parallel | |||
| @@ -598,14 +602,13 @@ class _CumSum(Cell): | |||
| self.delta = Tensor(1, mstype.int32) | |||
| self.add = P.TensorAdd().shard(((1,), ())) | |||
| def construct(self, expert_mask): | |||
| # origin_shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # origin_shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| origin_shape = self.shape(expert_mask) | |||
| tokens_per_device = origin_shape[1] | |||
| # expert_mask_trans's shape: (self.expert_parallel, self.expert_dim, tokens_per_device) | |||
| # expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device) | |||
| expert_mask_trans = self.transpose(expert_mask, (0, 2, 1)) | |||
| # expert_mask_reshaped's shape: (self.expert_parallel*self.expert_dim, tokens_per_device) | |||
| # expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device) | |||
| expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device)) | |||
| one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0) | |||
| @@ -614,11 +617,11 @@ class _CumSum(Cell): | |||
| up_tri_matrix = self.greater(one_dim, other_dim) | |||
| up_tri_matrix = self.cast(up_tri_matrix, mstype.float32) | |||
| # cum_sum's shape: (self.expert_parallel*self.expert_dim, tokens_per_device) | |||
| # cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device) | |||
| cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix) | |||
| # cum_sum's shape: (self.expert_parallel, self.expert_dim, tokens_per_device) | |||
| # cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device) | |||
| cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device)) | |||
| # cum_sum's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| cum_sum = self.transpose3(cum_sum, (0, 2, 1)) | |||
| return cum_sum | |||
| @@ -646,6 +649,7 @@ class Router(Cell): | |||
| Outputs: | |||
| Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. | |||
| """ | |||
| def __init__(self, | |||
| d_model, | |||
| moe_config, | |||
| @@ -704,6 +708,7 @@ class SwitchRouter(Cell): | |||
| Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, | |||
| Tensor of shape :math:`(1)`. | |||
| """ | |||
| def __init__(self, | |||
| d_model, | |||
| moe_config, | |||
| @@ -752,9 +757,9 @@ class SwitchRouter(Cell): | |||
| """ | |||
| Computing the load balance loss. | |||
| """ | |||
| # density_1's shape: (self.expert_parallel, self.expert_dim) | |||
| # density_1's shape: (expert_parallel, self.expert_dim) | |||
| density_1 = self.reduce_mean(expert_mask, 1) | |||
| # density_1_proxy's shape: (self.expert_parallel, self.expert_dim) | |||
| # density_1_proxy's shape: (expert_parallel, self.expert_dim) | |||
| density_1_proxy = self.reduce_mean2(router_prob, 1) | |||
| loss = self.mul(density_1, density_1_proxy) | |||
| loss = self.reduce_mean3(loss) | |||
| @@ -766,20 +771,19 @@ class SwitchRouter(Cell): | |||
| Keeping only the tokens that fit within expert_capacity. | |||
| """ | |||
| cumsum = self.cumsum(expert_mask) | |||
| # position_in_expert's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| position_in_expert = self.mul4(cumsum, expert_mask) | |||
| less_result = self.less(position_in_expert, expert_capacity) | |||
| # expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| expert_mask = self.mul5(less_result, expert_mask) | |||
| # expert_mask_flat's shape: (self.expert_parallel, tokens_per_device) | |||
| # expert_mask_flat's shape: (expert_parallel, tokens_per_device) | |||
| expert_mask_flat = self.reduce_sum(expert_mask, -1) | |||
| # Mask out the experts that have overflowed the expert_capacity. | |||
| # expert_gate's shape: (self.expert_parallel, tokens_per_device) | |||
| # expert_gate's shape: (expert_parallel, tokens_per_device) | |||
| expert_gate = self.mul6(expert_gate, expert_mask_flat) | |||
| return expert_gate, expert_mask_flat, position_in_expert | |||
| def construct(self, router_logits): | |||
| router_logits_shape = self.shape(router_logits) | |||
| router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) | |||
| @@ -791,9 +795,9 @@ class SwitchRouter(Cell): | |||
| # Probabilities for each token of what expert is should be sent to | |||
| router_prob = self.softmax(router_logits) | |||
| # shape: (self.expert_parallel, tokens_per_device) | |||
| # shape is : (expert_parallel, tokens_per_device) | |||
| expert_index, expert_gate = self.argmax(router_prob) | |||
| # expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) | |||
| # Computing the load balance loss: | |||
| @@ -802,12 +806,12 @@ class SwitchRouter(Cell): | |||
| expert_gate, expert_mask_flat, position_in_expert = \ | |||
| self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate) | |||
| # combine_tensor's shape: (self.expert_parallel, tokens_per_device) | |||
| # combine_tensor's shape: (expert_parallel, tokens_per_device) | |||
| combine_tensor = self.mul7(expert_gate, expert_mask_flat) | |||
| # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) | |||
| # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim) | |||
| combine_tensor = self.mul8(self.expand(combine_tensor, -1), | |||
| self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) | |||
| # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) | |||
| # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) | |||
| combine_tensor = self.mul9(self.expand2(combine_tensor, -1), | |||
| self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, | |||
| self.on_value, self.off_value)) | |||
| @@ -93,7 +93,7 @@ class CrossEntropyLoss(Cell): | |||
| """ | |||
| self._check_input(logits, label, input_mask) | |||
| # [bs*seq_length, vocab_size] | |||
| # the shape is [bs*seq_length, vocab_size] | |||
| logits = F.cast(logits, mstype.float32) | |||
| # LogSoftmax for logits over last dimension | |||
| _, logit_max = self.max(logits) | |||
| @@ -143,7 +143,6 @@ def _check_config(config): | |||
| device_num = D.get_group_size() | |||
| optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| # dp * pp * pipeline_stage <= device_num | |||
| if config.data_parallel * config.model_parallel * pipeline_stage > device_num: | |||
| raise ValueError(f"The product of the data parallel {config.data_parallel}, " | |||
| f"model parallel {config.model_parallel} " | |||
| @@ -155,10 +154,3 @@ def _check_config(config): | |||
| logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" | |||
| f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the " | |||
| f"optimizer_shard to make them consistent.") | |||
| # pipeline_stage <= micro_batch_num | |||
| if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\ | |||
| and config.pipeline_stage < config.micro_batch_num: | |||
| raise ValueError( | |||
| f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num " | |||
| f"{config.micro_batch_num}.") | |||
| @@ -396,13 +396,14 @@ class FeedForward(Cell): | |||
| _check_input_shape(F.shape(x), "x", self.cls_name, 3) | |||
| _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) | |||
| x = self.cast(x, mstype.float16) | |||
| # [bs, seq_length, ffn_hidden_size] | |||
| # returned shape is [bs, seq_length, ffn_hidden_size] | |||
| hidden = self.mapping(x) | |||
| output = self.projection(hidden) | |||
| # [bs, seq_length, hidden_size] | |||
| # returned shape is [bs, seq_length, hidden_size] | |||
| output = self.dropout(output) | |||
| return output | |||
| @constexpr | |||
| def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim): | |||
| return math.ceil(k * tokens_per_device * capacity_factor / expert_dim) | |||
| @@ -588,7 +589,7 @@ class AttentionMask(Cell): | |||
| mask_right = self.reshape(input_mask, shape_right) | |||
| attention_mask = self.mul(mask_left, mask_right) | |||
| lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) | |||
| # [bs, seq_length, seq_length] | |||
| # the returned shape is [bs, seq_length, seq_length] | |||
| attention_mask = self.multiply( | |||
| attention_mask, lower_traiangle) | |||
| return attention_mask | |||
| @@ -889,18 +890,18 @@ class MultiHeadAttention(Cell): | |||
| query = self.dense1(query_tensor) | |||
| key = self.dense2(key_tensor) | |||
| value = self.dense3(value_tensor) | |||
| # [bs, num_heads, seq_length, size_per_head] | |||
| # the returned shape is [bs, num_heads, seq_length, size_per_head] | |||
| query = self.transpose( | |||
| F.reshape( | |||
| query, | |||
| (-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 1, 3)) | |||
| # [bs, num_heads, size_per_head, seq_length] | |||
| # the returned shape is [bs, num_heads, size_per_head, seq_length] | |||
| key = self.transpose( | |||
| F.reshape( | |||
| key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)), | |||
| (0, 2, 3, 1)) | |||
| # [bs, num_heads, seq_length, size_per_head] | |||
| # the returned shape is [bs, num_heads, seq_length, size_per_head] | |||
| value = self.transpose( | |||
| F.reshape( | |||
| value, | |||
| @@ -949,7 +950,7 @@ class MultiHeadAttention(Cell): | |||
| layer_present = (key_present, value_present) | |||
| # multi head attention considering attention mask | |||
| # [bs, seq_length, hidden_size] | |||
| # the return shape is [bs, seq_length, hidden_size] | |||
| attention = self._attn(query, key, value, attention_mask) | |||
| # Output | |||
| output = self.projection(attention) | |||
| @@ -1019,8 +1020,8 @@ class MultiHeadAttention(Cell): | |||
| ori_dtype = P.DType()(score) | |||
| score = P.Cast()(score, self.softmax_dtype) | |||
| # for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be | |||
| # (bs, 1, 1, seq_length) | |||
| # for input size of (bs, 1) namely the second graph, | |||
| # the shape of attention_mask matrix should be (bs, 1, 1, seq_length) | |||
| if self.use_past and not self.is_first_iteration: | |||
| # Calculate the current total token | |||
| current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0), | |||
| @@ -1508,7 +1509,7 @@ class TransformerDecoderLayer(Cell): | |||
| memory_mask=None, | |||
| init_reset=True, batch_valid_length=None): | |||
| self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length) | |||
| # [bs, seq_length, embedding_size] | |||
| # the returned shape is [bs, seq_length, embedding_size] | |||
| input_x = self.layernorm1(hidden_stats) | |||
| input_x = F.cast(input_x, self.dtype) | |||
| @@ -33,20 +33,19 @@ class EmbeddingLayer(nn.Cell): | |||
| def __init__(self, config): | |||
| super(EmbeddingLayer, self).__init__() | |||
| # Only for the pipeline mode, the embedding needs to be row sliced. | |||
| copied_parallel_config = copy.deepcopy(config.parallel_config) | |||
| if copied_parallel_config.pipeline_stage > 1: | |||
| copied_parallel_config.vocab_emb_dp = False | |||
| self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size, | |||
| embedding_size=config.hidden_size, | |||
| param_init=initializer("normal", [config.vocab_size, config.hidden_size], | |||
| dtype=config.param_init_type), | |||
| parallel_config=copied_parallel_config.embedding_dp_mp_config) | |||
| parallel_config=config.parallel_config.embedding_dp_mp_config) | |||
| copied_parallel_config = copy.deepcopy(config.parallel_config) | |||
| copied_parallel_config.vocab_emb_dp = True | |||
| self.position_embedding = VocabEmbedding(vocab_size=config.seq_length, | |||
| embedding_size=config.hidden_size, | |||
| param_init=initializer("normal", | |||
| [config.seq_length, config.hidden_size], | |||
| dtype=config.param_init_type), | |||
| parallel_config=config.parallel_config.embedding_dp_mp_config) | |||
| parallel_config=copied_parallel_config.embedding_dp_mp_config) | |||
| self.add = P.Add().shard( | |||
| ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) | |||
| self.dropout = nn.Dropout(1 - config.dropout_rate) | |||
| @@ -249,13 +248,14 @@ class PanguAlpha_Model(Cell): | |||
| param_init_type=config.param_init_type, | |||
| use_past=config.use_past, | |||
| parallel_config=config.parallel_config).blocks | |||
| copied_parallel_config = copy.deepcopy(config.parallel_config) | |||
| copied_parallel_config.vocab_emb_dp = True | |||
| self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length, | |||
| embedding_size=config.hidden_size, | |||
| param_init=initializer("normal", | |||
| [config.seq_length, config.hidden_size], | |||
| dtype=config.param_init_type), | |||
| parallel_config=config.parallel_config.embedding_dp_mp_config) | |||
| parallel_config=copied_parallel_config.embedding_dp_mp_config) | |||
| self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 | |||
| if config.parallel_config.pipeline_stage > 1: | |||
| self.top_query_embedding.set_comm_fusion(2) | |||
| @@ -106,6 +106,7 @@ def run_train(args_opt): | |||
| pipeline_stage=args_opt.stage_num, | |||
| micro_batch_num=args_opt.micro_size, | |||
| optimizer_shard=bool(args_opt.optimizer_shard), | |||
| vocab_emb_dp=bool(args_opt.word_emb_dp), | |||
| recompute=True) | |||
| config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads, | |||
| hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, | |||
| @@ -221,6 +222,7 @@ def run_train_pipeline(args_opt): | |||
| pipeline_stage=args_opt.stage_num, | |||
| micro_batch_num=args_opt.micro_size, | |||
| optimizer_shard=bool(args_opt.optimizer_shard), | |||
| vocab_emb_dp=bool(args_opt.word_emb_dp), | |||
| recompute=True) | |||
| config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num, | |||
| num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, | |||