Browse Source

add simclr

pull/14270/head
MapleGrove 4 years ago
parent
commit
acf0b05afc
15 changed files with 1901 additions and 0 deletions
  1. +220
    -0
      model_zoo/official/cv/simclr/README.md
  2. +62
    -0
      model_zoo/official/cv/simclr/export.py
  3. +215
    -0
      model_zoo/official/cv/simclr/linear_eval.py
  4. +64
    -0
      model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh
  5. +37
    -0
      model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh
  6. +31
    -0
      model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh
  7. +0
    -0
      model_zoo/official/cv/simclr/src/__init__.py
  8. +94
    -0
      model_zoo/official/cv/simclr/src/dataset.py
  9. +198
    -0
      model_zoo/official/cv/simclr/src/lr_generator.py
  10. +91
    -0
      model_zoo/official/cv/simclr/src/nt_xent.py
  11. +52
    -0
      model_zoo/official/cv/simclr/src/optimizer.py
  12. +135
    -0
      model_zoo/official/cv/simclr/src/reporter.py
  13. +485
    -0
      model_zoo/official/cv/simclr/src/resnet.py
  14. +53
    -0
      model_zoo/official/cv/simclr/src/simclr_model.py
  15. +164
    -0
      model_zoo/official/cv/simclr/train.py

+ 220
- 0
model_zoo/official/cv/simclr/README.md View File

@@ -0,0 +1,220 @@
# Contents

