| @@ -250,3 +250,35 @@ def without_fold_batchnorm(weight, cell_quant): | |||
| weight = weight * _gamma / _sigma | |||
| bias = beta - gamma * mean / sigma | |||
| return weight, bias | |||
| def load_nonquant_param_into_quant_net(quant_model, params_dict): | |||
| """ | |||
| load fp32 model parameters to quantization model. | |||
| Args: | |||
| quant_model: quantization model | |||
| params_dict: f32 param | |||
| Returns: | |||
| None | |||
| """ | |||
| iterable_dict = { | |||
| 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), | |||
| 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), | |||
| 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), | |||
| 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), | |||
| 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), | |||
| 'moving_variance': iter( | |||
| [item for item in params_dict.items() if item[0].endswith('moving_variance')]), | |||
| 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), | |||
| 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) | |||
| } | |||
| for name, param in quant_model.parameters_and_names(): | |||
| key_name = name.split(".")[-1] | |||
| if key_name not in iterable_dict.keys(): | |||
| raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") | |||
| value_param = next(iterable_dict[key_name], None) | |||
| if value_param is not None: | |||
| param.set_parameter_data(value_param[1].data) | |||
| print(f'init model param {name} with checkpoint param {value_param[0]}') | |||
| @@ -308,6 +308,7 @@ def load_param_into_net(net, parameter_dict): | |||
| logger.debug("%s", param_name) | |||
| logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load))) | |||
| return param_not_load | |||
| def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): | |||
| @@ -93,65 +93,6 @@ Get the MNIST from scratch dataset. | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | |||
| cfg.batch_size, cfg.epoch_size) | |||
| step_size = ds_train.get_dataset_size() | |||
| ``` | |||
| ### Train model | |||
| Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. | |||
| ```Python | |||
| # Define the network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # Define the loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # Define optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # Define model using loss and optimization. | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| ``` | |||
| Now we can start training. | |||
| ```Python | |||
| model.train(cfg['epoch_size'], ds_train, | |||
| callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| ``` | |||
| After all the following we will get the loss value of each step as following: | |||
| ```bash | |||
| >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] | |||
| >>> ... | |||
| >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| ``` | |||
| Also, you can just run this command instead. | |||
| ```python | |||
| python train.py --data_path MNIST_Data --device_target Ascend | |||
| ``` | |||
| ### Evaluate fusion model | |||
| After training epoch stop. We can get the fusion model checkpoint file like `checkpoint_lenet.ckpt`. Meanwhile, we can evaluate this fusion model. | |||
| ```python | |||
| python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt | |||
| ``` | |||
| The top1 accuracy would display on shell. | |||
| ```bash | |||
| >>> Accuracy: 98.53. | |||
| ``` | |||
| ## Train quantization aware model | |||
| @@ -1,65 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## eval lenet example ######################## | |||
| eval lenet according to model file: | |||
| python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) | |||
| step_size = ds_eval.get_dataset_size() | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # define loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| # load check point into network | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== {} ==============".format(acc)) | |||
| @@ -63,7 +63,9 @@ if __name__ == "__main__": | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| not_load_param = load_param_into_net(network, param_dict) | |||
| if not_load_param: | |||
| raise ValueError("Load param into net fail!") | |||
| print("============== Starting Testing ==============") | |||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| @@ -1,64 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LeNet.""" | |||
| import mindspore.nn as nn | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| Args: | |||
| num_class (int): Num classes. Default: 10. | |||
| Returns: | |||
| Tensor, output tensor | |||
| Examples: | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10, channel=1): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') | |||
| self.bn1 = nn.BatchNorm2d(6) | |||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||
| self.bn2 = nn.BatchNorm2d(16) | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc3 = nn.Dense(84, self.num_class) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @@ -36,8 +36,8 @@ class LeNet5(nn.Cell): | |||
| self.num_class = num_class | |||
| # change `nn.Conv2d` to `nn.Conv2dBnAct` | |||
| self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', has_bn=True, activation='relu') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', has_bn=True, activation='relu') | |||
| self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') | |||
| # change `nn.Dense` to `nn.DenseBnAct` | |||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||
| @@ -1,68 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## train lenet example ######################## | |||
| train lenet and get network model files(.ckpt) : | |||
| python train.py --data_path /YourDataPath | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| from src.loss_monitor import LossMonitor | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1) | |||
| step_size = ds_train.get_dataset_size() | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | |||
| # define model | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -22,11 +22,12 @@ import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.quant import quant | |||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| @@ -54,10 +55,11 @@ if __name__ == "__main__": | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| load_nonquant_param_into_quant_net(network, param_dict) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=900, per_channel=[True, False], symmetric=[False, False]) | |||
| network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], | |||
| symmetric=[False, False]) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| @@ -68,7 +68,9 @@ if __name__ == '__main__': | |||
| # load checkpoint | |||
| if args_opt.checkpoint_path: | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(network, param_dict) | |||
| not_load_param = load_param_into_net(network, param_dict) | |||
| if not_load_param: | |||
| raise ValueError("Load param into net fail!") | |||
| network.set_train(False) | |||
| # define model | |||
| @@ -25,39 +25,6 @@ from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common import dtype as mstype | |||
| def _load_param_into_net(model, params_dict): | |||
| """ | |||
| load fp32 model parameters to quantization model. | |||
| Args: | |||
| model: quantization model | |||
| params_dict: f32 param | |||
| Returns: | |||
| None | |||
| """ | |||
| iterable_dict = { | |||
| 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), | |||
| 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), | |||
| 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), | |||
| 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), | |||
| 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), | |||
| 'moving_variance': iter( | |||
| [item for item in params_dict.items() if item[0].endswith('moving_variance')]), | |||
| 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), | |||
| 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) | |||
| } | |||
| for name, param in model.parameters_and_names(): | |||
| key_name = name.split(".")[-1] | |||
| if key_name not in iterable_dict.keys(): | |||
| raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") | |||
| value_param = next(iterable_dict[key_name], None) | |||
| if value_param is not None: | |||
| param.set_parameter_data(value_param[1].data) | |||
| print(f'init model param {name} with checkpoint param {value_param[0]}') | |||
| class Monitor(Callback): | |||
| """ | |||
| Monitor loss and time. | |||
| @@ -28,6 +28,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from mindspore.communication.management import init, get_group_size, get_rank | |||
| from mindspore.train.quant import quant | |||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||
| import mindspore.dataset.engine as de | |||
| from src.dataset import create_dataset | |||
| @@ -35,7 +36,6 @@ from src.lr_generator import get_lr | |||
| from src.utils import Monitor, CrossEntropyWithLabelSmooth | |||
| from src.config import config_ascend_quant, config_gpu_quant | |||
| from src.mobilenetV2 import mobilenetV2 | |||
| from src.utils import _load_param_into_net | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| @@ -101,7 +101,7 @@ def train_on_ascend(): | |||
| # load pre trained ckpt | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| _load_param_into_net(network, param_dict) | |||
| load_nonquant_param_into_quant_net(network, param_dict) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, | |||
| bn_fold=True, | |||
| @@ -163,7 +163,7 @@ def train_on_gpu(): | |||
| # resume | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| _load_param_into_net(network, param_dict) | |||
| load_nonquant_param_into_quant_net(network, param_dict) | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, | |||
| @@ -20,12 +20,11 @@ import argparse | |||
| from src.config import quant_set, config_quant, config_noquant | |||
| from src.dataset import create_dataset | |||
| from src.crossentropy import CrossEntropy | |||
| from src.utils import _load_param_into_net | |||
| from models.resnet_quant import resnet50_quant | |||
| from mindspore import context | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.quant import quant | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| @@ -66,7 +65,9 @@ if __name__ == '__main__': | |||
| # load checkpoint | |||
| if args_opt.checkpoint_path: | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| _load_param_into_net(net, param_dict) | |||
| not_load_param = load_param_into_net(net, param_dict) | |||
| if not_load_param: | |||
| raise ValueError("Load param into net fail!") | |||
| net.set_train(False) | |||
| # define model | |||
| @@ -1,46 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """utils script""" | |||
| def _load_param_into_net(model, params_dict): | |||
| """ | |||
| load fp32 model parameters to quantization model. | |||
| Args: | |||
| model: quantization model | |||
| params_dict: f32 param | |||
| Returns: | |||
| None | |||
| """ | |||
| iterable_dict = { | |||
| 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]), | |||
| 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]), | |||
| 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]), | |||
| 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]), | |||
| 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]), | |||
| 'moving_variance': iter( | |||
| [item for item in params_dict.items() if item[0].endswith('moving_variance')]), | |||
| 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]), | |||
| 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')]) | |||
| } | |||
| for name, param in model.parameters_and_names(): | |||
| key_name = name.split(".")[-1] | |||
| if key_name not in iterable_dict.keys(): | |||
| continue | |||
| value_param = next(iterable_dict[key_name], None) | |||
| if value_param is not None: | |||
| param.set_parameter_data(value_param[1].data) | |||
| print(f'init model param {name} with checkpoint param {value_param[0]}') | |||
| @@ -26,6 +26,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.train.serialization import load_checkpoint | |||
| from mindspore.train.quant import quant | |||
| from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net | |||
| from mindspore.communication.management import init | |||
| import mindspore.nn as nn | |||
| import mindspore.common.initializer as weight_init | |||
| @@ -35,7 +36,6 @@ from src.dataset import create_dataset | |||
| from src.lr_generator import get_lr | |||
| from src.config import config_quant | |||
| from src.crossentropy import CrossEntropy | |||
| from src.utils import _load_param_into_net | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | |||
| @@ -85,7 +85,7 @@ if __name__ == '__main__': | |||
| # weight init and load checkpoint file | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| _load_param_into_net(net, param_dict) | |||
| load_nonquant_param_into_quant_net(net, param_dict) | |||
| epoch_size = config.epoch_size - config.pretrained_epoch_size | |||
| else: | |||
| for _, cell in net.cells_and_names(): | |||