From: @gaojing22 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -65,33 +65,28 @@ if __name__ == '__main__': | |||
| for param in params: | |||
| value = param.data | |||
| name = param.name | |||
| if name not in weights: | |||
| raise ValueError(f"{name} is not found in weights.") | |||
| with open("weight_after_deal.txt", "a+") as f: | |||
| weights_name = name | |||
| f.write(weights_name) | |||
| f.write("\n") | |||
| if isinstance(value, Tensor): | |||
| print(name, value.asnumpy().shape) | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| param.set_data(weights[weights_name]) | |||
| weights_name = param.name | |||
| if weights_name not in weights: | |||
| raise ValueError(f"{weights_name} is not found in weights.") | |||
| if isinstance(value, Tensor): | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| f.close() | |||
| param.set_data(weights[weights_name]) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| print(" | Load weights successfully.") | |||
| tfm_infer = GNMTInferCell(tfm_model) | |||
| tfm_infer.set_train(False) | |||
| @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context, Parameter | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from src.dataset import load_dataset | |||
| from .gnmt import GNMT | |||
| @@ -37,18 +36,6 @@ context.set_context( | |||
| reserve_class_name_in_scope=False) | |||
| def get_weight_and_variable(model_path, params): | |||
| print("model path is {}".format(model_path)) | |||
| ms_ckpt = load_checkpoint(model_path) | |||
| with open("variable.txt", "w") as f: | |||
| for msname in ms_ckpt: | |||
| f.write(msname + "\n") | |||
| with open("weights.txt", "w") as f: | |||
| for param in params: | |||
| name = param.name | |||
| f.write(name + "\n") | |||
| class GNMTInferCell(nn.Cell): | |||
| """ | |||
| Encapsulation class of GNMT network infer. | |||
| @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): | |||
| use_one_hot_embeddings=False) | |||
| params = tfm_model.trainable_params() | |||
| get_weight_and_variable(config.existed_ckpt, params) | |||
| weights = load_infer_weights(config) | |||
| for param in params: | |||
| value = param.data | |||
| name = param.name | |||
| if name not in weights: | |||
| raise ValueError(f"{name} is not found in weights.") | |||
| with open("weight_after_deal.txt", "a+") as f: | |||
| weights_name = name | |||
| f.write(weights_name) | |||
| f.write("\n") | |||
| if isinstance(value, Tensor): | |||
| print(name, value.asnumpy().shape) | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| param.set_data(weights[weights_name]) | |||
| weights_name = param.name | |||
| if weights_name not in weights: | |||
| raise ValueError(f"{weights_name} is not found in weights.") | |||
| if isinstance(value, Tensor): | |||
| if weights_name in weights: | |||
| assert weights_name in weights | |||
| if isinstance(weights[weights_name], Parameter): | |||
| if param.data.dtype == "Float32": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||
| elif param.data.dtype == "Float16": | |||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||
| elif isinstance(weights[weights_name], Tensor): | |||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||
| elif isinstance(weights[weights_name], np.ndarray): | |||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| f.close() | |||
| param.set_data(weights[weights_name]) | |||
| else: | |||
| print("weight not found in checkpoint: " + weights_name) | |||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||
| print(" | Load weights successfully.") | |||
| tfm_infer = GNMTInferCell(tfm_model) | |||
| model = Model(tfm_infer) | |||
| @@ -37,36 +37,26 @@ def load_infer_weights(config): | |||
| ms_ckpt = load_checkpoint(model_path) | |||
| is_npz = False | |||
| weights = {} | |||
| with open("variable_after_deal.txt", "w") as f: | |||
| for param_name in ms_ckpt: | |||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||
| if infer_name.startswith("embedding_lookup."): | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| continue | |||
| elif not infer_name.startswith("gnmt_encoder"): | |||
| if infer_name.startswith("gnmt_decoder."): | |||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| for param_name in ms_ckpt: | |||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||
| if infer_name.startswith("embedding_lookup."): | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| f.write(infer_name) | |||
| f.write("\n") | |||
| f.close() | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| continue | |||
| elif not infer_name.startswith("gnmt_encoder"): | |||
| if infer_name.startswith("gnmt_decoder."): | |||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||
| infer_name = "beam_decoder.decoder." + infer_name | |||
| if is_npz: | |||
| weights[infer_name] = ms_ckpt[param_name] | |||
| else: | |||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||
| return weights | |||