- [SimCLR Description](#simclr-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [ModelZoo Homepage](#modelzoo-homepage)

## [SimCLR Description](#contents)

SimCLR: a simple framework for contrastive learning of visual representations.
[Paper](https://arxiv.org/pdf/2002.05709.pdf): Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton. A Simple Framework for Contrastive Learning of Visual Representations. *arXiv preprint arXiv:2002.05709*. 2020.

## [Model Architecture](#contents)

SimCLR learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space. This framework comprises the following four major components: a stochastic data augmentation module, a neural network base encoder, a small neural network projection head and a contrastive loss function.

## [Dataset](#contents)

In the following sections, we will introduce how to run the scripts using the related dataset below.

Dataset used: [CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>)

- Dataset size:175M,60,000 32*32 colorful images in 10 classes
- Train:146M,50,000 images
- Test:29.3M,10,000 images
- Data format:binary files
- Note:Data will be processed in dataset.py
- Download the dataset, the directory structure is as follows:

```bash
├─cifar-10-batches-bin
└─cifar-10-verify-bin
```

## [Environment Requirements](#contents)

- Hardware(Ascend)
- Prepare hardware environment with Ascend processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)

## [Quick Start](#contents)

After installing MindSpore via the official website, you can start training and evaluation as follows:

```python
# enter script dir, train SimCLR
sh run_standalone_train_ascend.sh [cifar10] [TRAIN_DATASET_PATH] [DEVICE_ID]
or
sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH]
# enter script dir, evaluate SimCLR
sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH]
```

## [Script Description](#contents)

### [Script and Sample Code](#contents)

```bash
├── cv
├── SimCLR
├── README.md // descriptions about SimCLR
├── requirements.txt // package needed
├── scripts
│ ├──run_distribution_train_ascend.sh // train in ascend
│ ├──run_standalone_train_ascend.sh // train in ascend
│ ├──run_standalone_eval_ascend.sh // evaluate in ascend
├── src
│ ├──dataset.py // creating dataset
│ ├──lr_generator.py // generating learning rate
│ ├──nt_xent.py // contrastive cross entropy loss
│ ├──optimizer.py // generating optimizer
│ ├──resnet.py // base encoder network
│ ├──simclr_model.py // simclr architecture
├── train.py // training script
├── linear_eval.py // linear evaluation script
├── export.py // export model for inference
```

### [Script Parameters](#contents)

```python
Major parameters in train.py as follows:
--device_target: Device target, Currently only Ascend is supported.
--run_cloudbrain: Whether it is running on CloudBrain platform.
--run_distribute: Run distributed training.
--device_num: Device num.
--device_id: Device id, default is 0.
--dataset_name: Dataset, Currently only cifar10 is supported.
--train_url: Cloudbrain Location of training outputs.This parameter needs to be set when running on the cloud brain platform.
--data_url: Cloudbrain Location of data. This parameter needs to be set when running on the cloud brain platform.
--train_dataset_path: Dataset path for training classifier. This parameter needs to be set when running on the host.
--train_output_path: Location of ckpt and log. This parameter needs to be set when running on the host.
--batch_size: Batch size, default is 128.
--epoch_size: Epoch size for training, default is 100.
--projection_dimension: Projection output dimensionality, default is 128.
--width_multiplier: Width multiplier for ResNet50, default is 1.
--temperature: Temperature for contrastive cross entropy loss.
--pre_trained_path: Pretrained checkpoint path.
--pretrain_epoch_size: real_epoch_size = epoch_size - pretrain_epoch_size.
save_checkpoint_epochs: Save checkpoint epochs, default is 1.
--save_graphs: Whether save graphs, default is False.
--optimizer: Optimizer, Currently only Adam is supported.
--weight_decay: Weight decay.
--warmup_epochs: Warmup epochs.

Major parameters in linear_eval.py as follows:
--device_target: Device target, Currently only Ascend is supported.
--run_cloudbrain: Whether it is running on CloudBrain platform.
--run_distribute: Run distributed training.
--device_num: Device num.
--device_id: Device id, default is 0.
--dataset_name: Dataset, Currently only cifar10 is supported.
--train_url: Cloudbrain Location of training outputs.This parameter needs to be set when running on the cloud brain platform.
--data_url: Cloudbrain Location of data. This parameter needs to be set when running on the cloud brain platform.
--train_dataset_path: Dataset path for training classifier. This parameter needs to be set when running on the host.
--eval_dataset_path: Dataset path for evaluating classifier.This parameter needs to be set when running on the host.
--train_output_path: Location of ckpt and log. This parameter needs to be set when running on the host.
--class_num: dataset classification number, default is 10 for cifar10.
--batch_size: Batch size, default is 128.
--epoch_size: Epoch size for training, default is 100.
--projection_dimension: Projection output dimensionality, default is 128.
--width_multiplier: Width multiplier for ResNet50, default is 1.
--pre_classifier_checkpoint_path: Classifier Checkpoint file path.
--encoder_checkpoint_path: Encoder Checkpoint file path.
--save_checkpoint_epochs: Save checkpoint epochs, default is 10.
--print_iter: Log print iter, default is 100.
--save_graphs: whether save graphs, default is False.
```

### [Training Process](#contents)

#### Training

- running on Ascend

```bash
sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH]
```

After training, the loss value will be achieved as follows:

```bash
# grep "loss is " log
epoch: 1 step: 48, loss is 9.5758915
epoch time: 253236.075 ms, per step time: 5275.752 ms
epoch: 1 step: 48, loss is 9.363186
epoch time: 253739.376 ms, per step time: 5286.237 ms
epoch: 1 step: 48, loss is 9.36029
epoch time: 253711.625 ms, per step time: 5285.659 ms
...
epoch: 100 step: 48, loss is 7.453776
epoch time: 12341.851 ms, per step time: 257.122 ms
epoch: 100 step: 48, loss is 7.499168
epoch time: 12420.060 ms, per step time: 258.751 ms
epoch: 100 step: 48, loss is 7.442362
epoch time: 12725.863 ms, per step time: 265.122 ms
...
```

The model checkpoint will be saved in the outputs directory.
### [Evaluation Process](#contents)
#### Evaluation
Before running the command below, please check the checkpoint path used for evaluation.

- running on Ascend

```bash
sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH]
```

You can view the results through the file "eval_log". The accuracy of the test dataset will be as follows:

```bash
# grep "Average accuracy: " eval_log
'Accuracy': 0.84505
```

## [Model Description](#contents)

### [Performance](#contents)

#### Evaluation Performance

| Parameters | Ascend |
| -------------------------- | ------------------------------------------------------------|
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G |
| uploaded Date | 30/03/2021 (month/day/year) |
| MindSpore Version | 1.1.1 |
| Dataset | CIFAR-10 |
| Training Parameters | epoch=100, batch_size=128, device_num=8 |
| Optimizer | Adam |
| Loss Function | NT-Xent Loss |
| linear eval | 84.505% |
| Total time | 25m04s |
| Scripts | [SimCLR Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/simclr) | [SimCLR Script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/simclr) |

## [Description of Random Situation](#contents)

We set the seed inside dataset.py. We also use random seed in train.py.

## [ModelZoo Homepage](#contents)

Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

+ 62
- 0
model_zoo/official/cv/simclr/export.py View File

@@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np

import mindspore as ms
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export

from src.simclr_model import SimCLR
from src.resnet import resnet50 as resnet

parser = argparse.ArgumentParser(description='SimCLR')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10'],
help='Dataset, Currently only cifar10 is supported.')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend'],
help='Device target, Currently only Ascend is supported.')
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="simclr", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
if args_opt.device_target == "Ascend":
context.set_context(device_id=args_opt.device_id)

if __name__ == '__main__':
if args_opt.dataset_name == 'cifar10':
width_multiplier = 1
cifar_stem = True
projection_dimension = 128
image_height = 32
image_width = 32
else:
raise ValueError("dataset is not support.")

base_net = resnet(1, width_multiplier=width_multiplier, cifar_stem=cifar_stem)
net = SimCLR(base_net, projection_dimension, base_net.end_point.in_channels)

param_dict = load_checkpoint(args_opt.ckpt_file)
load_param_into_net(net, param_dict)

input_arr = Tensor(np.zeros([args_opt.batch_size, 3, image_height, image_width]), ms.float32)
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)

