| @@ -135,11 +135,13 @@ The GNMT network script and code result are as follows: | |||
| │ ├──lr_scheduler.py // Learning rate scheduler. | |||
| │ ├──optimizer.py // Optimizer. | |||
| ├── scripts | |||
| │ ├──run_distributed_train_ascend.sh // shell script for distributed train on ascend. | |||
| │ ├──run_standalone_eval_ascend.sh // shell script for standalone eval on ascend. | |||
| │ ├──run_standalone_train_ascend.sh // shell script for standalone eval on ascend. | |||
| ├── create_dataset.py // dataset preparation. | |||
| │ ├──run_distributed_train_ascend.sh // Shell script for distributed train on ascend. | |||
| │ ├──run_standalone_eval_ascend.sh // Shell script for standalone eval on ascend. | |||
| │ ├──run_standalone_train_ascend.sh // Shell script for standalone eval on ascend. | |||
| ├── create_dataset.py // Dataset preparation. | |||
| ├── eval.py // Infer API entry. | |||
| ├── export.py // Export checkpoint file into air models. | |||
| ├── mindspore_hub_conf.py // Hub config. | |||
| ├── requirements.txt // Requirements of third party package. | |||
| ├── train.py // Train API entry. | |||
| ``` | |||
| @@ -65,33 +65,28 @@ if __name__ == '__main__': | |||
| for param in params: | |||
| value = param.data | |||
| name = param.name | |||
| if name not in weights: | |||
| raise ValueError(f"{name} is not found in weights.") | |||
| with open("weight_after_deal.txt", "a+") as f: | |||
| weights_name = name | |||
| f.write(weights_name) | |||
| f.write("\n") | |||
| if isinstance(value, Tensor): | |||
| print(name, value.asnumpy().shape) | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| param.set_data(weights[weights_name]) | |||
| weights_name = param.name | |||
| if weights_name not in weights: | |||
| raise ValueError(f"{weights_name} is not found in weights.") | |||
| if isinstance(value, Tensor): | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| f.close() | |||
| param.set_data(weights[weights_name]) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| print(" | Load weights successfully.") | |||
| tfm_infer = GNMTInferCell(tfm_model) | |||
| tfm_infer.set_train(False) | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| import mindspore.common.dtype as mstype | |||
| from config import GNMTConfig | |||
| from src.gnmt_model import GNMTNetworkWithLoss, GNMT | |||
| def get_config(config): | |||
| config = GNMTConfig.from_json_file(config) | |||
| config.compute_type = mstype.float16 | |||
| config.dtype = mstype.float32 | |||
| return config | |||
| def create_network(name, *args, **kwargs): | |||
| """create gnmt network.""" | |||
| if name == "gnmt": | |||
| if "config" in kwargs: | |||
| config = get_config(kwargs["config"]) | |||
| else: | |||
| raise NotImplementedError(f"Please make sure the configuration file path is correct") | |||
| is_training = kwargs.get("is_training", False) | |||
| if is_training: | |||
| return GNMTNetworkWithLoss(config, is_training=is_training, *args) | |||
| return GNMT(config, *args) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell): | |||
| max_decode_length=64, | |||
| sos_id=2, | |||
| eos_id=3, | |||
| is_using_while=False, | |||
| compute_type=mstype.float32): | |||
| super(BeamSearchDecoder, self).__init__() | |||
| @@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell): | |||
| self.cov_penalty_factor = cov_penalty_factor | |||
| self.max_decode_length = max_decode_length | |||
| self.decoder = decoder | |||
| self.is_using_while = is_using_while | |||
| self.add = P.TensorAdd() | |||
| self.expand = P.ExpandDims() | |||
| @@ -215,7 +217,12 @@ class BeamSearchDecoder(nn.Cell): | |||
| self.gather_nd = P.GatherNd() | |||
| self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) | |||
| self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) | |||
| if self.is_using_while: | |||
| self.start = Tensor(0, dtype=mstype.int32) | |||
| self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), | |||
| mstype.int32) | |||
| else: | |||
| self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) | |||
| init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) | |||
| self.init_scores = Tensor(init_scores, mstype.float32) | |||
| @@ -259,7 +266,7 @@ class BeamSearchDecoder(nn.Cell): | |||
| self.sub = P.Sub() | |||
| def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | |||
| state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None, | |||
| state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None, | |||
| state_finished=None): | |||
| """ | |||
| Beam search one_step output. | |||
| @@ -359,7 +366,13 @@ class BeamSearchDecoder(nn.Cell): | |||
| self.hidden_size)) | |||
| # update state_seq | |||
| state_seq = self.concat((seq, self.expand(word_indices, -1))) | |||
| if self.is_using_while: | |||
| state_seq_new = self.cast(seq, mstype.float32) | |||
| word_indices_fp32 = self.cast(word_indices, mstype.float32) | |||
| state_seq_new[:, :, idx] = word_indices_fp32 | |||
| state_seq = self.cast(state_seq_new, mstype.int32) | |||
| else: | |||
| state_seq = self.concat((seq, self.expand(word_indices, -1))) | |||
| cur_input_ids = self.reshape(word_indices, (-1, 1)) | |||
| state_log_probs = topk_scores | |||
| @@ -388,11 +401,22 @@ class BeamSearchDecoder(nn.Cell): | |||
| decoder_hidden_state = self.decoder_hidden_state | |||
| accu_attn_scores = self.accu_attn_scores | |||
| for _ in range(self.max_decode_length + 1): | |||
| cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ | |||
| state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | |||
| state_seq, state_length, decoder_hidden_state, accu_attn_scores, | |||
| state_finished) | |||
| if not self.is_using_while: | |||
| for _ in range(self.max_decode_length + 1): | |||
| cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ | |||
| state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | |||
| state_seq, state_length, None, decoder_hidden_state, accu_attn_scores, | |||
| state_finished) | |||
| else: | |||
| idx = self.start + 1 | |||
| ends = self.start + self.max_decode_length + 1 | |||
| while idx < ends: | |||
| cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ | |||
| state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, | |||
| state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, | |||
| state_finished) | |||
| idx = idx + 1 | |||
| # add length penalty scores | |||
| penalty_len = self.length_penalty(state_length) | |||
| # return penalty_len | |||
| @@ -408,6 +432,9 @@ class BeamSearchDecoder(nn.Cell): | |||
| gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) | |||
| # sort sequence and attention scores | |||
| predicted_ids = self.gather_nd(state_seq, gather_indices) | |||
| predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] | |||
| if not self.is_using_while: | |||
| predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] | |||
| else: | |||
| predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] | |||
| return predicted_ids | |||
| @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context, Parameter | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from src.dataset import load_dataset | |||
| from .gnmt import GNMT | |||
| @@ -37,18 +36,6 @@ context.set_context( | |||
| reserve_class_name_in_scope=False) | |||
| def get_weight_and_variable(model_path, params): | |||
| print("model path is {}".format(model_path)) | |||
| ms_ckpt = load_checkpoint(model_path) | |||
| with open("variable.txt", "w") as f: | |||
| for msname in ms_ckpt: | |||
| f.write(msname + "\n") | |||
| with open("weights.txt", "w") as f: | |||
| for param in params: | |||
| name = param.name | |||
| f.write(name + "\n") | |||
| class GNMTInferCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of GNMT network infer. | |||
| @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): | |||
| use_one_hot_embeddings=False) | |||
| params = tfm_model.trainable_params() | |||
| get_weight_and_variable(config.existed_ckpt, params) | |||
| weights = load_infer_weights(config) | |||
| for param in params: | |||
| value = param.data | |||
| name = param.name | |||
| if name not in weights: | |||
| raise ValueError(f"{name} is not found in weights.") | |||
| with open("weight_after_deal.txt", "a+") as f: | |||
| weights_name = name | |||
| f.write(weights_name) | |||
| f.write("\n") | |||
| if isinstance(value, Tensor): | |||
| print(name, value.asnumpy().shape) | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| param.set_data(weights[weights_name]) | |||
| weights_name = param.name | |||
| if weights_name not in weights: | |||
| raise ValueError(f"{weights_name} is not found in weights.") | |||
| if isinstance(value, Tensor): | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| f.close() | |||
| param.set_data(weights[weights_name]) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| print(" | Load weights successfully.") | |||
| tfm_infer = GNMTInferCell(tfm_model) | |||
| model = Model(tfm_infer) | |||
| @@ -37,36 +37,26 @@ def load_infer_weights(config): | |||
| ms_ckpt = load_checkpoint(model_path) | |||
| is_npz = False | |||
| weights = {} | |||
| with open("variable_after_deal.txt", "w") as f: | |||
| for param_name in ms_ckpt: | |||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||
| if infer_name.startswith("embedding_lookup."): | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| continue | |||
| elif not infer_name.startswith("gnmt_encoder"): | |||
| if infer_name.startswith("gnmt_decoder."): | |||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| for param_name in ms_ckpt: | |||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||
| if infer_name.startswith("embedding_lookup."): | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| f.close() | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| continue | |||
| elif not infer_name.startswith("gnmt_encoder"): | |||
| if infer_name.startswith("gnmt_decoder."): | |||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| return weights | |||