diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 7dac5a27b4..12827981d1 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -250,3 +250,35 @@ def without_fold_batchnorm(weight, cell_quant): weight = weight * _gamma / _sigma bias = beta - gamma * mean / sigma return weight, bias + + +def load_nonquant_param_into_quant_net(quant_model, params_dict): + """ + load fp32 model parameters to quantization model. + + Args: + quant_model: quantization model + params_dict: f32 param + + Returns: + None + """ + iterable_dict = { + 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), + 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), + 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), + 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), + 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), + 'moving_variance': iter( + [item for item in params_dict.items() if item[0].endswith('moving_variance')]), + 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), + 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) + } + for name, param in quant_model.parameters_and_names(): + key_name = name.split(".")[-1] + if key_name not in iterable_dict.keys(): + raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") + value_param = next(iterable_dict[key_name], None) + if value_param is not None: + param.set_parameter_data(value_param[1].data) + print(f'init model param {name} with checkpoint param {value_param[0]}') diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 30795431b5..91f976cb0b 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -308,6 +308,7 @@ def load_param_into_net(net, parameter_dict): logger.debug("%s", param_name) logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load))) + return param_not_load def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): diff --git a/model_zoo/official/cv/lenet_quant/Readme.md b/model_zoo/official/cv/lenet_quant/Readme.md index 2fd3e129a2..9e5e64b48c 100644 --- a/model_zoo/official/cv/lenet_quant/Readme.md +++ b/model_zoo/official/cv/lenet_quant/Readme.md @@ -93,65 +93,6 @@ Get the MNIST from scratch dataset. ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) step_size = ds_train.get_dataset_size() -``` - -### Train model - -Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. - -```Python -# Define the network -network = LeNet5Fusion(cfg.num_classes) -# Define the loss -net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") -# Define optimization -net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - -# Define model using loss and optimization. -time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) -config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) -ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) -model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) -``` - -Now we can start training. - -```Python -model.train(cfg['epoch_size'], ds_train, - callbacks=[time_cb, ckpoint_cb, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) -``` - -After all the following we will get the loss value of each step as following: - -```bash ->>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] ->>> ... ->>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] ->>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] ->>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] -``` - -Also, you can just run this command instead. - -```python -python train.py --data_path MNIST_Data --device_target Ascend -``` - -### Evaluate fusion model - -After training epoch stop. We can get the fusion model checkpoint file like `checkpoint_lenet.ckpt`. Meanwhile, we can evaluate this fusion model. - -```python -python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt -``` - -The top1 accuracy would display on shell. - -```bash ->>> Accuracy: 98.53. -``` ## Train quantization aware model diff --git a/model_zoo/official/cv/lenet_quant/eval.py b/model_zoo/official/cv/lenet_quant/eval.py deleted file mode 100644 index df9a5b123b..0000000000 --- a/model_zoo/official/cv/lenet_quant/eval.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## eval lenet example ######################## -eval lenet according to model file: -python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt -""" - -import os -import argparse -import mindspore.nn as nn -from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy -from src.dataset import create_dataset -from src.config import mnist_cfg as cfg -from src.lenet_fusion import LeNet5 as LeNet5Fusion - -parser = argparse.ArgumentParser(description='MindSpore MNIST Example') -parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') -parser.add_argument('--data_path', type=str, default="./MNIST_Data", - help='path where the dataset is saved') -parser.add_argument('--ckpt_path', type=str, default="", - help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') -args = parser.parse_args() - -if __name__ == "__main__": - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) - step_size = ds_eval.get_dataset_size() - - # define fusion network - network = LeNet5Fusion(cfg.num_classes) - # define loss - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - # define network optimization - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - - # call back and monitor - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - # load check point into network - param_dict = load_checkpoint(args.ckpt_path) - load_param_into_net(network, param_dict) - - print("============== Starting Testing ==============") - acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) - print("============== {} ==============".format(acc)) diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py index 9f586d9e19..4849d01daf 100644 --- a/model_zoo/official/cv/lenet_quant/eval_quant.py +++ b/model_zoo/official/cv/lenet_quant/eval_quant.py @@ -63,7 +63,9 @@ if __name__ == "__main__": # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) - load_param_into_net(network, param_dict) + not_load_param = load_param_into_net(network, param_dict) + if not_load_param: + raise ValueError("Load param into net fail!") print("============== Starting Testing ==============") acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) diff --git a/model_zoo/official/cv/lenet_quant/src/lenet.py b/model_zoo/official/cv/lenet_quant/src/lenet.py deleted file mode 100644 index 18d310c2c7..0000000000 --- a/model_zoo/official/cv/lenet_quant/src/lenet.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""LeNet.""" -import mindspore.nn as nn - - -class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Num classes. Default: 10. - - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - - """ - - def __init__(self, num_class=10, channel=1): - super(LeNet5, self).__init__() - self.num_class = num_class - - self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') - self.bn1 = nn.BatchNorm2d(6) - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.bn2 = nn.BatchNorm2d(16) - self.fc1 = nn.Dense(16 * 5 * 5, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, self.num_class) - - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x diff --git a/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py b/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py index 88b5685218..88b3593502 100644 --- a/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py +++ b/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py @@ -36,8 +36,8 @@ class LeNet5(nn.Cell): self.num_class = num_class # change `nn.Conv2d` to `nn.Conv2dBnAct` - self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', has_bn=True, activation='relu') - self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', has_bn=True, activation='relu') + self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') # change `nn.Dense` to `nn.DenseBnAct` self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') diff --git a/model_zoo/official/cv/lenet_quant/train.py b/model_zoo/official/cv/lenet_quant/train.py deleted file mode 100644 index 66546b15c0..0000000000 --- a/model_zoo/official/cv/lenet_quant/train.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -######################## train lenet example ######################## -train lenet and get network model files(.ckpt) : -python train.py --data_path /YourDataPath -""" - -import os -import argparse -import mindspore.nn as nn -from mindspore import context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy -from src.dataset import create_dataset -from src.config import mnist_cfg as cfg -from src.lenet_fusion import LeNet5 as LeNet5Fusion -from src.loss_monitor import LossMonitor - -parser = argparse.ArgumentParser(description='MindSpore MNIST Example') -parser.add_argument('--device_target', type=str, default="Ascend", - choices=['Ascend', 'GPU'], - help='device where the code will be implemented (default: Ascend)') -parser.add_argument('--data_path', type=str, default="./MNIST_Data", - help='path where the dataset is saved') -parser.add_argument('--ckpt_path', type=str, default="", - help='if mode is test, must provide path where the trained ckpt file') -parser.add_argument('--dataset_sink_mode', type=bool, default=True, - help='dataset_sink_mode is False or True') -args = parser.parse_args() - -if __name__ == "__main__": - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) - step_size = ds_train.get_dataset_size() - - # define fusion network - network = LeNet5Fusion(cfg.num_classes) - # define network loss - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - # define network optimization - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - - # call back and monitor - config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) - - # define model - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], - dataset_sink_mode=args.dataset_sink_mode) - print("============== End Training ==============") diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py index dd6b59a9c8..51d37cc1bf 100644 --- a/model_zoo/official/cv/lenet_quant/train_quant.py +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -22,11 +22,12 @@ import os import argparse import mindspore.nn as nn from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.train.quant import quant +from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from src.dataset import create_dataset from src.config import mnist_cfg as cfg from src.lenet_fusion import LeNet5 as LeNet5Fusion @@ -54,10 +55,11 @@ if __name__ == "__main__": # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) - load_param_into_net(network, param_dict) + load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=900, per_channel=[True, False], symmetric=[False, False]) + network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], + symmetric=[False, False]) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/model_zoo/official/cv/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py index cfa873a98d..e6b0875c75 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/eval.py +++ b/model_zoo/official/cv/mobilenetv2_quant/eval.py @@ -68,7 +68,9 @@ if __name__ == '__main__': # load checkpoint if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) - load_param_into_net(network, param_dict) + not_load_param = load_param_into_net(network, param_dict) + if not_load_param: + raise ValueError("Load param into net fail!") network.set_train(False) # define model diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/utils.py b/model_zoo/official/cv/mobilenetv2_quant/src/utils.py index 33fec74de3..a00ac53349 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/src/utils.py +++ b/model_zoo/official/cv/mobilenetv2_quant/src/utils.py @@ -25,39 +25,6 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.common import dtype as mstype - -def _load_param_into_net(model, params_dict): - """ - load fp32 model parameters to quantization model. - - Args: - model: quantization model - params_dict: f32 param - - Returns: - None - """ - iterable_dict = { - 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), - 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), - 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), - 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), - 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), - 'moving_variance': iter( - [item for item in params_dict.items() if item[0].endswith('moving_variance')]), - 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), - 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) - } - for name, param in model.parameters_and_names(): - key_name = name.split(".")[-1] - if key_name not in iterable_dict.keys(): - raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") - value_param = next(iterable_dict[key_name], None) - if value_param is not None: - param.set_parameter_data(value_param[1].data) - print(f'init model param {name} with checkpoint param {value_param[0]}') - - class Monitor(Callback): """ Monitor loss and time. diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py index a941d370fc..ebe60996cf 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/train.py +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -28,6 +28,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_checkpoint from mindspore.communication.management import init, get_group_size, get_rank from mindspore.train.quant import quant +from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net import mindspore.dataset.engine as de from src.dataset import create_dataset @@ -35,7 +36,6 @@ from src.lr_generator import get_lr from src.utils import Monitor, CrossEntropyWithLabelSmooth from src.config import config_ascend_quant, config_gpu_quant from src.mobilenetV2 import mobilenetV2 -from src.utils import _load_param_into_net random.seed(1) np.random.seed(1) @@ -101,7 +101,7 @@ def train_on_ascend(): # load pre trained ckpt if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) - _load_param_into_net(network, param_dict) + load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network network = quant.convert_quant_network(network, bn_fold=True, @@ -163,7 +163,7 @@ def train_on_gpu(): # resume if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) - _load_param_into_net(network, param_dict) + load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network network = quant.convert_quant_network(network, diff --git a/model_zoo/official/cv/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py index 481e4bb853..0395e38b60 100755 --- a/model_zoo/official/cv/resnet50_quant/eval.py +++ b/model_zoo/official/cv/resnet50_quant/eval.py @@ -20,12 +20,11 @@ import argparse from src.config import quant_set, config_quant, config_noquant from src.dataset import create_dataset from src.crossentropy import CrossEntropy -from src.utils import _load_param_into_net from models.resnet_quant import resnet50_quant from mindspore import context from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint +from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.quant import quant parser = argparse.ArgumentParser(description='Image classification') @@ -66,7 +65,9 @@ if __name__ == '__main__': # load checkpoint if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) - _load_param_into_net(net, param_dict) + not_load_param = load_param_into_net(net, param_dict) + if not_load_param: + raise ValueError("Load param into net fail!") net.set_train(False) # define model diff --git a/model_zoo/official/cv/resnet50_quant/src/utils.py b/model_zoo/official/cv/resnet50_quant/src/utils.py deleted file mode 100644 index 846fd7b894..0000000000 --- a/model_zoo/official/cv/resnet50_quant/src/utils.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""utils script""" - -def _load_param_into_net(model, params_dict): - """ - load fp32 model parameters to quantization model. - - Args: - model: quantization model - params_dict: f32 param - - Returns: - None - """ - iterable_dict = { - 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), - 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), - 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), - 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), - 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), - 'moving_variance': iter( - [item for item in params_dict.items() if item[0].endswith('moving_variance')]), - 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), - 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) - } - for name, param in model.parameters_and_names(): - key_name = name.split(".")[-1] - if key_name not in iterable_dict.keys(): - continue - value_param = next(iterable_dict[key_name], None) - if value_param is not None: - param.set_parameter_data(value_param[1].data) - print(f'init model param {name} with checkpoint param {value_param[0]}') diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index be609bd543..2e13ec37ff 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -26,6 +26,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint from mindspore.train.quant import quant +from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.communication.management import init import mindspore.nn as nn import mindspore.common.initializer as weight_init @@ -35,7 +36,6 @@ from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_quant from src.crossentropy import CrossEntropy -from src.utils import _load_param_into_net parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') @@ -85,7 +85,7 @@ if __name__ == '__main__': # weight init and load checkpoint file if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) - _load_param_into_net(net, param_dict) + load_nonquant_param_into_quant_net(net, param_dict) epoch_size = config.epoch_size - config.pretrained_epoch_size else: for _, cell in net.cells_and_names():