+ 215
- 0
model_zoo/official/cv/simclr/linear_eval.py View File

@@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## eval SimCLR example ########################
eval SimCLR according to model file:
python eval.py --encoder_checkpoint_path Your.ckpt --train_dataset_path /YourDataPath1
--eval_dataset_path /YourDataPath2
"""
import ast
import os
import argparse
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import ops
from mindspore import context
from mindspore.common.initializer import TruncatedNormal
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from src.dataset import create_dataset
from src.simclr_model import SimCLR
from src.resnet import resnet50 as resnet
from src.reporter import Reporter
from src.optimizer import get_eval_optimizer as get_optimizer



parser = argparse.ArgumentParser(description='Linear Evaluation Protocol')
parser.add_argument('--device_target', type=str, default='Ascend',
help='Device target, Currently only Ascend is supported.')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Running distributed evaluation.')
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
help='Whether it is running on CloudBrain platform.')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.')
parser.add_argument('--dataset_name', type=str, default='cifar10', help='Dataset, Currently only cifar10 is supported.')
parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\
This parameter needs to be set when running on the cloud brain platform.')
parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\
This parameter needs to be set when running on the cloud brain platform.')
parser.add_argument('--train_dataset_path', type=str, default='./cifar/train',\
help='Dataset path for training classifier.\
This parameter needs to be set when running on the host.')
parser.add_argument('--eval_dataset_path', type=str, default='./cifar/eval',\
help='Dataset path for evaluating classifier.\
This parameter needs to be set when running on the host.')
parser.add_argument('--train_output_path', type=str, default='./outputs', help='Location of ckpt and log.\
This parameter needs to be set when running on the host.')
parser.add_argument('--class_num', type=int, default=10, help='dataset classification number, default is 10.')
parser.add_argument('--batch_size', type=int, default=128, help='batch_size for training classifier, default is 128.')
parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training classifier, default is 100.')
parser.add_argument('--projection_dimension', type=int, default=128,
help='Projection output dimensionality, default is 128.')
parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier=4,resnet50x4')
parser.add_argument('--pre_classifier_checkpoint_path', type=str, default=None, help='Classifier Checkpoint file path.')
parser.add_argument('--encoder_checkpoint_path', type=str, help='Encoder Checkpoint file path.')
parser.add_argument('--save_checkpoint_epochs', type=int, default=10, help='Save checkpoint epochs, default is 10.')
parser.add_argument('--print_iter', type=int, default=100, help='log print iter, default is 100.')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False,
help='whether save graphs, default is False.')
parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Dataset normalize.')

args = parser.parse_args()
set_seed(1)
local_data_url = './cache/data'
local_train_url = './cache/train'
_local_train_url = local_train_url

if args.device_target != "Ascend":
raise ValueError("Unsupported device target.")
if args.run_distribute:
device_id = os.getenv("DEVICE_ID", default=None)
if device_id is None:
raise ValueError("Unsupported device id.")
args.device_id = int(device_id)
rank_size = os.getenv("RANK_SIZE", default=None)
if rank_size is None:
raise ValueError("Unsupported rank size.")
if args.device_num > int(rank_size) or args.device_num == 1:
args.device_num = int(rank_size)
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=args.device_num)
init()
args.rank = get_rank()
local_data_url = os.path.join(local_data_url, str(args.device_id))
local_train_url = os.path.join(local_train_url, str(args.device_id))
args.train_output_path = os.path.join(args.train_output_path, str(args.device_id))
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1

if args.run_cloudbrain:
import moxing as mox
args.train_dataset_path = os.path.join(local_data_url, "train")
args.eval_dataset_path = os.path.join(local_data_url, "val")
args.train_output_path = local_train_url
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)

class LogisticRegression(nn.Cell):
"""
Logistic regression
"""
def __init__(self, n_features, n_classes):
super(LogisticRegression, self).__init__()
self.model = nn.Dense(n_features, n_classes, TruncatedNormal(0.02), TruncatedNormal(0.02))

def construct(self, x):
x = self.model(x)
return x

class Linear_Eval():
"""
Linear classifier
"""
def __init__(self, net, loss):
super(Linear_Eval, self).__init__()
self.net = net
self.softmax = nn.Softmax()
self.loss = loss
def __call__(self, x, y):
x = self.net(x)
loss = self.loss(x, y)
x = self.softmax(x)
predicts = ops.Argmax(output_type=mstype.int32)(x)
acc = np.sum(predicts.asnumpy() == y.asnumpy())/len(y.asnumpy())
return loss.asnumpy(), acc

class Linear_Train(nn.Cell):
"""
Train linear classifier
"""
def __init__(self, net, loss, opt):
super(Linear_Train, self).__init__()
self.netwithloss = nn.WithLossCell(net, loss)
self.train_net = nn.TrainOneStepCell(self.netwithloss, opt)
self.train_net.set_train()
def construct(self, x, y):
return self.train_net(x, y)

if __name__ == "__main__":
base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10")
simclr_model = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels)
if args.run_cloudbrain:
mox.file.copy_parallel(src_url=args.encoder_checkpoint_path, dst_url=local_data_url+'/encoder.ckpt')
simclr_param = load_checkpoint(local_data_url+'/encoder.ckpt')
else:
simclr_param = load_checkpoint(args.encoder_checkpoint_path)
load_param_into_net(simclr_model.encoder, simclr_param)
classifier = LogisticRegression(simclr_model.n_features, args.class_num)
dataset = create_dataset(args, dataset_mode="train_classifier")
optimizer = get_optimizer(classifier, dataset.get_dataset_size(), args)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_Train = Linear_Train(net=classifier, loss=criterion, opt=optimizer)
reporter = Reporter(args, linear_eval=True)
reporter.dataset_size = dataset.get_dataset_size()
reporter.linear_eval = True
if args.pre_classifier_checkpoint_path:
if args.run_cloudbrain:
mox.file.copy_parallel(src_url=args.pre_classifier_checkpoint_path,
dst_url=local_data_url+'/pre_classifier.ckpt')
classifier_param = load_checkpoint(local_data_url+'/pre_classifier.ckpt')
else:
classifier_param = load_checkpoint(args.pre_classifier_checkpoint_path)
load_param_into_net(classifier, classifier_param)
else:
dataset_train = []
for _, data in enumerate(dataset, start=1):
_, images, labels = data
features = simclr_model.inference(images)
dataset_train.append([features, labels])
reporter.info('==========start training linear classifier===============')
# Train.
for _ in range(args.epoch_size):
reporter.epoch_start()
for idx, data in enumerate(dataset_train, start=1):
features, labels = data
out = net_Train(features, labels)
reporter.step_end(out)
reporter.epoch_end(classifier)
reporter.info('==========end training linear classifier===============')

dataset = create_dataset(args, dataset_mode="eval_classifier")
reporter.dataset_size = dataset.get_dataset_size()
net_Eval = Linear_Eval(net=classifier, loss=criterion)
# Eval.
reporter.info('==========start evaluating linear classifier===============')
reporter.start_predict()
for idx, data in enumerate(dataset, start=1):
_, images, labels = data
features = simclr_model.inference(images)
batch_loss, batch_acc = net_Eval(features, labels)
reporter.predict_step_end(batch_loss, batch_acc)
reporter.end_predict()
reporter.info('==========end evaluating linear classifier===============')
if args.run_cloudbrain:
mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url)

+ 64
- 0
model_zoo/official/cv/simclr/scripts/run_distribution_ascend.sh View File

@@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 4 ]
then
echo "Usage: sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH]"
exit 1
fi

#
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}

#
if [ ! -f $2 ]
then
echo "error: RANK_TABLE_FILE=$2 is not a file"
exit 1
fi

ulimit -u unlimited
export DEVICE_NUM=$1
export RANK_SIZE=$1
RANK_TABLE_FILE=$(get_real_path $2)
export RANK_TABLE_FILE
export DATASET_NAME=$3
export TRAIN_DATASET_PATH=$(get_real_path $4)
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"

export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
cd ./train_parallel$i ||exit
env > env.log
python train.py --device_id=$i --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \
--run_cloudbrain=False --run_distribute=True > log 2>&1 &
cd ..
done

+ 37
- 0
model_zoo/official/cv/simclr/scripts/run_standalone_eval_ascend.sh View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 5 ]
then
echo "Usage: sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH]"
exit 1
fi

script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export DATASET_NAME=$1
export DEVICE_ID=$2
export SIMCLR_MODEL_PATH=$3
export TRAIN_DATASET_PATH=$4
export EVAL_DATASET_PATH=$5


python ${self_path}/../linear_eval.py --dataset_name=$DATASET_NAME \
--encoder_checkpoint_path=$SIMCLR_MODEL_PATH \
--train_dataset_path=$TRAIN_DATASET_PATH \
--eval_dataset_path=$EVAL_DATASET_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--run_distribute=False --run_cloudbrain=False > eval_log 2>&1 &

+ 31
- 0
model_zoo/official/cv/simclr/scripts/run_standalone_train_ascend.sh View File

@@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# an simple tutorial as follows, more parameters can be setting
if [ $# != 3 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [cifar10] [TRAIN_DATASET_PATH] [DEVICE_ID]"
exit 1
fi

script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export DATASET_NAME=$1
export TRAIN_DATASET_PATH=$2
export DEVICE_ID=$3

python ${self_path}/../train.py --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \
--device_id=$DEVICE_ID --device_target="Ascend" \
--run_cloudbrain=False --run_distribute=False > log 2>&1 &

+ 0
- 0
model_zoo/official/cv/simclr/src/__init__.py View File


+ 94
- 0
model_zoo/official/cv/simclr/src/dataset.py View File

@@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore.dataset.vision import Inter
import cv2
import numpy as np

ds.config.set_seed(0)

def gaussian_blur(im):
sigma = 0
_, w = im.shape[:2]
kernel_size = int(w // 10)
if kernel_size % 2 == 0:
kernel_size -= 1
return np.array(cv2.GaussianBlur(im, (kernel_size, kernel_size), sigma))

def copy_column(x, y):
return x, x, y

def create_dataset(args, dataset_mode, repeat_num=1):
"""
create a train or evaluate cifar10 dataset for SimCLR
"""
if args.dataset_name != "cifar10":
raise ValueError("Unsupported dataset.")
if dataset_mode in ("train_endcoder", "train_classifier"):
dataset_path = args.train_dataset_path
else:
dataset_path = args.eval_dataset_path
if args.run_distribute and args.device_target == "Ascend":
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=args.device_num, shard_id=args.device_id)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
# define map operations
trans = []
if dataset_mode == "train_endcoder":
if args.use_crop:
trans += [C.Resize(256, interpolation=Inter.BICUBIC)]
trans += [C.RandomResizedCrop(size=(32, 32), scale=(0.31, 1),
interpolation=Inter.BICUBIC, max_attempts=100)]
if args.use_flip:
trans += [C.RandomHorizontalFlip(prob=0.5)]
if args.use_color_jitter:
scale = 0.6
color_jitter = C.RandomColorAdjust(0.8 * scale, 0.8 * scale, 0.8 * scale, 0.2 * scale)
trans += [C2.RandomApply([color_jitter], prob=0.8)]
if args.use_color_gray:
trans += [py_vision.ToPIL(),
py_vision.RandomGrayscale(prob=0.2),
np.array] # need to convert PIL image to a NumPy array to pass it to C++ operation
if args.use_blur:
trans += [C2.RandomApply([gaussian_blur], prob=0.8)]
if args.use_norm:
trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]
trans += [C2.TypeCast(mstype.float32), C.HWC2CHW()]
else:
trans += [C.Resize(32)]
trans += [C2.TypeCast(mstype.float32)]
if args.use_norm:
trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]
trans += [C.HWC2CHW()]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
data_set = data_set.map(operations=copy_column, input_columns=["image", "label"],
output_columns=["image1", "image2", "label"],
column_order=["image1", "image2", "label"], num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns=["image1"], num_parallel_workers=8)
data_set = data_set.map(operations=trans, input_columns=["image2"], num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(args.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set

+ 198
- 0
model_zoo/official/cv/simclr/src/lr_generator.py View File

@@ -0,0 +1,198 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np


def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
"""
Applies three steps decay to generate learning rate array.
"""
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
return lr_each_step


def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies polynomial decay to generate learning rate array.

Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.

Returns:
np.array, learning rate array.
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
return lr_each_step


def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies cosine decay to generate learning rate array.

Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.

Returns:
np.array, learning rate array.
"""
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
return lr_each_step


