Merge pull request !2368 from chenzhongming/abctags/v0.5.0-beta
| @@ -0,0 +1,61 @@ | |||
| # LeNet Quantization Example | |||
| ## Description | |||
| Training LeNet with MNIST dataset in MindSpore with aware quantization trainging. | |||
| This is the simple and basic tutorial for constructing a network in MindSpore with quantization. | |||
| ## Requirements | |||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | |||
| - Download the MNIST dataset, the directory structure is as follows: | |||
| ``` | |||
| └─MNIST_Data | |||
| ├─test | |||
| │ t10k-images.idx3-ubyte | |||
| │ t10k-labels.idx1-ubyte | |||
| └─train | |||
| train-images.idx3-ubyte | |||
| train-labels.idx1-ubyte | |||
| ``` | |||
| ## Running the example | |||
| ```python | |||
| # train LeNet, hyperparameter setting in config.py | |||
| python train.py --data_path MNIST_Data | |||
| ``` | |||
| You will get the loss value of each step as following: | |||
| ```bash | |||
| Epoch: [ 1/ 10] step: [ 1 / 900], loss: [2.3040/2.5234], time: [1.300234] | |||
| ... | |||
| Epoch: [ 10/ 10] step: [887 / 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| Epoch: [ 10/ 10] step: [888 / 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| Epoch: [ 10/ 10] step: [889 / 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| ... | |||
| ``` | |||
| Then, evaluate LeNet according to network model | |||
| ```python | |||
| python eval.py --data_path MNIST_Data --ckpt_path checkpoint_lenet-1_1875.ckpt | |||
| ``` | |||
| ## Note | |||
| Here are some optional parameters: | |||
| ```bash | |||
| --device_target {Ascend,GPU,CPU} | |||
| device where the code will be implemented (default: Ascend) | |||
| --data_path DATA_PATH | |||
| path where the dataset is saved | |||
| --dataset_sink_mode DATASET_SINK_MODE | |||
| dataset_sink_mode is False or True | |||
| ``` | |||
| You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## eval lenet example ######################## | |||
| eval lenet according to model file: | |||
| python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) | |||
| step_size = ds_eval.get_dataset_size() | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| repeat_size = cfg.epoch_size | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== {} ==============".format(acc)) | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## eval lenet example ######################## | |||
| eval lenet according to model file: | |||
| python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.quant import quant | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) | |||
| step_size = ds_eval.get_dataset_size() | |||
| # define funsion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert funsion netwrok to aware quantizaiton network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| # load aware quantizaiton network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== {} ==============".format(acc)) | |||
| @@ -0,0 +1,31 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py | |||
| """ | |||
| from easydict import EasyDict as edict | |||
| mnist_cfg = edict({ | |||
| 'num_classes': 10, | |||
| 'lr': 0.01, | |||
| 'momentum': 0.9, | |||
| 'epoch_size': 10, | |||
| 'batch_size': 64, | |||
| 'buffer_size': 1000, | |||
| 'image_height': 32, | |||
| 'image_width': 32, | |||
| 'keep_checkpoint_max': 10, | |||
| }) | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Produce the dataset | |||
| """ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.vision.c_transforms as CV | |||
| import mindspore.dataset.transforms.c_transforms as C | |||
| from mindspore.dataset.transforms.vision import Inter | |||
| from mindspore.common import dtype as mstype | |||
| def create_dataset(data_path, batch_size=32, repeat_size=1, | |||
| num_parallel_workers=1): | |||
| """ | |||
| create dataset for train or test | |||
| """ | |||
| # define dataset | |||
| mnist_ds = ds.MnistDataset(data_path) | |||
| resize_height, resize_width = 32, 32 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| rescale_nml = 1 / 0.3081 | |||
| shift_nml = -1 * 0.1307 / 0.3081 | |||
| # define map operations | |||
| resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode | |||
| rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) | |||
| rescale_op = CV.Rescale(rescale, shift) | |||
| hwc2chw_op = CV.HWC2CHW() | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| # apply map operations on images | |||
| mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers) | |||
| mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers) | |||
| # apply DatasetOps | |||
| buffer_size = 10000 | |||
| mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script | |||
| mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) | |||
| mnist_ds = mnist_ds.repeat(repeat_size) | |||
| return mnist_ds | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LeNet.""" | |||
| import mindspore.nn as nn | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| Args: | |||
| num_class (int): Num classes. Default: 10. | |||
| Returns: | |||
| Tensor, output tensor | |||
| Examples: | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10, channel=1): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = nn.Conv2d(channel, 6, 5) | |||
| self.conv2 = nn.Conv2d(6, 16, 5) | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc3 = nn.Dense(84, self.num_class) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @@ -0,0 +1,57 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LeNet.""" | |||
| import mindspore.nn as nn | |||
| class LeNet5(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| Args: | |||
| num_class (int): Num classes. Default: 10. | |||
| Returns: | |||
| Tensor, output tensor | |||
| Examples: | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10, channel=1): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| # change `nn.Conv2d` to `nn.Conv2dBnAct` | |||
| self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu') | |||
| # change `nn.Dense` to `nn.DenseBnAct` | |||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||
| self.fc3 = nn.DenseBnAct(84, self.num_class) | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.fc2(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## train lenet example ######################## | |||
| train lenet and get network model files(.ckpt) : | |||
| python train.py --data_path /YourDataPath | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) | |||
| step_size = ds_train.get_dataset_size() | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| ######################## train lenet example ######################## | |||
| train lenet and get network model files(.ckpt) : | |||
| python train.py --data_path /YourDataPath | |||
| """ | |||
| import os | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.quant import quant | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet_fusion import LeNet5 as LeNet5Fusion | |||
| parser = argparse.ArgumentParser(description='MindSpore MNIST Example') | |||
| parser.add_argument('--device_target', type=str, default="Ascend", | |||
| choices=['Ascend', 'GPU', 'CPU'], | |||
| help='device where the code will be implemented (default: Ascend)') | |||
| parser.add_argument('--data_path', type=str, default="./MNIST_Data", | |||
| help='path where the dataset is saved') | |||
| parser.add_argument('--ckpt_path', type=str, default="", | |||
| help='if mode is test, must provide path where the trained ckpt file') | |||
| parser.add_argument('--dataset_sink_mode', type=bool, default=True, | |||
| help='dataset_sink_mode is False or True') | |||
| args = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) | |||
| step_size = ds_train.get_dataset_size() | |||
| # define funsion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # load aware quantizaiton network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| # convert funsion netwrok to aware quantizaiton network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||