From: @gaojing22 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -65,33 +65,28 @@ if __name__ == '__main__': | |||||
| for param in params: | for param in params: | ||||
| value = param.data | value = param.data | ||||
| name = param.name | |||||
| if name not in weights: | |||||
| raise ValueError(f"{name} is not found in weights.") | |||||
| with open("weight_after_deal.txt", "a+") as f: | |||||
| weights_name = name | |||||
| f.write(weights_name) | |||||
| f.write("\n") | |||||
| if isinstance(value, Tensor): | |||||
| print(name, value.asnumpy().shape) | |||||
| if weights_name in weights: | |||||
| assert weights_name in weights | |||||
| if isinstance(weights[weights_name], Parameter): | |||||
| if param.data.dtype == "Float32": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||||
| elif param.data.dtype == "Float16": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||||
| elif isinstance(weights[weights_name], Tensor): | |||||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||||
| elif isinstance(weights[weights_name], np.ndarray): | |||||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||||
| else: | |||||
| param.set_data(weights[weights_name]) | |||||
| weights_name = param.name | |||||
| if weights_name not in weights: | |||||
| raise ValueError(f"{weights_name} is not found in weights.") | |||||
| if isinstance(value, Tensor): | |||||
| if weights_name in weights: | |||||
| assert weights_name in weights | |||||
| if isinstance(weights[weights_name], Parameter): | |||||
| if param.data.dtype == "Float32": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||||
| elif param.data.dtype == "Float16": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||||
| elif isinstance(weights[weights_name], Tensor): | |||||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||||
| elif isinstance(weights[weights_name], np.ndarray): | |||||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||||
| else: | else: | ||||
| print("weight not found in checkpoint: " + weights_name) | |||||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||||
| f.close() | |||||
| param.set_data(weights[weights_name]) | |||||
| else: | |||||
| print("weight not found in checkpoint: " + weights_name) | |||||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||||
| print(" | Load weights successfully.") | print(" | Load weights successfully.") | ||||
| tfm_infer = GNMTInferCell(tfm_model) | tfm_infer = GNMTInferCell(tfm_model) | ||||
| tfm_infer.set_train(False) | tfm_infer.set_train(False) | ||||
| @@ -23,7 +23,6 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import context, Parameter | from mindspore import context, Parameter | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint | |||||
| from src.dataset import load_dataset | from src.dataset import load_dataset | ||||
| from .gnmt import GNMT | from .gnmt import GNMT | ||||
| @@ -37,18 +36,6 @@ context.set_context( | |||||
| reserve_class_name_in_scope=False) | reserve_class_name_in_scope=False) | ||||
| def get_weight_and_variable(model_path, params): | |||||
| print("model path is {}".format(model_path)) | |||||
| ms_ckpt = load_checkpoint(model_path) | |||||
| with open("variable.txt", "w") as f: | |||||
| for msname in ms_ckpt: | |||||
| f.write(msname + "\n") | |||||
| with open("weights.txt", "w") as f: | |||||
| for param in params: | |||||
| name = param.name | |||||
| f.write(name + "\n") | |||||
| class GNMTInferCell(nn.Cell): | class GNMTInferCell(nn.Cell): | ||||
| """ | """ | ||||
| Encapsulation class of GNMT network infer. | Encapsulation class of GNMT network infer. | ||||
| @@ -92,38 +79,31 @@ def gnmt_infer(config, dataset): | |||||
| use_one_hot_embeddings=False) | use_one_hot_embeddings=False) | ||||
| params = tfm_model.trainable_params() | params = tfm_model.trainable_params() | ||||
| get_weight_and_variable(config.existed_ckpt, params) | |||||
| weights = load_infer_weights(config) | weights = load_infer_weights(config) | ||||
| for param in params: | for param in params: | ||||
| value = param.data | value = param.data | ||||
| name = param.name | |||||
| if name not in weights: | |||||
| raise ValueError(f"{name} is not found in weights.") | |||||
| with open("weight_after_deal.txt", "a+") as f: | |||||
| weights_name = name | |||||
| f.write(weights_name) | |||||
| f.write("\n") | |||||
| if isinstance(value, Tensor): | |||||
| print(name, value.asnumpy().shape) | |||||
| if weights_name in weights: | |||||
| assert weights_name in weights | |||||
| if isinstance(weights[weights_name], Parameter): | |||||
| if param.data.dtype == "Float32": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||||
| elif param.data.dtype == "Float16": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||||
| elif isinstance(weights[weights_name], Tensor): | |||||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||||
| elif isinstance(weights[weights_name], np.ndarray): | |||||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||||
| else: | |||||
| param.set_data(weights[weights_name]) | |||||
| weights_name = param.name | |||||
| if weights_name not in weights: | |||||
| raise ValueError(f"{weights_name} is not found in weights.") | |||||
| if isinstance(value, Tensor): | |||||
| if weights_name in weights: | |||||
| assert weights_name in weights | |||||
| if isinstance(weights[weights_name], Parameter): | |||||
| if param.data.dtype == "Float32": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32)) | |||||
| elif param.data.dtype == "Float16": | |||||
| param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16)) | |||||
| elif isinstance(weights[weights_name], Tensor): | |||||
| param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype)) | |||||
| elif isinstance(weights[weights_name], np.ndarray): | |||||
| param.set_data(Tensor(weights[weights_name], config.dtype)) | |||||
| else: | else: | ||||
| print("weight not found in checkpoint: " + weights_name) | |||||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||||
| f.close() | |||||
| param.set_data(weights[weights_name]) | |||||
| else: | |||||
| print("weight not found in checkpoint: " + weights_name) | |||||
| param.set_data(zero_weight(value.asnumpy().shape)) | |||||
| print(" | Load weights successfully.") | print(" | Load weights successfully.") | ||||
| tfm_infer = GNMTInferCell(tfm_model) | tfm_infer = GNMTInferCell(tfm_model) | ||||
| model = Model(tfm_infer) | model = Model(tfm_infer) | ||||
| @@ -37,36 +37,26 @@ def load_infer_weights(config): | |||||
| ms_ckpt = load_checkpoint(model_path) | ms_ckpt = load_checkpoint(model_path) | ||||
| is_npz = False | is_npz = False | ||||
| weights = {} | weights = {} | ||||
| with open("variable_after_deal.txt", "w") as f: | |||||
| for param_name in ms_ckpt: | |||||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||||
| if infer_name.startswith("embedding_lookup."): | |||||
| if is_npz: | |||||
| weights[infer_name] = ms_ckpt[param_name] | |||||
| else: | |||||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||||
| f.write(infer_name) | |||||
| f.write("\n") | |||||
| infer_name = "beam_decoder.decoder." + infer_name | |||||
| if is_npz: | |||||
| weights[infer_name] = ms_ckpt[param_name] | |||||
| else: | |||||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||||
| f.write(infer_name) | |||||
| f.write("\n") | |||||
| continue | |||||
| elif not infer_name.startswith("gnmt_encoder"): | |||||
| if infer_name.startswith("gnmt_decoder."): | |||||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||||
| infer_name = "beam_decoder.decoder." + infer_name | |||||
| for param_name in ms_ckpt: | |||||
| infer_name = param_name.replace("gnmt.gnmt.", "") | |||||
| if infer_name.startswith("embedding_lookup."): | |||||
| if is_npz: | if is_npz: | ||||
| weights[infer_name] = ms_ckpt[param_name] | weights[infer_name] = ms_ckpt[param_name] | ||||
| else: | else: | ||||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | ||||
| f.write(infer_name) | |||||
| f.write("\n") | |||||
| f.close() | |||||
| infer_name = "beam_decoder.decoder." + infer_name | |||||
| if is_npz: | |||||
| weights[infer_name] = ms_ckpt[param_name] | |||||
| else: | |||||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||||
| continue | |||||
| elif not infer_name.startswith("gnmt_encoder"): | |||||
| if infer_name.startswith("gnmt_decoder."): | |||||
| infer_name = infer_name.replace("gnmt_decoder.", "decoder.") | |||||
| infer_name = "beam_decoder.decoder." + infer_name | |||||
| if is_npz: | |||||
| weights[infer_name] = ms_ckpt[param_name] | |||||
| else: | |||||
| weights[infer_name] = ms_ckpt[param_name].data.asnumpy() | |||||
| return weights | return weights | ||||