def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies liner decay to generate learning rate array.

Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.

Returns:
np.array, learning rate array.
"""
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
return lr_each_step



def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array

Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)

Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs

if lr_decay_mode == 'steps':
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'poly':
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'cosine':
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
else:
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)

lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step


def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr


def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
"""
generate learning rate array with cosine

Args:
lr(float): base learning rate
steps_per_epoch(int): steps size of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epochs of training
global_step(int): the current start index of lr array
Returns:
np.array, learning rate array
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps

lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed
lr_each_step.append(lr)

lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[global_step:]
return learning_rate

+ 91
- 0
model_zoo/official/cv/simclr/src/nt_xent.py View File

@@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SimCLR Loss class."""

from mindspore import Tensor
from mindspore import ops as P
from mindspore.common import dtype as mstype
import mindspore.nn as nn


class CrossEntropyLoss(nn.Cell):
"""
Cross Entropy Loss.
"""
def __init__(self, reduction="mean"):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
if reduction == "sum":
self.reduction = P.ReduceSum()
if reduction == "mean":
self.reduction = P.ReduceMean()
self.one_hot = P.OneHot()
self.one = Tensor(1.0, mstype.float32)
self.zero = Tensor(0.0, mstype.float32)

def construct(self, logits, label):
loss = self.cross_entropy(logits, label)[0]
loss = self.reduction(loss, (-1,))
return loss


class NT_Xent_Loss(nn.Cell):
"""
Loss for SimCLR.
"""
def __init__(self, batch_size, temperature=1, world_size=1):
super(NT_Xent_Loss, self).__init__()
# Parameters.
self.LARGE_NUM = 1e9
self.batch_size = batch_size
self.temperature = temperature
self.world_size = world_size
self.N = 2 * self.batch_size * self.world_size
# Tail_Loss.
self.criterion = CrossEntropyLoss(reduction="mean")
self.norm = P.L2Normalize(axis=1)
self.one_hot = P.OneHot()
self.range = nn.Range(0, self.batch_size)
self.one = Tensor(1.0, mstype.float32)
self.zero = Tensor(0.0, mstype.float32)
self.transpose = P.Transpose()
self.matmul = nn.MatMul()
# Operations.
self.ones = P.Ones()
self.zeros = P.Zeros()
self.cat1 = P.Concat(axis=1)

def construct(self, z_i, z_j):
"""
Forward.
"""
hidden1 = self.norm(z_i)
hidden2 = self.norm(z_j)
hidden1_large = hidden1
hidden2_large = hidden2
ones_mask = self.range()
zeros_mask = self.zeros((self.batch_size, self.batch_size), mstype.float32)
masks = self.one_hot(ones_mask, self.batch_size, self.one, self.zero)
labels = self.cat1((masks, zeros_mask))
logits_aa = self.matmul(hidden1, self.transpose(hidden1_large, (1, 0))) / self.temperature
logits_aa = logits_aa - masks * self.LARGE_NUM
logits_bb = self.matmul(hidden2, self.transpose(hidden2_large, (1, 0))) / self.temperature
logits_bb = logits_bb - masks * self.LARGE_NUM
logits_ab = self.matmul(hidden1, self.transpose(hidden2_large, (1, 0))) / self.temperature
logits_ba = self.matmul(hidden2, self.transpose(hidden1_large, (1, 0))) / self.temperature
loss_a = self.criterion(self.cat1((logits_ab, logits_aa)), labels)
loss_b = self.criterion(self.cat1((logits_ba, logits_bb)), labels)
loss = loss_a + loss_b
return loss

+ 52
- 0
model_zoo/official/cv/simclr/src/optimizer.py View File

@@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""optimizer generator"""
from mindspore import nn, Tensor
from .lr_generator import get_lr

