From acf0b05afcd598b46d18d1f82b6ee2de63247829 Mon Sep 17 00:00:00 2001 From: MapleGrove Date: Sat, 27 Mar 2021 22:07:56 +0800 Subject: [PATCH] add simclr --- model_zoo/official/cv/simclr/README.md | 220 ++++++++ model_zoo/official/cv/simclr/export.py | 62 +++ model_zoo/official/cv/simclr/linear_eval.py | 215 ++++++++ .../simclr/scripts/run_distribution_ascend.sh | 64 +++ .../scripts/run_standalone_eval_ascend.sh | 37 ++ .../scripts/run_standalone_train_ascend.sh | 31 ++ model_zoo/official/cv/simclr/src/__init__.py | 0 model_zoo/official/cv/simclr/src/dataset.py | 94 ++++ .../official/cv/simclr/src/lr_generator.py | 198 +++++++ model_zoo/official/cv/simclr/src/nt_xent.py | 91 ++++ model_zoo/official/cv/simclr/src/optimizer.py | 52 ++ model_zoo/official/cv/simclr/src/reporter.py | 135 +++++ model_zoo/official/cv/simclr/src/resnet.py | 485 ++++++++++++++++++ .../official/cv/simclr/src/simclr_model.py | 53 ++ model_zoo/official/cv/simclr/train.py | 164 ++++++ 15 files changed, 1901 insertions(+) create mode 100644 model_zoo/official/cv/simclr/README.md create mode 100644 model_zoo/official/cv/simclr/export.py create mode 100644 model_zoo/official/cv/simclr/linear_eval.py create mode 100644 model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh create mode 100644 model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh create mode 100644 model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh create mode 100644 model_zoo/official/cv/simclr/src/__init__.py create mode 100644 model_zoo/official/cv/simclr/src/dataset.py create mode 100644 model_zoo/official/cv/simclr/src/lr_generator.py create mode 100644 model_zoo/official/cv/simclr/src/nt_xent.py create mode 100644 model_zoo/official/cv/simclr/src/optimizer.py create mode 100644 model_zoo/official/cv/simclr/src/reporter.py create mode 100644 model_zoo/official/cv/simclr/src/resnet.py create mode 100644 model_zoo/official/cv/simclr/src/simclr_model.py create mode 100644 model_zoo/official/cv/simclr/train.py diff --git a/model_zoo/official/cv/simclr/README.md b/model_zoo/official/cv/simclr/README.md new file mode 100644 index 0000000000..678f755380 --- /dev/null +++ b/model_zoo/official/cv/simclr/README.md @@ -0,0 +1,220 @@ +# Contents + +- [SimCLR Description](#simclr-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +## [SimCLR Description](#contents) + +SimCLR: a simple framework for contrastive learning of visual representations. +[Paper](https://arxiv.org/pdf/2002.05709.pdf): Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton. A Simple Framework for Contrastive Learning of Visual Representations. *arXiv preprint arXiv:2002.05709*. 2020. + +## [Model Architecture](#contents) + +SimCLR learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space. This framework comprises the following four major components: a stochastic data augmentation module, a neural network base encoder, a small neural network projection head and a contrastive loss function. + +## [Dataset](#contents) + +In the following sections, we will introduce how to run the scripts using the related dataset below. + +Dataset used: [CIFAR-10]() + +- Dataset size:175M,60,000 32*32 colorful images in 10 classes + - Train:146M,50,000 images + - Test:29.3M,10,000 images +- Data format:binary files + - Note:Data will be processed in dataset.py +- Download the dataset, the directory structure is as follows: + +```bash +├─cifar-10-batches-bin +│ +└─cifar-10-verify-bin +``` + +## [Environment Requirements](#contents) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +## [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +```python +# enter script dir, train SimCLR +sh run_standalone_train_ascend.sh [cifar10] [TRAIN_DATASET_PATH] [DEVICE_ID] +or +sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH] +# enter script dir, evaluate SimCLR +sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH] +``` + +## [Script Description](#contents) + +### [Script and Sample Code](#contents) + +```bash +├── cv + ├── SimCLR + ├── README.md // descriptions about SimCLR + ├── requirements.txt // package needed + ├── scripts + │ ├──run_distribution_train_ascend.sh // train in ascend + │ ├──run_standalone_train_ascend.sh // train in ascend + │ ├──run_standalone_eval_ascend.sh // evaluate in ascend + ├── src + │ ├──dataset.py // creating dataset + │ ├──lr_generator.py // generating learning rate + │ ├──nt_xent.py // contrastive cross entropy loss + │ ├──optimizer.py // generating optimizer + │ ├──resnet.py // base encoder network + │ ├──simclr_model.py // simclr architecture + ├── train.py // training script + ├── linear_eval.py // linear evaluation script + ├── export.py // export model for inference +``` + +### [Script Parameters](#contents) + +```python +Major parameters in train.py as follows: +--device_target: Device target, Currently only Ascend is supported. +--run_cloudbrain: Whether it is running on CloudBrain platform. +--run_distribute: Run distributed training. +--device_num: Device num. +--device_id: Device id, default is 0. +--dataset_name: Dataset, Currently only cifar10 is supported. +--train_url: Cloudbrain Location of training outputs.This parameter needs to be set when running on the cloud brain platform. +--data_url: Cloudbrain Location of data. This parameter needs to be set when running on the cloud brain platform. +--train_dataset_path: Dataset path for training classifier. This parameter needs to be set when running on the host. +--train_output_path: Location of ckpt and log. This parameter needs to be set when running on the host. +--batch_size: Batch size, default is 128. +--epoch_size: Epoch size for training, default is 100. +--projection_dimension: Projection output dimensionality, default is 128. +--width_multiplier: Width multiplier for ResNet50, default is 1. +--temperature: Temperature for contrastive cross entropy loss. +--pre_trained_path: Pretrained checkpoint path. +--pretrain_epoch_size: real_epoch_size = epoch_size - pretrain_epoch_size. +save_checkpoint_epochs: Save checkpoint epochs, default is 1. +--save_graphs: Whether save graphs, default is False. +--optimizer: Optimizer, Currently only Adam is supported. +--weight_decay: Weight decay. +--warmup_epochs: Warmup epochs. + +Major parameters in linear_eval.py as follows: +--device_target: Device target, Currently only Ascend is supported. +--run_cloudbrain: Whether it is running on CloudBrain platform. +--run_distribute: Run distributed training. +--device_num: Device num. +--device_id: Device id, default is 0. +--dataset_name: Dataset, Currently only cifar10 is supported. +--train_url: Cloudbrain Location of training outputs.This parameter needs to be set when running on the cloud brain platform. +--data_url: Cloudbrain Location of data. This parameter needs to be set when running on the cloud brain platform. +--train_dataset_path: Dataset path for training classifier. This parameter needs to be set when running on the host. +--eval_dataset_path: Dataset path for evaluating classifier.This parameter needs to be set when running on the host. +--train_output_path: Location of ckpt and log. This parameter needs to be set when running on the host. +--class_num: dataset classification number, default is 10 for cifar10. +--batch_size: Batch size, default is 128. +--epoch_size: Epoch size for training, default is 100. +--projection_dimension: Projection output dimensionality, default is 128. +--width_multiplier: Width multiplier for ResNet50, default is 1. +--pre_classifier_checkpoint_path: Classifier Checkpoint file path. +--encoder_checkpoint_path: Encoder Checkpoint file path. +--save_checkpoint_epochs: Save checkpoint epochs, default is 10. +--print_iter: Log print iter, default is 100. +--save_graphs: whether save graphs, default is False. +``` + +### [Training Process](#contents) + +#### Training + +- running on Ascend + + ```bash + sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH] + ``` + + After training, the loss value will be achieved as follows: + + ```bash + # grep "loss is " log + epoch: 1 step: 48, loss is 9.5758915 + epoch time: 253236.075 ms, per step time: 5275.752 ms + epoch: 1 step: 48, loss is 9.363186 + epoch time: 253739.376 ms, per step time: 5286.237 ms + epoch: 1 step: 48, loss is 9.36029 + epoch time: 253711.625 ms, per step time: 5285.659 ms + ... + epoch: 100 step: 48, loss is 7.453776 + epoch time: 12341.851 ms, per step time: 257.122 ms + epoch: 100 step: 48, loss is 7.499168 + epoch time: 12420.060 ms, per step time: 258.751 ms + epoch: 100 step: 48, loss is 7.442362 + epoch time: 12725.863 ms, per step time: 265.122 ms + ... + ``` + + The model checkpoint will be saved in the outputs directory. +### [Evaluation Process](#contents) +#### Evaluation +Before running the command below, please check the checkpoint path used for evaluation. + +- running on Ascend + + ```bash + sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH] + ``` + + You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows: + + ```bash + # grep "Average accuracy: " eval_log + 'Accuracy': 0.84505 + ``` + +## [Model Description](#contents) + +### [Performance](#contents) + +#### Evaluation Performance + +| Parameters | Ascend | +| -------------------------- | ------------------------------------------------------------| +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | +| uploaded Date | 30/03/2021 (month/day/year) | +| MindSpore Version | 1.1.1 | +| Dataset | CIFAR-10 | +| Training Parameters | epoch=100, batch_size=128, device_num=8 | +| Optimizer | Adam | +| Loss Function | NT-Xent Loss | +| linear eval | 84.505% | +| Total time | 25m04s | +| Scripts | [SimCLR Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/simclr) | [SimCLR Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/simclr) | + +## [Description of Random Situation](#contents) + +We set the seed inside dataset.py. We also use random seed in train.py. + +## [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/simclr/export.py b/model_zoo/official/cv/simclr/export.py new file mode 100644 index 0000000000..268b7e5542 --- /dev/null +++ b/model_zoo/official/cv/simclr/export.py @@ -0,0 +1,62 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air, onnx, mindir models################# +python export.py +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export + +from src.simclr_model import SimCLR +from src.resnet import resnet50 as resnet + +parser = argparse.ArgumentParser(description='SimCLR') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=128, help="batch size") +parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10'], + help='Dataset, Currently only cifar10 is supported.') +parser.add_argument('--device_target', type=str, default="Ascend", + choices=['Ascend'], + help='Device target, Currently only Ascend is supported.') +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="simclr", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) +if args_opt.device_target == "Ascend": + context.set_context(device_id=args_opt.device_id) + +if __name__ == '__main__': + if args_opt.dataset_name == 'cifar10': + width_multiplier = 1 + cifar_stem = True + projection_dimension = 128 + image_height = 32 + image_width = 32 + else: + raise ValueError("dataset is not support.") + + base_net = resnet(1, width_multiplier=width_multiplier, cifar_stem=cifar_stem) + net = SimCLR(base_net, projection_dimension, base_net.end_point.in_channels) + + param_dict = load_checkpoint(args_opt.ckpt_file) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.zeros([args_opt.batch_size, 3, image_height, image_width]), ms.float32) + export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format) diff --git a/model_zoo/official/cv/simclr/linear_eval.py b/model_zoo/official/cv/simclr/linear_eval.py new file mode 100644 index 0000000000..0a78d3a86e --- /dev/null +++ b/model_zoo/official/cv/simclr/linear_eval.py @@ -0,0 +1,215 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## eval SimCLR example ######################## +eval SimCLR according to model file: +python eval.py --encoder_checkpoint_path Your.ckpt --train_dataset_path /YourDataPath1 + --eval_dataset_path /YourDataPath2 +""" +import ast +import os +import argparse +import numpy as np +import mindspore.common.dtype as mstype +from mindspore import nn +from mindspore import ops +from mindspore import context +from mindspore.common.initializer import TruncatedNormal +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +from src.dataset import create_dataset +from src.simclr_model import SimCLR +from src.resnet import resnet50 as resnet +from src.reporter import Reporter +from src.optimizer import get_eval_optimizer as get_optimizer + + + +parser = argparse.ArgumentParser(description='Linear Evaluation Protocol') +parser.add_argument('--device_target', type=str, default='Ascend', + help='Device target, Currently only Ascend is supported.') +parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Running distributed evaluation.') +parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True, + help='Whether it is running on CloudBrain platform.') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.') +parser.add_argument('--dataset_name', type=str, default='cifar10', help='Dataset, Currently only cifar10 is supported.') +parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\ + This parameter needs to be set when running on the cloud brain platform.') +parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\ + This parameter needs to be set when running on the cloud brain platform.') +parser.add_argument('--train_dataset_path', type=str, default='./cifar/train',\ + help='Dataset path for training classifier.\ + This parameter needs to be set when running on the host.') +parser.add_argument('--eval_dataset_path', type=str, default='./cifar/eval',\ + help='Dataset path for evaluating classifier.\ + This parameter needs to be set when running on the host.') +parser.add_argument('--train_output_path', type=str, default='./outputs', help='Location of ckpt and log.\ + This parameter needs to be set when running on the host.') +parser.add_argument('--class_num', type=int, default=10, help='dataset classification number, default is 10.') +parser.add_argument('--batch_size', type=int, default=128, help='batch_size for training classifier, default is 128.') +parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training classifier, default is 100.') +parser.add_argument('--projection_dimension', type=int, default=128, + help='Projection output dimensionality, default is 128.') +parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier=4,resnet50x4') +parser.add_argument('--pre_classifier_checkpoint_path', type=str, default=None, help='Classifier Checkpoint file path.') +parser.add_argument('--encoder_checkpoint_path', type=str, help='Encoder Checkpoint file path.') +parser.add_argument('--save_checkpoint_epochs', type=int, default=10, help='Save checkpoint epochs, default is 10.') +parser.add_argument('--print_iter', type=int, default=100, help='log print iter, default is 100.') +parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, + help='whether save graphs, default is False.') +parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Dataset normalize.') + +args = parser.parse_args() +set_seed(1) +local_data_url = './cache/data' +local_train_url = './cache/train' +_local_train_url = local_train_url + +if args.device_target != "Ascend": + raise ValueError("Unsupported device target.") +if args.run_distribute: + device_id = os.getenv("DEVICE_ID", default=None) + if device_id is None: + raise ValueError("Unsupported device id.") + args.device_id = int(device_id) + rank_size = os.getenv("RANK_SIZE", default=None) + if rank_size is None: + raise ValueError("Unsupported rank size.") + if args.device_num > int(rank_size) or args.device_num == 1: + args.device_num = int(rank_size) + context.set_context(device_id=args.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, device_num=args.device_num) + init() + args.rank = get_rank() + local_data_url = os.path.join(local_data_url, str(args.device_id)) + local_train_url = os.path.join(local_train_url, str(args.device_id)) + args.train_output_path = os.path.join(args.train_output_path, str(args.device_id)) +else: + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, + save_graphs=args.save_graphs, device_id=args.device_id) + args.rank = 0 + args.device_num = 1 + +if args.run_cloudbrain: + import moxing as mox + args.train_dataset_path = os.path.join(local_data_url, "train") + args.eval_dataset_path = os.path.join(local_data_url, "val") + args.train_output_path = local_train_url + mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url) + +class LogisticRegression(nn.Cell): + """ + Logistic regression + """ + def __init__(self, n_features, n_classes): + super(LogisticRegression, self).__init__() + self.model = nn.Dense(n_features, n_classes, TruncatedNormal(0.02), TruncatedNormal(0.02)) + + def construct(self, x): + x = self.model(x) + return x + +class Linear_Eval(): + """ + Linear classifier + """ + def __init__(self, net, loss): + super(Linear_Eval, self).__init__() + self.net = net + self.softmax = nn.Softmax() + self.loss = loss + def __call__(self, x, y): + x = self.net(x) + loss = self.loss(x, y) + x = self.softmax(x) + predicts = ops.Argmax(output_type=mstype.int32)(x) + acc = np.sum(predicts.asnumpy() == y.asnumpy())/len(y.asnumpy()) + return loss.asnumpy(), acc + +class Linear_Train(nn.Cell): + """ + Train linear classifier + """ + def __init__(self, net, loss, opt): + super(Linear_Train, self).__init__() + self.netwithloss = nn.WithLossCell(net, loss) + self.train_net = nn.TrainOneStepCell(self.netwithloss, opt) + self.train_net.set_train() + def construct(self, x, y): + return self.train_net(x, y) + +if __name__ == "__main__": + base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10") + simclr_model = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels) + if args.run_cloudbrain: + mox.file.copy_parallel(src_url=args.encoder_checkpoint_path, dst_url=local_data_url+'/encoder.ckpt') + simclr_param = load_checkpoint(local_data_url+'/encoder.ckpt') + else: + simclr_param = load_checkpoint(args.encoder_checkpoint_path) + load_param_into_net(simclr_model.encoder, simclr_param) + classifier = LogisticRegression(simclr_model.n_features, args.class_num) + dataset = create_dataset(args, dataset_mode="train_classifier") + optimizer = get_optimizer(classifier, dataset.get_dataset_size(), args) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_Train = Linear_Train(net=classifier, loss=criterion, opt=optimizer) + reporter = Reporter(args, linear_eval=True) + reporter.dataset_size = dataset.get_dataset_size() + reporter.linear_eval = True + if args.pre_classifier_checkpoint_path: + if args.run_cloudbrain: + mox.file.copy_parallel(src_url=args.pre_classifier_checkpoint_path, + dst_url=local_data_url+'/pre_classifier.ckpt') + classifier_param = load_checkpoint(local_data_url+'/pre_classifier.ckpt') + else: + classifier_param = load_checkpoint(args.pre_classifier_checkpoint_path) + load_param_into_net(classifier, classifier_param) + else: + dataset_train = [] + for _, data in enumerate(dataset, start=1): + _, images, labels = data + features = simclr_model.inference(images) + dataset_train.append([features, labels]) + reporter.info('==========start training linear classifier===============') + # Train. + for _ in range(args.epoch_size): + reporter.epoch_start() + for idx, data in enumerate(dataset_train, start=1): + features, labels = data + out = net_Train(features, labels) + reporter.step_end(out) + reporter.epoch_end(classifier) + reporter.info('==========end training linear classifier===============') + + dataset = create_dataset(args, dataset_mode="eval_classifier") + reporter.dataset_size = dataset.get_dataset_size() + net_Eval = Linear_Eval(net=classifier, loss=criterion) + # Eval. + reporter.info('==========start evaluating linear classifier===============') + reporter.start_predict() + for idx, data in enumerate(dataset, start=1): + _, images, labels = data + features = simclr_model.inference(images) + batch_loss, batch_acc = net_Eval(features, labels) + reporter.predict_step_end(batch_loss, batch_acc) + reporter.end_predict() + reporter.info('==========end evaluating linear classifier===============') + if args.run_cloudbrain: + mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url) diff --git a/model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh b/model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh new file mode 100644 index 0000000000..23d65a093c --- /dev/null +++ b/model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 4 ] +then + echo "Usage: sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH]" +exit 1 +fi + +# +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +# +if [ ! -f $2 ] +then + echo "error: RANK_TABLE_FILE=$2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=$1 +export RANK_SIZE=$1 +RANK_TABLE_FILE=$(get_real_path $2) +export RANK_TABLE_FILE +export DATASET_NAME=$3 +export TRAIN_DATASET_PATH=$(get_real_path $4) +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ../src ./train_parallel$i + cp ../train.py ./train_parallel$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + cd ./train_parallel$i ||exit + env > env.log + python train.py --device_id=$i --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \ + --run_cloudbrain=False --run_distribute=True > log 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh b/model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh new file mode 100644 index 0000000000..cd76ec12c3 --- /dev/null +++ b/model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 5 ] +then + echo "Usage: sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH]" +exit 1 +fi + +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export DATASET_NAME=$1 +export DEVICE_ID=$2 +export SIMCLR_MODEL_PATH=$3 +export TRAIN_DATASET_PATH=$4 +export EVAL_DATASET_PATH=$5 + + +python ${self_path}/../linear_eval.py --dataset_name=$DATASET_NAME \ + --encoder_checkpoint_path=$SIMCLR_MODEL_PATH \ + --train_dataset_path=$TRAIN_DATASET_PATH \ + --eval_dataset_path=$EVAL_DATASET_PATH \ + --device_id=$DEVICE_ID --device_target="Ascend" \ + --run_distribute=False --run_cloudbrain=False > eval_log 2>&1 & diff --git a/model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh new file mode 100644 index 0000000000..81d9f3bfd4 --- /dev/null +++ b/model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# an simple tutorial as follows, more parameters can be setting +if [ $# != 3 ] +then + echo "Usage: sh run_standalone_train_ascend.sh [cifar10] [TRAIN_DATASET_PATH] [DEVICE_ID]" +exit 1 +fi + +script_self=$(readlink -f "$0") +self_path=$(dirname "${script_self}") +export DATASET_NAME=$1 +export TRAIN_DATASET_PATH=$2 +export DEVICE_ID=$3 + +python ${self_path}/../train.py --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \ + --device_id=$DEVICE_ID --device_target="Ascend" \ + --run_cloudbrain=False --run_distribute=False > log 2>&1 & \ No newline at end of file diff --git a/model_zoo/official/cv/simclr/src/__init__.py b/model_zoo/official/cv/simclr/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/simclr/src/dataset.py b/model_zoo/official/cv/simclr/src/dataset.py new file mode 100644 index 0000000000..f62e5bbcc9 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/dataset.py @@ -0,0 +1,94 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.vision.py_transforms as py_vision +from mindspore.dataset.vision import Inter +import cv2 +import numpy as np + +ds.config.set_seed(0) + +def gaussian_blur(im): + sigma = 0 + _, w = im.shape[:2] + kernel_size = int(w // 10) + if kernel_size % 2 == 0: + kernel_size -= 1 + return np.array(cv2.GaussianBlur(im, (kernel_size, kernel_size), sigma)) + +def copy_column(x, y): + return x, x, y + +def create_dataset(args, dataset_mode, repeat_num=1): + """ + create a train or evaluate cifar10 dataset for SimCLR + """ + if args.dataset_name != "cifar10": + raise ValueError("Unsupported dataset.") + if dataset_mode in ("train_endcoder", "train_classifier"): + dataset_path = args.train_dataset_path + else: + dataset_path = args.eval_dataset_path + if args.run_distribute and args.device_target == "Ascend": + data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=args.device_num, shard_id=args.device_id) + else: + data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True) + # define map operations + trans = [] + if dataset_mode == "train_endcoder": + if args.use_crop: + trans += [C.Resize(256, interpolation=Inter.BICUBIC)] + trans += [C.RandomResizedCrop(size=(32, 32), scale=(0.31, 1), + interpolation=Inter.BICUBIC, max_attempts=100)] + if args.use_flip: + trans += [C.RandomHorizontalFlip(prob=0.5)] + if args.use_color_jitter: + scale = 0.6 + color_jitter = C.RandomColorAdjust(0.8 * scale, 0.8 * scale, 0.8 * scale, 0.2 * scale) + trans += [C2.RandomApply([color_jitter], prob=0.8)] + if args.use_color_gray: + trans += [py_vision.ToPIL(), + py_vision.RandomGrayscale(prob=0.2), + np.array] # need to convert PIL image to a NumPy array to pass it to C++ operation + if args.use_blur: + trans += [C2.RandomApply([gaussian_blur], prob=0.8)] + if args.use_norm: + trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])] + trans += [C2.TypeCast(mstype.float32), C.HWC2CHW()] + else: + trans += [C.Resize(32)] + trans += [C2.TypeCast(mstype.float32)] + if args.use_norm: + trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])] + trans += [C.HWC2CHW()] + type_cast_op = C2.TypeCast(mstype.int32) + data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + data_set = data_set.map(operations=copy_column, input_columns=["image", "label"], + output_columns=["image1", "image2", "label"], + column_order=["image1", "image2", "label"], num_parallel_workers=8) + data_set = data_set.map(operations=trans, input_columns=["image1"], num_parallel_workers=8) + data_set = data_set.map(operations=trans, input_columns=["image2"], num_parallel_workers=8) + # apply batch operations + data_set = data_set.batch(args.batch_size, drop_remainder=True) + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + return data_set diff --git a/model_zoo/official/cv/simclr/src/lr_generator.py b/model_zoo/official/cv/simclr/src/lr_generator.py new file mode 100644 index 0000000000..ee443f0e85 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/lr_generator.py @@ -0,0 +1,198 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + + +def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps): + """ + Applies three steps decay to generate learning rate array. + """ + decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + else: + lr = lr_max * 0.001 + lr_each_step.append(lr) + return lr_each_step + + +def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies polynomial decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + lr_each_step = [] + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + return lr_each_step + + +def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies cosine decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + decay_steps = total_steps - warmup_steps + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + return lr_each_step + + +def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies liner decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + return lr_each_step + + + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default) + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + + if lr_decay_mode == 'steps': + lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps) + elif lr_decay_mode == 'poly': + lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + elif lr_decay_mode == 'cosine': + lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + else: + lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0): + """ + generate learning rate array with cosine + + Args: + lr(float): base learning rate + steps_per_epoch(int): steps size of one epoch + warmup_epochs(int): number of warmup epochs + max_epoch(int): total epochs of training + global_step(int): the current start index of lr array + Returns: + np.array, learning rate array + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + decay_steps = total_steps - warmup_steps + + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = base_lr * decayed + lr_each_step.append(lr) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[global_step:] + return learning_rate diff --git a/model_zoo/official/cv/simclr/src/nt_xent.py b/model_zoo/official/cv/simclr/src/nt_xent.py new file mode 100644 index 0000000000..472cfb1864 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/nt_xent.py @@ -0,0 +1,91 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SimCLR Loss class.""" + +from mindspore import Tensor +from mindspore import ops as P +from mindspore.common import dtype as mstype +import mindspore.nn as nn + + +class CrossEntropyLoss(nn.Cell): + """ + Cross Entropy Loss. + """ + def __init__(self, reduction="mean"): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + if reduction == "sum": + self.reduction = P.ReduceSum() + if reduction == "mean": + self.reduction = P.ReduceMean() + self.one_hot = P.OneHot() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + + def construct(self, logits, label): + loss = self.cross_entropy(logits, label)[0] + loss = self.reduction(loss, (-1,)) + return loss + + +class NT_Xent_Loss(nn.Cell): + """ + Loss for SimCLR. + """ + def __init__(self, batch_size, temperature=1, world_size=1): + super(NT_Xent_Loss, self).__init__() + # Parameters. + self.LARGE_NUM = 1e9 + self.batch_size = batch_size + self.temperature = temperature + self.world_size = world_size + self.N = 2 * self.batch_size * self.world_size + # Tail_Loss. + self.criterion = CrossEntropyLoss(reduction="mean") + self.norm = P.L2Normalize(axis=1) + self.one_hot = P.OneHot() + self.range = nn.Range(0, self.batch_size) + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + self.transpose = P.Transpose() + self.matmul = nn.MatMul() + # Operations. + self.ones = P.Ones() + self.zeros = P.Zeros() + self.cat1 = P.Concat(axis=1) + + def construct(self, z_i, z_j): + """ + Forward. + """ + hidden1 = self.norm(z_i) + hidden2 = self.norm(z_j) + hidden1_large = hidden1 + hidden2_large = hidden2 + ones_mask = self.range() + zeros_mask = self.zeros((self.batch_size, self.batch_size), mstype.float32) + masks = self.one_hot(ones_mask, self.batch_size, self.one, self.zero) + labels = self.cat1((masks, zeros_mask)) + logits_aa = self.matmul(hidden1, self.transpose(hidden1_large, (1, 0))) / self.temperature + logits_aa = logits_aa - masks * self.LARGE_NUM + logits_bb = self.matmul(hidden2, self.transpose(hidden2_large, (1, 0))) / self.temperature + logits_bb = logits_bb - masks * self.LARGE_NUM + logits_ab = self.matmul(hidden1, self.transpose(hidden2_large, (1, 0))) / self.temperature + logits_ba = self.matmul(hidden2, self.transpose(hidden1_large, (1, 0))) / self.temperature + loss_a = self.criterion(self.cat1((logits_ab, logits_aa)), labels) + loss_b = self.criterion(self.cat1((logits_ba, logits_bb)), labels) + loss = loss_a + loss_b + return loss diff --git a/model_zoo/official/cv/simclr/src/optimizer.py b/model_zoo/official/cv/simclr/src/optimizer.py new file mode 100644 index 0000000000..6345a44646 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/optimizer.py @@ -0,0 +1,52 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""optimizer generator""" +from mindspore import nn, Tensor +from .lr_generator import get_lr + +def get_train_optimizer(net, steps_per_epoch, args): + """ + generate optimizer for updating the weights. + """ + if args.optimizer == "Adam": + lr = get_lr(lr_init=1e-4, lr_end=1e-6, lr_max=9e-4, + warmup_epochs=args.warmup_epochs, total_epochs=args.epoch_size, + steps_per_epoch=steps_per_epoch, + lr_decay_mode="linear") + lr = Tensor(lr) + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + group_params = [{'params': decayed_params, 'weight_decay': args.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + optimizer = nn.Adam(params=group_params, learning_rate=lr) + else: + raise ValueError("Unsupported optimizer.") + + return optimizer + +def get_eval_optimizer(net, steps_per_epoch, args): + lr = get_lr(lr_init=1e-3, lr_end=6e-6, lr_max=1e-2, + warmup_epochs=5, total_epochs=args.epoch_size, + steps_per_epoch=steps_per_epoch, + lr_decay_mode="linear") + lr = Tensor(lr) + optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr) + return optimizer diff --git a/model_zoo/official/cv/simclr/src/reporter.py b/model_zoo/official/cv/simclr/src/reporter.py new file mode 100644 index 0000000000..6420cf4d94 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/reporter.py @@ -0,0 +1,135 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Reporter class.""" +import logging +import os +import time +from datetime import datetime +from mindspore.train.serialization import save_checkpoint + +class Reporter(logging.Logger): + """ + This class includes several functions that can save images/checkpoints and print/save logging information. + """ + def __init__(self, args, linear_eval): + super(Reporter, self).__init__("clean") + self.log_dir = os.path.join(args.train_output_path, 'log') + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir, exist_ok=True) + if linear_eval: + self.ckpts_dir = os.path.join(args.train_output_path, "checkpoint") + if not os.path.exists(self.ckpts_dir): + os.makedirs(self.ckpts_dir, exist_ok=True) + self.rank = args.rank + self.save_checkpoint_epochs = args.save_checkpoint_epochs + formatter = logging.Formatter('%(message)s') + # console handler + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + # file handler + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(self.rank) + self.log_fn = os.path.join(self.log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + self.addHandler(fh) + if args: + self.save_args(args) + self.step = 0 + self.epoch = 0 + self.dataset_size = 0 + self.print_iter = args.print_iter + self.contrastive_loss = [] + self.linear_eval = False + self.Loss = 0 + self.Acc = 0 + + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + self.info('--> %s: %s', key, args_dict[key]) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.logger.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + + def epoch_start(self): + self.step_start_time = time.time() + self.epoch_start_time = time.time() + self.step = 0 + self.epoch += 1 + self.contrastive_loss = [] + + def step_end(self, loss): + """print log when step end.""" + self.step += 1 + self.contrastive_loss.append(loss.asnumpy()) + if self.step % self.print_iter == 0: + step_cost = (time.time() - self.step_start_time) * 1000 / self.print_iter + self.info("Epoch[{}] [{}/{}] step cost: {:.2f} ms, loss: {}".format( + self.epoch, self.step, self.dataset_size, step_cost, loss)) + self.step_start_time = time.time() + + def epoch_end(self, net): + """print log and save cgeckpoints when epoch end.""" + epoch_cost = (time.time() - self.epoch_start_time) * 1000 + pre_step_time = epoch_cost / self.dataset_size + mean_loss = sum(self.contrastive_loss) / self.dataset_size + + self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, mean_loss: {:.2f}"\ + .format(self.epoch, epoch_cost, pre_step_time, mean_loss)) + if self.epoch % self.save_checkpoint_epochs == 0: + if self.linear_eval: + save_checkpoint(net, os.path.join(self.ckpts_dir, f"linearClassifier_{self.epoch}.ckpt")) + else: + save_checkpoint(net, os.path.join(self.ckpts_dir, f"simclr_{self.epoch}.ckpt")) + + def start_predict(self): + self.predict_start_time = time.time() + self.step = 0 + self.info('==========start predict===============') + + def end_predict(self): + avg_loss = self.Loss / self.step + avg_acc = self.Acc / self.step + self.info('Average loss {:.5f}, Average accuracy {:.5f}'.format(avg_loss, avg_acc)) + self.info('==========end predict===============\n') + + def predict_step_end(self, loss, acc): + self.step += 1 + self.Loss = self.Loss + loss + self.Acc = self.Acc + acc + if self.step % self.print_iter == 0: + current_loss = self.Loss / self.step + current_acc = self.Acc / self.step + self.info('[{}/{}] Current total loss {:.5f}, Current total accuracy {:.5f}'\ + .format(self.step, self.dataset_size, current_loss, current_acc)) diff --git a/model_zoo/official/cv/simclr/src/resnet.py b/model_zoo/official/cv/simclr/src/resnet.py new file mode 100644 index 0000000000..733318b1e6 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/resnet.py @@ -0,0 +1,485 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SimCLR ResNet.""" +import math +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from scipy.stats import truncnorm + +def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): + fan_in = in_channel * kernel_size * kernel_size + scale = 1.0 + scale /= max(1., fan_in) + stddev = (scale ** 0.5) / .87962566103423978 + mu, sigma = 0, stddev + weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) + weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) + return Tensor(weight, dtype=mstype.float32) + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + +def calculate_gain(nonlinearity, param=None): + """calculate_gain""" + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + res = 0 + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + res = 1 + elif nonlinearity == 'tanh': + res = 5.0 / 3 + elif nonlinearity == 'relu': + res = math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + res = math.sqrt(2.0 / (1 + negative_slope ** 2)) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + return res + + +def _calculate_fan_in_and_fan_out(tensor): + """_calculate_fan_in_and_fan_out""" + dimensions = len(tensor) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + if dimensions == 2: # Linear + fan_in = tensor[1] + fan_out = tensor[0] + else: + num_input_fmaps = tensor[1] + num_output_fmaps = tensor[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = tensor[2] * tensor[3] + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return np.random.normal(0, std, size=inputs_shape).astype(np.float32) + + +def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'): + fan = _calculate_correct_fan(inputs_shape, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32) + + +def _conv3x3(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) + else: + weight_shape = (out_channel, in_channel, 3, 3) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) + else: + weight_shape = (out_channel, in_channel, 1, 1) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) + else: + weight_shape = (out_channel, in_channel, 7, 7) + weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu')) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel, use_se=False): + if use_se: + weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel) + weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) + else: + weight_shape = (out_channel, in_channel) + weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1, + use_se=False, se_block=False): + super(ResidualBlock, self).__init__() + self.stride = stride + self.use_se = use_se + self.se_block = se_block + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) + self.bn1 = _bn(channel) + if self.use_se and self.stride != 1: + self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), + nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) + else: + self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) + self.bn3 = _bn_last(out_channel) + if self.se_block: + self.se_global_pool = P.ReduceMean(keep_dims=False) + self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se) + self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se) + self.se_sigmoid = nn.Sigmoid() + self.se_mul = P.Mul() + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + if self.use_se: + if stride == 1: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, + stride, use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), + _conv1x1(in_channel, out_channel, 1, + use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, + use_se=self.use_se), _bn(out_channel)]) + self.add = F.tensor_add + + def construct(self, x): + """ + Forward. + """ + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + if self.use_se and self.stride != 1: + out = self.e2(out) + else: + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.se_block: + out_se = out + out = self.se_global_pool(out, (2, 3)) + out = self.se_dense_0(out) + out = self.relu(out) + out = self.se_dense_1(out) + out = self.se_sigmoid(out) + out = F.reshape(out, F.shape(out) + (1, 1)) + out = self.se_mul(out, out_se) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + +class Identity(nn.Cell): + def construct(self, x): + return x + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes, + width_multiplier, + cifar_stem, + use_se=False): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.use_se = use_se + self.se_block = False + if self.use_se: + self.se_block = True + + if self.use_se: + self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) + self.bn1_0 = _bn(32) + self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) + self.bn1_1 = _bn(32) + self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) + else: + if cifar_stem: + self.conv1 = _conv3x3(3, 64 * width_multiplier, stride=1) # cifar + else: + self.conv1 = _conv7x7(3, 64 * width_multiplier, stride=2) + + self.bn1 = _bn(64 * width_multiplier) + self.relu = P.ReLU() + if cifar_stem: + self.maxpool = Identity() # cifar + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + in_channels = [i * width_multiplier for i in in_channels] + out_channels = [i * width_multiplier for i in out_channels] + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0], + use_se=self.use_se) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1], + use_se=self.use_se) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2], + use_se=self.use_se, + se_block=self.se_block) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3], + use_se=self.use_se, + se_block=self.se_block) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) + layers.append(resnet_block) + if se_block: + for _ in range(1, layer_num - 1): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) + layers.append(resnet_block) + else: + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + def construct(self, x): + """ + Forward. + """ + if self.use_se: + x = self.conv1_0(x) + x = self.bn1_0(x) + x = self.relu(x) + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu(x) + x = self.conv1_2(x) + else: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=10, width_multiplier=1, cifar_stem=True): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + width_multiplier, + cifar_stem) + +def se_resnet50(class_num=1001, width_multiplier=1): + """ + Get SE-ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of SE-ResNet50 neural network. + + Examples: + >>> net = se-resnet50(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + width_multiplier, + use_se=True) + +def resnet101(class_num=1001, width_multiplier=1): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + width_multiplier) diff --git a/model_zoo/official/cv/simclr/src/simclr_model.py b/model_zoo/official/cv/simclr/src/simclr_model.py new file mode 100644 index 0000000000..0475eb7de6 --- /dev/null +++ b/model_zoo/official/cv/simclr/src/simclr_model.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SimCLR Model class.""" +from mindspore import nn +from .resnet import _fc + +class Identity(nn.Cell): + def construct(self, x): + return x + +class SimCLR(nn.Cell): + """ + SimCLR Model. + """ + def __init__(self, encoder, project_dim, n_features): + super(SimCLR, self).__init__() + self.encoder = encoder + self.n_features = n_features + self.encoder.end_point = Identity() + self.dense1 = _fc(self.n_features, self.n_features) + self.relu = nn.ReLU() + self.end_point = _fc(self.n_features, project_dim) + + # Projector MLP. + def projector(self, x): + out = self.dense1(x) + out = self.relu(out) + out = self.end_point(out) + return out + + def construct(self, x_i, x_j): + h_i = self.encoder(x_i) + z_i = self.projector(h_i) + + h_j = self.encoder(x_j) + z_j = self.projector(h_j) + return h_i, h_j, z_i, z_j + + def inference(self, x): + h = self.encoder(x) + return h diff --git a/model_zoo/official/cv/simclr/train.py b/model_zoo/official/cv/simclr/train.py new file mode 100644 index 0000000000..05038ee276 --- /dev/null +++ b/model_zoo/official/cv/simclr/train.py @@ -0,0 +1,164 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +######################## train SimCLR example ######################## +train simclr and get network model files(.ckpt) : +python train.py --train_dataset_path /YourDataPath +""" +import ast +import argparse +import os +from src.nt_xent import NT_Xent_Loss +from src.optimizer import get_train_optimizer as get_optimizer +from src.dataset import create_dataset +from src.simclr_model import SimCLR +from src.resnet import resnet50 as resnet +from mindspore import nn +from mindspore import context +from mindspore.train import Model +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.common import initializer as weight_init +from mindspore.common import set_seed +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +parser = argparse.ArgumentParser(description='MindSpore SimCLR') +parser.add_argument('--device_target', type=str, default='Ascend', + help='Device target, Currently only Ascend is supported.') +parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True, + help='Whether it is running on CloudBrain platform.') +parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distributed training.') +parser.add_argument('--device_num', type=int, default=1, help='Device num.') +parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.') +parser.add_argument('--dataset_name', type=str, default='cifar10', help='Dataset, Currently only cifar10 is supported.') +parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\ + This parameter needs to be set when running on the cloud brain platform.') +parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\ + This parameter needs to be set when running on the cloud brain platform.') +parser.add_argument('--train_dataset_path', type=str, default='./cifar/train', + help='Dataset path for training classifier. ' + 'This parameter needs to be set when running on the host.') +parser.add_argument('--train_output_path', type=str, default='./outputs', help='Location of ckpt and log.\ + This parameter needs to be set when running on the host.') +parser.add_argument('--batch_size', type=int, default=128, help='batch_size, default is 128.') +parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training, default is 100.') +parser.add_argument('--projection_dimension', type=int, default=128, + help='Projection output dimensionality, default is 128.') +parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier for ResNet50') +parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss') +parser.add_argument('--pre_trained_path', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--pretrain_epoch_size', type=int, default=0, + help='real_epoch_size = epoch_size - pretrain_epoch_size.') +parser.add_argument('--save_checkpoint_epochs', type=int, default=1, help='Save checkpoint epochs, default is 1.') +parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, + help='whether save graphs, default is False.') +parser.add_argument('--optimizer', type=str, default='Adam', help='Optimizer, Currently only Adam is supported.') +parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') +parser.add_argument('--warmup_epochs', type=int, default=15, help='warmup epochs.') +parser.add_argument('--use_crop', type=ast.literal_eval, default=True, help='RandomResizedCrop') +parser.add_argument('--use_flip', type=ast.literal_eval, default=True, help='RandomHorizontalFlip') +parser.add_argument('--use_color_jitter', type=ast.literal_eval, default=True, help='RandomColorAdjust') +parser.add_argument('--use_color_gray', type=ast.literal_eval, default=True, help='RandomGrayscale') +parser.add_argument('--use_blur', type=ast.literal_eval, default=False, help='GaussianBlur') +parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Normalize') + +args = parser.parse_args() +local_data_url = './cache/data' +local_train_url = './cache/train' +_local_train_url = local_train_url + +if args.device_target != "Ascend": + raise ValueError("Unsupported device target.") +if args.run_distribute: + device_id = os.getenv("DEVICE_ID", default=None) + if device_id is None: + raise ValueError("Unsupported device id.") + args.device_id = int(device_id) + rank_size = os.getenv("RANK_SIZE", default=None) + if rank_size is None: + raise ValueError("Unsupported rank size.") + if args.device_num > int(rank_size) or args.device_num == 1: + args.device_num = int(rank_size) + context.set_context(device_id=args.device_id) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, device_num=args.device_num) + init() + args.rank = get_rank() + local_data_url = os.path.join(local_data_url, str(args.device_id)) + local_train_url = os.path.join(local_train_url, str(args.device_id)) + args.train_output_path = os.path.join(args.train_output_path, str(args.device_id)) +else: + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, + save_graphs=args.save_graphs, device_id=args.device_id) + args.rank = 0 + args.device_num = 1 + +if args.run_cloudbrain: + import moxing as mox + args.train_dataset_path = os.path.join(local_data_url, "train") + args.train_output_path = local_train_url + mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url) + +set_seed(1) + +class NetWithLossCell(nn.Cell): + def __init__(self, backbone, loss_fn): + super(NetWithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data_x, data_y, label): + _, _, x_pred, y_pred = self._backbone(data_x, data_y) + return self._loss_fn(x_pred, y_pred) + +if __name__ == "__main__": + dataset = create_dataset(args, dataset_mode="train_endcoder") + # Net. + base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10") + net = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels) + # init weight + if args.pre_trained_path: + if args.run_cloudbrain: + mox.file.copy_parallel(src_url=args.pre_trained_path, dst_url=local_data_url+'/pre_train.ckpt') + param_dict = load_checkpoint(local_data_url+'/pre_train.ckpt') + else: + param_dict = load_checkpoint(args.pre_trained_path) + load_param_into_net(net, param_dict) + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(), + cell.weight.shape, + cell.weight.dtype)) + optimizer = get_optimizer(net, dataset.get_dataset_size(), args) + loss = NT_Xent_Loss(args.batch_size, args.temperature) + net_loss = NetWithLossCell(net, loss) + train_net = nn.TrainOneStepCell(net_loss, optimizer) + model = Model(train_net) + time_cb = TimeMonitor(data_size=dataset.get_dataset_size()) + config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_epochs) + ckpts_dir = os.path.join(args.train_output_path, "checkpoint") + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_simclr", directory=ckpts_dir, config=config_ck) + print("============== Starting Training ==============") + model.train(args.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, LossMonitor()]) + if args.run_cloudbrain and args.device_id == 0: + mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url)