From: @caojian05 Reviewed-by: @wuxuejian,@oacjiewen Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -40,6 +40,14 @@ Dataset used can refer to paper. | |||
| - Data format: RGB images. | |||
| - Note: Data will be processed in src/dataset.py | |||
| Dataset used: [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar.html) | |||
| - Dataset size: 175M, 60,000 32\*32 colorful images in 10 classes | |||
| - Train: 146M, 50,000 images | |||
| - Test: 29M, 10,000 images | |||
| - Data format:binary files | |||
| - Note:Data will be processed in src/dataset.py | |||
| # [Features](#contents) | |||
| ## [Mixed Precision(Ascend)](#contents) | |||
| @@ -67,8 +75,13 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||
| └─Inception-v3 | |||
| ├─README.md | |||
| ├─scripts | |||
| ├─run_standalone_train_cpu.sh # launch standalone training with cpu platform | |||
| ├─run_standalone_train_gpu.sh # launch standalone training with gpu platform(1p) | |||
| ├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p) | |||
| ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) | |||
| ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) | |||
| ├─run_eval_cpu.sh # launch evaluation with cpu platform | |||
| ├─run_eval_gpu.sh # launch evaluation with gpu platform | |||
| └─run_eval.sh # launch evaluating with ascend platform | |||
| ├─src | |||
| ├─config.py # parameter configuration | |||
| @@ -93,6 +106,8 @@ Major parameters in train.py and config.py are: | |||
| 'batch_size' # input batchsize | |||
| 'epoch_size' # total epoch numbers | |||
| 'num_classes' # dataset class numbers | |||
| 'ds_type' # dataset type, such as: imagenet, cifar10 | |||
| 'ds_sink_mode' # whether enable dataset sink mode | |||
| 'smooth_factor' # label smoothing factor | |||
| 'aux_factor' # loss factor of aux logit | |||
| 'lr_init' # initiate learning rate | |||
| @@ -127,6 +142,13 @@ sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| ``` | |||
| - CPU: | |||
| ```shell | |||
| # standalone training | |||
| sh scripts/run_standalone_train_cpu.sh DATA_PATH | |||
| ``` | |||
| > Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV3, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. | |||
| > | |||
| > This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` | |||
| @@ -137,6 +159,7 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| # training example | |||
| python: | |||
| Ascend: python train.py --dataset_path DATA_PATH --platform Ascend | |||
| CPU: python train.py --dataset_path DATA_PATH --platform CPU | |||
| shell: | |||
| Ascend: | |||
| @@ -144,12 +167,17 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| # standalone training example | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| CPU: | |||
| sh script/run_standalone_train_cpu.sh DATA_PATH | |||
| ``` | |||
| ### Result | |||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./log.txt` like followings. | |||
| #### Ascend | |||
| ```python | |||
| epoch: 0 step: 1251, loss is 5.7787247 | |||
| epoch time: 360760.985 ms, per step time: 288.378 ms | |||
| @@ -157,6 +185,18 @@ epoch: 1 step: 1251, loss is 4.392868 | |||
| epoch time: 160917.911 ms, per step time: 128.631 ms | |||
| ``` | |||
| #### CPU | |||
| ```bash | |||
| epoch: 1 step: 390, loss is 2.7072601 | |||
| epoch time: 6334572.124 ms, per step time: 16242.493 ms | |||
| epoch: 2 step: 390, loss is 2.5908582 | |||
| epoch time: 6217897.644 ms, per step time: 15943.327 ms | |||
| epoch: 3 step: 390, loss is 2.5612416 | |||
| epoch time: 6358482.104 ms, per step time: 16303.800 ms | |||
| ... | |||
| ``` | |||
| ## [Eval process](#contents) | |||
| ### Usage | |||
| @@ -169,15 +209,23 @@ You can start training using python or shell scripts. The usage of shell scripts | |||
| sh scripts/run_eval.sh DEVICE_ID DATA_PATH PATH_CHECKPOINT | |||
| ``` | |||
| - CPU: | |||
| ```python | |||
| sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT | |||
| ``` | |||
| ### Launch | |||
| ```python | |||
| # eval example | |||
| python: | |||
| Ascend: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform Ascend | |||
| CPU: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform CPU | |||
| shell: | |||
| Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_PATH PATH_CHECKPOINT | |||
| CPU: sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| @@ -236,4 +284,4 @@ In dataset.py, we set the seed inside “create_dataset" function. We also use r | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -51,6 +51,14 @@ InceptionV3的总体网络架构如下: | |||
| - 数据格式:RGB | |||
| - 注:数据将在src/dataset.py中处理。 | |||
| 使用的数据集:[CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>) | |||
| - 数据集大小:175M,共10个类、6万张32*32彩色图像 | |||
| - 训练集:146M,共5万张图像 | |||
| - 测试集:29M,共1万张图像 | |||
| - 数据格式:二进制文件 | |||
| - 注:数据将在src/dataset.py中处理。 | |||
| # 特性 | |||
| ## 混合精度(Ascend) | |||
| @@ -78,9 +86,14 @@ InceptionV3的总体网络架构如下: | |||
| └─Inception-v3 | |||
| ├─README.md | |||
| ├─scripts | |||
| ├─run_standalone_train_cpu.sh # 启动CPU训练 | |||
| ├─run_standalone_train_gpu.sh # 启动GPU单机训练(单卡) | |||
| ├─run_distribute_train_gpu.sh # 启动GPU分布式训练(8卡) | |||
| ├─run_standalone_train.sh # 启动Ascend单机训练(单卡) | |||
| ├─run_distribute_train.sh # 启动Ascend分布式训练(8卡) | |||
| ├─run_eval.sh # 启动Ascend评估 | |||
| ├─run_eval_cpu.sh # 启动CPU评估 | |||
| ├─run_eval_gpu.sh # 启动GPU评估 | |||
| └─run_eval.sh # 启动Ascend评估 | |||
| ├─src | |||
| ├─config.py # 参数配置 | |||
| ├─dataset.py # 数据预处理 | |||
| @@ -106,6 +119,8 @@ train.py和config.py中主要参数如下: | |||
| 'batch_size' # 输入张量的批次大小 | |||
| 'epoch_size' # 总轮次数 | |||
| 'num_classes' # 数据集类数 | |||
| 'ds_type' # 数据集类型,如:imagenet, cifar10 | |||
| 'ds_sink_mode' # 使能数据下沉 | |||
| 'smooth_factor' # 标签平滑因子 | |||
| 'aux_factor' # aux logit的损耗因子 | |||
| 'lr_init' # 初始学习率 | |||
| @@ -149,6 +164,7 @@ train.py和config.py中主要参数如下: | |||
| # 训练示例 | |||
| python: | |||
| Ascend: python train.py --dataset_path /dataset/train --platform Ascend | |||
| CPU: python train.py --dataset_path DATA_PATH --platform CPU | |||
| shell: | |||
| Ascend: | |||
| @@ -156,12 +172,17 @@ train.py和config.py中主要参数如下: | |||
| sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| # 单机训练 | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| CPU: | |||
| sh script/run_standalone_train_cpu.sh DATA_PATH | |||
| ``` | |||
| ### 结果 | |||
| 训练结果保存在示例路径。检查点默认保存在`checkpoint`,训练日志会重定向到`./log.txt`,如下: | |||
| #### Ascend | |||
| ```log | |||
| epoch:0 step:1251, loss is 5.7787247 | |||
| Epoch time:360760.985, per step time:288.378 | |||
| @@ -169,6 +190,18 @@ epoch:1 step:1251, loss is 4.392868 | |||
| Epoch time:160917.911, per step time:128.631 | |||
| ``` | |||
| #### CPU | |||
| ```bash | |||
| epoch: 1 step: 390, loss is 2.7072601 | |||
| epoch time: 6334572.124 ms, per step time: 16242.493 ms | |||
| epoch: 2 step: 390, loss is 2.5908582 | |||
| epoch time: 6217897.644 ms, per step time: 15943.327 ms | |||
| epoch: 3 step: 390, loss is 2.5612416 | |||
| epoch time: 6358482.104 ms, per step time: 16303.800 ms | |||
| ... | |||
| ``` | |||
| ## 评估过程 | |||
| ### 用法 | |||
| @@ -181,15 +214,23 @@ Epoch time:160917.911, per step time:128.631 | |||
| sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| - CPU: | |||
| ```python | |||
| sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT | |||
| ``` | |||
| ### 启动 | |||
| ``` launch | |||
| # 评估示例 | |||
| python: | |||
| Ascend: python eval.py --dataset_path DATA_DIR --checkpoint PATH_CHECKPOINT --platform Ascend | |||
| CPU: python eval.py --dataset_path DATA_PATH --checkpoint PATH_CHECKPOINT --platform CPU | |||
| shell: | |||
| Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| CPU: sh scripts/run_eval_cpu.sh DATA_PATH PATH_CHECKPOINT | |||
| ``` | |||
| > 训练过程中可以生成检查点。 | |||
| @@ -21,33 +21,48 @@ from mindspore import context | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.config import config_gpu as cfg | |||
| from src.dataset import create_dataset | |||
| from src.config import config_gpu, config_ascend, config_cpu | |||
| from src.dataset import create_dataset_imagenet, create_dataset_cifar10 | |||
| from src.inception_v3 import InceptionV3 | |||
| from src.loss import CrossEntropy_Val | |||
| CFG_DICT = { | |||
| "Ascend": config_ascend, | |||
| "GPU": config_gpu, | |||
| "CPU": config_cpu, | |||
| } | |||
| DS_DICT = { | |||
| "imagenet": create_dataset_imagenet, | |||
| "cifar10": create_dataset_cifar10, | |||
| } | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification evaluation') | |||
| parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)') | |||
| parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') | |||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') | |||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform') | |||
| args_opt = parser.parse_args() | |||
| if args_opt.platform == 'Ascend': | |||
| device_id = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(device_id=device_id) | |||
| cfg = CFG_DICT[args_opt.platform] | |||
| create_dataset = DS_DICT[cfg.ds_type] | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform) | |||
| net = InceptionV3(num_classes=cfg.num_classes, is_training=False) | |||
| ckpt = load_checkpoint(args_opt.checkpoint) | |||
| load_param_into_net(net, ckpt) | |||
| net.set_train(False) | |||
| dataset = create_dataset(args_opt.dataset_path, False, 0, 1) | |||
| cfg.rank = 0 | |||
| cfg.group_size = 1 | |||
| dataset = create_dataset(args_opt.dataset_path, False, cfg) | |||
| loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes) | |||
| eval_metrics = {'Loss': nn.Loss(), | |||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | |||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | |||
| model = Model(net, loss, optimizer=None, metrics=eval_metrics) | |||
| metrics = model.eval(dataset) | |||
| metrics = model.eval(dataset, dataset_sink_mode=cfg.ds_sink_mode) | |||
| print("metric: ", metrics) | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| DATA_DIR=$1 | |||
| PATH_CHECKPOINT=$2 | |||
| python ./eval.py --platform 'CPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 & | |||
| @@ -0,0 +1,18 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| DATA_DIR=$1 | |||
| python ./train.py --platform 'CPU' --dataset_path $DATA_DIR > train.log 2>&1 & | |||
| @@ -26,6 +26,8 @@ config_gpu = edict({ | |||
| 'batch_size': 128, | |||
| 'epoch_size': 250, | |||
| 'num_classes': 1000, | |||
| 'ds_type': 'imagenet', | |||
| 'ds_sink_mode': True, | |||
| 'smooth_factor': 0.1, | |||
| 'aux_factor': 0.2, | |||
| 'lr_init': 0.00004, | |||
| @@ -51,6 +53,8 @@ config_ascend = edict({ | |||
| 'batch_size': 128, | |||
| 'epoch_size': 250, | |||
| 'num_classes': 1000, | |||
| 'ds_type': 'imagenet', | |||
| 'ds_sink_mode': True, | |||
| 'smooth_factor': 0.1, | |||
| 'aux_factor': 0.2, | |||
| 'lr_init': 0.00004, | |||
| @@ -67,3 +71,30 @@ config_ascend = edict({ | |||
| 'has_bias': False, | |||
| 'amp_level': 'O3' | |||
| }) | |||
| config_cpu = edict({ | |||
| 'random_seed': 1, | |||
| 'work_nums': 8, | |||
| 'decay_method': 'cosine', | |||
| "loss_scale": 1024, | |||
| 'batch_size': 128, | |||
| 'epoch_size': 120, | |||
| 'num_classes': 10, | |||
| 'ds_type': 'cifar10', | |||
| 'ds_sink_mode': False, | |||
| 'smooth_factor': 0.1, | |||
| 'aux_factor': 0.2, | |||
| 'lr_init': 0.00004, | |||
| 'lr_max': 0.1, | |||
| 'lr_end': 0.000004, | |||
| 'warmup_epochs': 1, | |||
| 'weight_decay': 0.00004, | |||
| 'momentum': 0.9, | |||
| 'opt_eps': 1.0, | |||
| 'keep_checkpoint_max': 10, | |||
| 'ckpt_path': './', | |||
| 'is_save_on_master': 0, | |||
| 'dropout_keep_prob': 0.8, | |||
| 'has_bias': False, | |||
| 'amp_level': 'O0', | |||
| }) | |||
| @@ -15,32 +15,32 @@ | |||
| """ | |||
| Data operations, will be used in train.py and eval.py | |||
| """ | |||
| import os | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.transforms.c_transforms as C2 | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| from src.config import config_gpu as cfg | |||
| def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): | |||
| def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1): | |||
| """ | |||
| create a train or eval dataset | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| rank (int): The shard ID within num_shards (default=None). | |||
| group_size (int): Number of shards that the dataset should be divided into (default=None). | |||
| cfg (dict): the config for creating dataset. | |||
| repeat_num(int): the repeat times of dataset. Default: 1. | |||
| Returns: | |||
| dataset | |||
| """ | |||
| if group_size == 1: | |||
| if cfg.group_size == 1: | |||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) | |||
| else: | |||
| data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, | |||
| num_shards=group_size, shard_id=rank) | |||
| num_shards=cfg.group_size, shard_id=cfg.rank) | |||
| # define map operations | |||
| if do_train: | |||
| trans = [ | |||
| @@ -67,3 +67,44 @@ def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): | |||
| # apply dataset repeat operation | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1): | |||
| """ | |||
| create a train or eval dataset | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| cfg (dict): the config for creating dataset. | |||
| repeat_num(int): the repeat times of dataset. Default: 1. | |||
| Returns: | |||
| dataset | |||
| """ | |||
| dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin") | |||
| if cfg.group_size == 1: | |||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) | |||
| else: | |||
| data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, | |||
| num_shards=cfg.group_size, shard_id=cfg.rank) | |||
| # define map operations | |||
| trans = [] | |||
| if do_train: | |||
| trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4))) | |||
| trans.append(C.RandomHorizontalFlip(prob=0.5)) | |||
| trans.append(C.Resize((299, 299))) | |||
| trans.append(C.Rescale(1.0 / 255.0, 0.0)) | |||
| trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])) | |||
| trans.append(C.HWC2CHW()) | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums) | |||
| data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums) | |||
| # apply batch operations | |||
| data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train) | |||
| # apply dataset repeat operation | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| @@ -29,14 +29,24 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.common.initializer import XavierUniform, initializer | |||
| from mindspore.common import set_seed | |||
| from src.config import config_gpu, config_ascend | |||
| from src.dataset import create_dataset | |||
| from src.config import config_gpu, config_ascend, config_cpu | |||
| from src.dataset import create_dataset_imagenet, create_dataset_cifar10 | |||
| from src.inception_v3 import InceptionV3 | |||
| from src.lr_generator import get_lr | |||
| from src.loss import CrossEntropy | |||
| set_seed(1) | |||
| CFG_DICT = { | |||
| "Ascend": config_ascend, | |||
| "GPU": config_gpu, | |||
| "CPU": config_cpu, | |||
| } | |||
| DS_DICT = { | |||
| "imagenet": create_dataset_imagenet, | |||
| "cifar10": create_dataset_cifar10, | |||
| } | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification training') | |||
| @@ -44,13 +54,16 @@ if __name__ == '__main__': | |||
| parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') | |||
| parser.add_argument('--is_distributed', action='store_true', default=False, | |||
| help='distributed training') | |||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') | |||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform') | |||
| args_opt = parser.parse_args() | |||
| cfg = CFG_DICT[args_opt.platform] | |||
| create_dataset = DS_DICT[cfg.ds_type] | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| cfg = config_ascend if args_opt.platform == 'Ascend' else config_gpu | |||
| # init distributed | |||
| if args_opt.is_distributed: | |||
| init() | |||
| @@ -64,7 +77,7 @@ if __name__ == '__main__': | |||
| cfg.group_size = 1 | |||
| # dataloader | |||
| dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size) | |||
| dataset = create_dataset(args_opt.dataset_path, True, cfg) | |||
| batches_per_epoch = dataset.get_dataset_size() | |||
| # network | |||
| @@ -120,8 +133,8 @@ if __name__ == '__main__': | |||
| if args_opt.is_distributed & cfg.is_save_on_master: | |||
| if cfg.rank == 0: | |||
| callbacks.append(ckpoint_cb) | |||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) | |||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode) | |||
| else: | |||
| callbacks.append(ckpoint_cb) | |||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) | |||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode) | |||
| print("train success") | |||