def get_train_optimizer(net, steps_per_epoch, args):
"""
generate optimizer for updating the weights.
"""
if args.optimizer == "Adam":
lr = get_lr(lr_init=1e-4, lr_end=1e-6, lr_max=9e-4,
warmup_epochs=args.warmup_epochs, total_epochs=args.epoch_size,
steps_per_epoch=steps_per_epoch,
lr_decay_mode="linear")
lr = Tensor(lr)
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': args.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
optimizer = nn.Adam(params=group_params, learning_rate=lr)
else:
raise ValueError("Unsupported optimizer.")

return optimizer

def get_eval_optimizer(net, steps_per_epoch, args):
lr = get_lr(lr_init=1e-3, lr_end=6e-6, lr_max=1e-2,
warmup_epochs=5, total_epochs=args.epoch_size,
steps_per_epoch=steps_per_epoch,
lr_decay_mode="linear")
lr = Tensor(lr)
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr)
return optimizer

+ 135
- 0
model_zoo/official/cv/simclr/src/reporter.py View File

@@ -0,0 +1,135 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reporter class."""
import logging
import os
import time
from datetime import datetime
from mindspore.train.serialization import save_checkpoint

class Reporter(logging.Logger):
"""
This class includes several functions that can save images/checkpoints and print/save logging information.
"""
def __init__(self, args, linear_eval):
super(Reporter, self).__init__("clean")
self.log_dir = os.path.join(args.train_output_path, 'log')
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir, exist_ok=True)
if linear_eval:
self.ckpts_dir = os.path.join(args.train_output_path, "checkpoint")
if not os.path.exists(self.ckpts_dir):
os.makedirs(self.ckpts_dir, exist_ok=True)
self.rank = args.rank
self.save_checkpoint_epochs = args.save_checkpoint_epochs
formatter = logging.Formatter('%(message)s')
# console handler
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
# file handler
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(self.rank)
self.log_fn = os.path.join(self.log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.addHandler(fh)
if args:
self.save_args(args)
self.step = 0
self.epoch = 0
self.dataset_size = 0
self.print_iter = args.print_iter
self.contrastive_loss = []
self.linear_eval = False
self.Loss = 0
self.Acc = 0


def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)

def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')

def important_info(self, msg, *args, **kwargs):
if self.logger.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)

def epoch_start(self):
self.step_start_time = time.time()
self.epoch_start_time = time.time()
self.step = 0
self.epoch += 1
self.contrastive_loss = []

def step_end(self, loss):
"""print log when step end."""
self.step += 1
self.contrastive_loss.append(loss.asnumpy())
if self.step % self.print_iter == 0:
step_cost = (time.time() - self.step_start_time) * 1000 / self.print_iter
self.info("Epoch[{}] [{}/{}] step cost: {:.2f} ms, loss: {}".format(
self.epoch, self.step, self.dataset_size, step_cost, loss))
self.step_start_time = time.time()

def epoch_end(self, net):
"""print log and save cgeckpoints when epoch end."""
epoch_cost = (time.time() - self.epoch_start_time) * 1000
pre_step_time = epoch_cost / self.dataset_size
mean_loss = sum(self.contrastive_loss) / self.dataset_size

self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, mean_loss: {:.2f}"\
.format(self.epoch, epoch_cost, pre_step_time, mean_loss))
if self.epoch % self.save_checkpoint_epochs == 0:
if self.linear_eval:
save_checkpoint(net, os.path.join(self.ckpts_dir, f"linearClassifier_{self.epoch}.ckpt"))
else:
save_checkpoint(net, os.path.join(self.ckpts_dir, f"simclr_{self.epoch}.ckpt"))

def start_predict(self):
self.predict_start_time = time.time()
self.step = 0
self.info('==========start predict===============')

def end_predict(self):
avg_loss = self.Loss / self.step
avg_acc = self.Acc / self.step
self.info('Average loss {:.5f}, Average accuracy {:.5f}'.format(avg_loss, avg_acc))
self.info('==========end predict===============\n')

def predict_step_end(self, loss, acc):
self.step += 1
self.Loss = self.Loss + loss
self.Acc = self.Acc + acc
if self.step % self.print_iter == 0:
current_loss = self.Loss / self.step
current_acc = self.Acc / self.step
self.info('[{}/{}] Current total loss {:.5f}, Current total accuracy {:.5f}'\
.format(self.step, self.dataset_size, current_loss, current_acc))

+ 485
- 0
model_zoo/official/cv/simclr/src/resnet.py View File

@@ -0,0 +1,485 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SimCLR ResNet."""
import math
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm

def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)

def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)

def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res


def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out


def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out


def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)


def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)


def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)


class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.

Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.

Returns:
Tensor, output tensor.

Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4

def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)

self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()

self.down_sample = False

if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None

if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = F.tensor_add

def construct(self, x):
"""
Forward.
"""
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)

if self.down_sample:
identity = self.down_sample_layer(identity)

out = self.add(out, identity)
out = self.relu(out)

return out

class Identity(nn.Cell):
def construct(self, x):
return x

class ResNet(nn.Cell):
"""
ResNet architecture.

Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.

Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""

def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
width_multiplier,
cifar_stem,
use_se=False):
super(ResNet, self).__init__()

if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.se_block = False
if self.use_se:
self.se_block = True

if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
if cifar_stem:
self.conv1 = _conv3x3(3, 64 * width_multiplier, stride=1) # cifar
else:
self.conv1 = _conv7x7(3, 64 * width_multiplier, stride=2)

self.bn1 = _bn(64 * width_multiplier)
self.relu = P.ReLU()
if cifar_stem:
self.maxpool = Identity() # cifar
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")

in_channels = [i * width_multiplier for i in in_channels]
out_channels = [i * width_multiplier for i in out_channels]
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)

def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.

Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.

Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []

resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)

def construct(self, x):
"""
Forward.
"""
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)

c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)

out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)

return out


def resnet50(class_num=10, width_multiplier=1, cifar_stem=True):
"""
Get ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet50 neural network.

Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
width_multiplier,
cifar_stem)

def se_resnet50(class_num=1001, width_multiplier=1):
"""
Get SE-ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of SE-ResNet50 neural network.

Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
width_multiplier,
use_se=True)

def resnet101(class_num=1001, width_multiplier=1):
"""
Get ResNet101 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet101 neural network.

Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
width_multiplier)

+ 53
- 0
model_zoo/official/cv/simclr/src/simclr_model.py View File

@@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SimCLR Model class."""
from mindspore import nn
from .resnet import _fc

class Identity(nn.Cell):
def construct(self, x):
return x

class SimCLR(nn.Cell):
"""
SimCLR Model.
"""
def __init__(self, encoder, project_dim, n_features):
super(SimCLR, self).__init__()
self.encoder = encoder
self.n_features = n_features
self.encoder.end_point = Identity()
self.dense1 = _fc(self.n_features, self.n_features)
self.relu = nn.ReLU()
self.end_point = _fc(self.n_features, project_dim)

# Projector MLP.
def projector(self, x):
out = self.dense1(x)
out = self.relu(out)
out = self.end_point(out)
return out

def construct(self, x_i, x_j):
h_i = self.encoder(x_i)
z_i = self.projector(h_i)

h_j = self.encoder(x_j)
z_j = self.projector(h_j)
return h_i, h_j, z_i, z_j

def inference(self, x):
h = self.encoder(x)
return h

+ 164
- 0
model_zoo/official/cv/simclr/train.py View File

@@ -0,0 +1,164 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train SimCLR example ########################
train simclr and get network model files(.ckpt) :
python train.py --train_dataset_path /YourDataPath
"""
import ast
import argparse
import os
from src.nt_xent import NT_Xent_Loss
from src.optimizer import get_train_optimizer as get_optimizer
from src.dataset import create_dataset
from src.simclr_model import SimCLR
from src.resnet import resnet50 as resnet
from mindspore import nn
from mindspore import context
from mindspore.train import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.common import initializer as weight_init
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.train.serialization import load_checkpoint, load_param_into_net

parser = argparse.ArgumentParser(description='MindSpore SimCLR')
parser.add_argument('--device_target', type=str, default='Ascend',
help='Device target, Currently only Ascend is supported.')
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
help='Whether it is running on CloudBrain platform.')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distributed training.')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--device_id', type=int, default=0, help='device id, default is 0.')
parser.add_argument('--dataset_name', type=str, default='cifar10', help='Dataset, Currently only cifar10 is supported.')
parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\
This parameter needs to be set when running on the cloud brain platform.')
parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\
This parameter needs to be set when running on the cloud brain platform.')
parser.add_argument('--train_dataset_path', type=str, default='./cifar/train',
help='Dataset path for training classifier. '
'This parameter needs to be set when running on the host.')
parser.add_argument('--train_output_path', type=str, default='./outputs', help='Location of ckpt and log.\
This parameter needs to be set when running on the host.')
parser.add_argument('--batch_size', type=int, default=128, help='batch_size, default is 128.')
parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training, default is 100.')
parser.add_argument('--projection_dimension', type=int, default=128,
help='Projection output dimensionality, default is 128.')
parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier for ResNet50')
parser.add_argument('--temperature', type=float, default=0.5, help='temperature for loss')
parser.add_argument('--pre_trained_path', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--pretrain_epoch_size', type=int, default=0,
help='real_epoch_size = epoch_size - pretrain_epoch_size.')
parser.add_argument('--save_checkpoint_epochs', type=int, default=1, help='Save checkpoint epochs, default is 1.')
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False,
help='whether save graphs, default is False.')
parser.add_argument('--optimizer', type=str, default='Adam', help='Optimizer, Currently only Adam is supported.')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--warmup_epochs', type=int, default=15, help='warmup epochs.')
parser.add_argument('--use_crop', type=ast.literal_eval, default=True, help='RandomResizedCrop')
parser.add_argument('--use_flip', type=ast.literal_eval, default=True, help='RandomHorizontalFlip')
parser.add_argument('--use_color_jitter', type=ast.literal_eval, default=True, help='RandomColorAdjust')
parser.add_argument('--use_color_gray', type=ast.literal_eval, default=True, help='RandomGrayscale')
parser.add_argument('--use_blur', type=ast.literal_eval, default=False, help='GaussianBlur')
parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Normalize')

args = parser.parse_args()
local_data_url = './cache/data'
local_train_url = './cache/train'
_local_train_url = local_train_url

if args.device_target != "Ascend":
raise ValueError("Unsupported device target.")
if args.run_distribute:
device_id = os.getenv("DEVICE_ID", default=None)
if device_id is None:
raise ValueError("Unsupported device id.")
args.device_id = int(device_id)
rank_size = os.getenv("RANK_SIZE", default=None)
if rank_size is None:
raise ValueError("Unsupported rank size.")
if args.device_num > int(rank_size) or args.device_num == 1:
args.device_num = int(rank_size)
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=args.device_num)
init()
args.rank = get_rank()
local_data_url = os.path.join(local_data_url, str(args.device_id))
local_train_url = os.path.join(local_train_url, str(args.device_id))
args.train_output_path = os.path.join(args.train_output_path, str(args.device_id))
else:
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
save_graphs=args.save_graphs, device_id=args.device_id)
args.rank = 0
args.device_num = 1

if args.run_cloudbrain:
import moxing as mox
args.train_dataset_path = os.path.join(local_data_url, "train")
args.train_output_path = local_train_url
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)

set_seed(1)

class NetWithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(NetWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn

def construct(self, data_x, data_y, label):
_, _, x_pred, y_pred = self._backbone(data_x, data_y)
return self._loss_fn(x_pred, y_pred)

if __name__ == "__main__":
dataset = create_dataset(args, dataset_mode="train_endcoder")
# Net.
base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10")
net = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels)
# init weight
if args.pre_trained_path:
if args.run_cloudbrain:
mox.file.copy_parallel(src_url=args.pre_trained_path, dst_url=local_data_url+'/pre_train.ckpt')
param_dict = load_checkpoint(local_data_url+'/pre_train.ckpt')
else:
param_dict = load_checkpoint(args.pre_trained_path)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
optimizer = get_optimizer(net, dataset.get_dataset_size(), args)
loss = NT_Xent_Loss(args.batch_size, args.temperature)
net_loss = NetWithLossCell(net, loss)
train_net = nn.TrainOneStepCell(net_loss, optimizer)
model = Model(train_net)
time_cb = TimeMonitor(data_size=dataset.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_epochs)
ckpts_dir = os.path.join(args.train_output_path, "checkpoint")
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_simclr", directory=ckpts_dir, config=config_ck)
print("============== Starting Training ==============")
model.train(args.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
if args.run_cloudbrain and args.device_id == 0:
mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url)

Loading…
Cancel
Save