| @@ -59,6 +59,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework, | |||
| - [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md) | |||
| - [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md) | |||
| - [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md) | |||
| - [CycleGAN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan/README.md) | |||
| - [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp) | |||
| - [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md) | |||
| - [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio) | |||
| @@ -0,0 +1,235 @@ | |||
| # Contents | |||
| - [CycleGAN Description](#cyclegan-description) | |||
| - [Model Architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Script Description](#script-description) | |||
| - [Script and Sample Code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training Process](#training-process) | |||
| - [Knowledge Distillation Process](#knowledge-distillation-process) | |||
| - [Prediction Process](#prediction-process) | |||
| - [Evaluation with cityscape dataset](#evaluation-with-cityscape-dataset) | |||
| - [Export MindIR](#export-mindir) | |||
| - [Model Description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Evaluation Performance](#evaluation-performance) | |||
| - [Inference Performance](#evaluation-performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [CycleGAN Description](#contents) | |||
| Generative Adversarial Network (referred to as GAN) is an unsupervised learning method that learns by letting two neural networks play against each other. CycleGAN is a kind of GAN, which consists of two generation networks and two discriminant networks. It converts a certain type of pictures into another type of pictures through unpaired pictures, which can be used for style transfer. | |||
| [Paper](https://arxiv.org/abs/1703.10593): Zhu J Y , Park T , Isola P , et al. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks[J]. 2017. | |||
| # [Model Architecture](#contents) | |||
| The CycleGAN contains two generation networks and two discriminant networks. We support two architectures for generation networks: resnet and unet. Resnet architecture contains three convolutions, several residual blocks, two fractionally-strided convlutions with stride 1/2, and one convolution that maps features to RGB. Unet architecture contains three unet block to downsample and upsample, several unet blocks unet block and one convolution that maps features to RGB. For the discriminator networks we use 70 × 70 PatchGANs, which aim to classify whether 70 × 70 overlapping image patches are real or fake. | |||
| # [Dataset](#contents) | |||
| Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. | |||
| Dataset used: [CityScape](<https://cityscapes-dataset.com>) | |||
| Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. We provide `src/utils/prepare_cityscapes_dataset.py` to process images. gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory. | |||
| The processed images will be placed at --output_dir. | |||
| Example usage: | |||
| ```bash | |||
| python src/utils/prepare_cityscapes_dataset.py --gitFine_dir ./cityscapes/gtFine/ --leftImg8bit_dir ./cityscapes/leftImg8bit --output_dir ./cityscapes/ | |||
| ``` | |||
| The directory structure is as follows: | |||
| ```path | |||
| . | |||
| └─cityscapes | |||
| ├─trainA | |||
| ├─trainB | |||
| ├─testA | |||
| └─testB | |||
| ``` | |||
| # [Environment Requirements](#contents) | |||
| - Hardware GPU | |||
| - Prepare hardware environment with GPU processor. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Script Description](#contents) | |||
| ## [Script and Sample Code](#contents) | |||
| ```path | |||
| . | |||
| └─ cv | |||
| └─ cyclegan | |||
| ├─ src | |||
| ├─ __init__.py # init file | |||
| ├─ dataset | |||
| ├─ __init__.py # init file | |||
| ├─ cyclegan_dataset.py # create cyclegan dataset | |||
| ├─ datasets.py # UnalignedDataset and ImageFolderDataset class and some image utils | |||
| └─ distributed_sampler.py # iterator of dataset | |||
| ├─ models | |||
| ├─ __init__.py # init file | |||
| ├─ cycle_gan.py # cyclegan model define | |||
| ├─ losses.py # cyclegan losses function define | |||
| ├─ networks.py # cyclegan sub networks define | |||
| ├─ resnet.py # resnet generate network | |||
| └─ unet.py # unet generate network | |||
| └─ utils | |||
| ├─ __init__.py # init file | |||
| ├─ args.py # parse args | |||
| ├─ prepare_cityscapes_dataset.py # prepare cityscapes dataset to cyclegan format | |||
| ├─ cityscapes_utils.py # cityscapes dataset evaluation utils | |||
| ├─ reporter.py # Reporter class | |||
| └─ tools.py # utils for cyclegan | |||
| ├─ cityscape_eval.py # cityscape dataset eval script | |||
| ├─ predict.py # generate images from A->B and B->A | |||
| ├─ train.py # train script | |||
| ├─ export.py # export mindir script | |||
| ├─ README.md # descriptions about CycleGAN | |||
| └─ mindspore_hub_conf.py # mindspore hub interface | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| ```python | |||
| Major parameters in train.py and config.py as follows: | |||
| "model": "resnet" # generator model, should be in [resnet, unet]. | |||
| "platform": "GPU" # run platform, support GPU, CPU and Ascend. | |||
| "device_id": 0 # device id, default is 0. | |||
| "lr": 0.0002 # init learning rate, default is 0.0002. | |||
| "pool_size": 50 # the size of image buffer that stores previously generated images, default is 50. | |||
| "lr_policy": "linear" # learning rate policy, default is linear. | |||
| "image_size": 256 # input image_size, default is 256. | |||
| "batch_size": 1 # batch_size, default is 1. | |||
| "max_epoch": 200 # epoch size for training, default is 200. | |||
| "n_epochs": 100 # number of epochs with the initial learning rate, default is 100 | |||
| "beta1": 0.5 # Adam beta1, default is 0.5. | |||
| "init_type": normal # network initialization, default is normal. | |||
| "init_gain": 0.02 # scaling factor for normal, xavier and orthogonal, default is 0.02. | |||
| "in_planes": 3 # input channels, default is 3. | |||
| "ngf": 64 # generator model filter numbers, default is 64. | |||
| "gl_num": 9 # generator model residual block numbers, default is 9. | |||
| "ndf": 64 # discriminator model filter numbers, default is 64. | |||
| "dl_num": 3 # discriminator model residual block numbers, default is 3. | |||
| "slope": 0.2 # leakyrelu slope, default is 0.2. | |||
| "norm_mode":"instance" # norm mode, should be [batch, instance], default is instance. | |||
| "lambda_A": 10 # weight for cycle loss (A -> B -> A), default is 10. | |||
| "lambda_B": 10 # weight for cycle loss (B -> A -> B), default is 10. | |||
| "lambda_idt": 0.5 # if lambda_idt > 0 use identity mapping. | |||
| "gan_mode": lsgan # the type of GAN loss, should be [lsgan, vanilla], default is lsgan. | |||
| "pad_mode": REFLECT # the type of Pad, should be [CONSTANT, REFLECT, SYMMETRIC], default is REFLECT. | |||
| "need_dropout": True # whether need dropout, default is True. | |||
| "kd": False # knowledge distillation learning or not, default is False. | |||
| "t_ngf": 64 # teacher network generator model filter numbers when `kd` is True, default is 64. | |||
| "t_gl_num":9 # teacher network generator model residual block numbers when `kd` is True, default is 9. | |||
| "t_slope": 0.2 # teacher network leakyrelu slope when `kd` is True, default is 0.2. | |||
| "t_norm_mode": "instance" #teacher network norm mode when `kd` is True, defaultis instance. | |||
| "print_iter": 100 # log print iter, default is 100. | |||
| "outputs_dir": "outputs" # models are saved here, default is ./outputs. | |||
| "dataroot": None # path of images (should have subfolders trainA, trainB, testA, testB, etc). | |||
| "save_imgs": True # whether save imgs when epoch end, if True result images will generate in `outputs_dir/imgs`, default is True. | |||
| "GT_A_ckpt": None # teacher network pretrained checkpoint file path of G_A when `kd` is True. | |||
| "GT_B_ckpt": None # teacher network pretrained checkpoint file path of G_B when `kd` is True. | |||
| "G_A_ckpt": None # pretrained checkpoint file path of G_A. | |||
| "G_B_ckpt": None # pretrained checkpoint file path of G_B. | |||
| "D_A_ckpt": None # pretrained checkpoint file path of D_A. | |||
| "D_B_ckpt": None # pretrained checkpoint file path of D_B. | |||
| ``` | |||
| ## [Training Process](#contents) | |||
| ```bash | |||
| python train.py --platform [PLATFORM] --dataroot [DATA_PATH] | |||
| ``` | |||
| **Note: pad_mode should be CONSTANT when use Ascend and CPU. When using unet as generate network, the gl_num should less than 7.** | |||
| ## [Knowledge Distillation Process](#contents) | |||
| ```bash | |||
| python train.py --platform [PLATFORM] --dataroot [DATA_PATH] --ngf [NGF] --kd True --GT_A_ckpt [G_A_CKPT] --GT_B_ckpt [G_B_CKPT] | |||
| ``` | |||
| **Note: the student network ngf should be 1/2 or 1/4 of teacher network ngf, if you change default args when training teacher generate networks, please change t_xx in knowledge distillation process.** | |||
| ## [Prediction Process](#contents) | |||
| ```bash | |||
| python predict.py --platform [PLATFORM] --dataroot [DATA_PATH] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] | |||
| ``` | |||
| **Note: the result will saved at `outputs_dir/predict`.** | |||
| ## [Evaluation with cityscape dataset](#contents) | |||
| ```bash | |||
| python cityscape_eval.py --cityscapes_dir [LABEL_PATH] --result_dir [FAKEB_PATH] | |||
| ``` | |||
| **Note: Please run cityscape_eval.py after prediction process.** | |||
| ## [Export MindIR](#contents) | |||
| ```bash | |||
| python export.py --platform [PLATFORM] --G_A_ckpt [G_A_CKPT] --G_B_ckpt [G_B_CKPT] --file_name [FILE_NAME] --file_format [FILE_FORMAT] | |||
| ``` | |||
| **Note: The file_name parameter is the prefix, the final file will as [FILE_NAME]_AtoB.[FILE_FORMAT] and [FILE_NAME]_BtoA.[FILE_FORMAT].** | |||
| # [Model Description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Evaluation Performance | |||
| | Parameters | GPU | | |||
| | -------------------------- | ----------------------------------------------------------- | | |||
| | Model Version | CycleGAN | | |||
| | Resource | NV SMX2 V100-32G | | |||
| | uploaded Date | 12/10/2020 (month/day/year) | | |||
| | MindSpore Version | 1.1.0 | | |||
| | Dataset | Cityscapes | | |||
| | Training Parameters | epoch=200, steps=2975, batch_size=1, lr=0.002 | | |||
| | Optimizer | Adam | | |||
| | Loss Function | Mean Sqare Loss & L1 Loss | | |||
| | outputs | probability | | |||
| | Speed | 1pc: 264 ms/step; | | |||
| | Total time | 1pc: 43.6h; | | |||
| | Parameters (M) | 11.378 M | | |||
| | Checkpoint for Fine tuning | 44M (.ckpt file) | | |||
| | Scripts | [CycleGAN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/cycle_gan) | | |||
| ### Inference Performance | |||
| | Parameters | GPU | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | CycleGAN | | |||
| | Resource | GPU | | |||
| | Uploaded Date | 12/10/2020 (month/day/year) | | |||
| | MindSpore Version | 1.1.0 | | |||
| | Dataset | Cityscapes | | |||
| | batch_size | 1 | | |||
| | outputs | probability | | |||
| | Accuracy | mean_pixel_acc: 54.8, mean_class_acc: 21.3, mean_class_iou: 16.1 | | |||
| # [Description of Random Situation](#contents) | |||
| In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Eval use cityscape dataset.""" | |||
| import os | |||
| import argparse | |||
| import numpy as np | |||
| from src.dataset import make_dataset | |||
| from src.utils import CityScapes, fast_hist, get_scores | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset") | |||
| parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated") | |||
| args = parser.parse_args() | |||
| def main(): | |||
| CS = CityScapes() | |||
| cityscapes = make_dataset(args.cityscapes_dir) | |||
| hist_perframe = np.zeros((CS.class_num, CS.class_num)) | |||
| for i, img_path in enumerate(cityscapes): | |||
| if i % 100 == 0: | |||
| print('Evaluating: %d/%d' % (i, len(cityscapes))) | |||
| img_name = os.path.split(img_path)[1] | |||
| ids1 = CS.get_id(os.path.join(args.cityscapes_dir, img_name)) | |||
| ids2 = CS.get_id(os.path.join(args.result_dir, img_name)) | |||
| hist_perframe += fast_hist(ids1.flatten(), ids2.flatten(), CS.class_num) | |||
| mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe) | |||
| print(f"mean_pixel_acc: {mean_pixel_acc}, mean_class_acc: {mean_class_acc}, mean_class_iou: {mean_class_iou}") | |||
| with open('./evaluation_results.txt', 'w') as f: | |||
| f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc) | |||
| f.write('Mean class accuracy: %f\n' % mean_class_acc) | |||
| f.write('Mean class IoU: %f\n' % mean_class_iou) | |||
| f.write('************ Per class numbers below ************\n') | |||
| for i, cl in enumerate(CS.classes): | |||
| while len(cl) < 15: | |||
| cl = cl + ' ' | |||
| f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i])) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """export file.""" | |||
| import numpy as np | |||
| from mindspore import context, Tensor | |||
| from mindspore.train.serialization import export | |||
| from src.models import get_generator | |||
| from src.utils import get_args, load_ckpt | |||
| args = get_args("export") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.platform) | |||
| if __name__ == '__main__': | |||
| G_A = get_generator(args) | |||
| G_B = get_generator(args) | |||
| # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d | |||
| # Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d | |||
| G_A.set_train(True) | |||
| G_B.set_train(True) | |||
| load_ckpt(args, G_A, G_B) | |||
| input_shp = [1, 3, args.image_size, args.image_size] | |||
| input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32)) | |||
| G_A_file = f"{args.file_name}_BtoA" | |||
| export(G_A, input_array, file_name=G_A_file, file_format=args.file_format) | |||
| G_B_file = f"{args.file_name}_AtoB" | |||
| export(G_B, input_array, file_name=G_B_file, file_format=args.file_format) | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config.""" | |||
| from src.models import get_generator | |||
| def create_network(name, *args, **kwargs): | |||
| if name == "cyclegan": | |||
| G_A = get_generator(*args, **kwargs) | |||
| G_B = get_generator(*args, **kwargs) | |||
| # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d | |||
| # Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d | |||
| G_A.set_train(True) | |||
| G_B.set_train(True) | |||
| return G_A, G_B | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN predict.""" | |||
| import os | |||
| from mindspore import Tensor | |||
| from src.models import get_generator | |||
| from src.utils import get_args, load_ckpt, save_image, Reporter | |||
| from src.dataset import create_dataset | |||
| def predict(): | |||
| """Predict function.""" | |||
| args = get_args("predict") | |||
| G_A = get_generator(args) | |||
| G_B = get_generator(args) | |||
| # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d | |||
| # Use real mean and varance rather than moving_men and moving_varance in BatchNorm2d | |||
| G_A.set_train(True) | |||
| G_B.set_train(True) | |||
| load_ckpt(args, G_A, G_B) | |||
| imgs_out = os.path.join(args.outputs_dir, "predict") | |||
| if not os.path.exists(imgs_out): | |||
| os.makedirs(imgs_out) | |||
| if not os.path.exists(os.path.join(imgs_out, "fake_A")): | |||
| os.makedirs(os.path.join(imgs_out, "fake_A")) | |||
| if not os.path.exists(os.path.join(imgs_out, "fake_B")): | |||
| os.makedirs(os.path.join(imgs_out, "fake_B")) | |||
| args.data_dir = 'testA' | |||
| ds = create_dataset(args) | |||
| reporter = Reporter(args) | |||
| reporter.start_predict("A to B") | |||
| for data in ds.create_dict_iterator(output_numpy=True): | |||
| img_A = Tensor(data["image"]) | |||
| path_A = str(data["image_name"][0], encoding="utf-8") | |||
| fake_B = G_A(img_A) | |||
| save_image(fake_B, os.path.join(imgs_out, "fake_B", path_A)) | |||
| reporter.info('save fake_B at %s', os.path.join(imgs_out, "fake_B", path_A)) | |||
| reporter.end_predict() | |||
| args.data_dir = 'testB' | |||
| ds = create_dataset(args) | |||
| reporter.dataset_size = args.dataset_size | |||
| reporter.start_predict("B to A") | |||
| for data in ds.create_dict_iterator(output_numpy=True): | |||
| img_B = Tensor(data["image"]) | |||
| path_B = str(data["image_name"][0], encoding="utf-8") | |||
| fake_A = G_B(img_B) | |||
| save_image(fake_A, os.path.join(imgs_out, "fake_A", path_B)) | |||
| reporter.info('save fake_A at %s', os.path.join(imgs_out, "fake_A", path_B)) | |||
| reporter.end_predict() | |||
| if __name__ == "__main__": | |||
| predict() | |||
| @@ -0,0 +1,17 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """init file.""" | |||
| from .datasets import UnalignedDataset, ImageFolderDataset, make_dataset | |||
| from .cyclegan_dataset import create_dataset | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN dataset.""" | |||
| import os | |||
| import multiprocessing | |||
| import mindspore.dataset as de | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| from .distributed_sampler import DistributedSampler | |||
| from .datasets import UnalignedDataset, ImageFolderDataset | |||
| def create_dataset(args, shuffle=True, max_dataset_size=float("inf")): | |||
| """Create dataset""" | |||
| dataroot = args.dataroot | |||
| phase = args.phase | |||
| batch_size = args.batch_size | |||
| device_num = args.device_num | |||
| rank = args.rank | |||
| cores = multiprocessing.cpu_count() | |||
| num_parallel_workers = min(8, int(cores / device_num)) | |||
| image_size = args.image_size | |||
| mean = [0.5 * 255] * 3 | |||
| std = [0.5 * 255] * 3 | |||
| if phase == "train": | |||
| dataset = UnalignedDataset(dataroot, phase, max_dataset_size=max_dataset_size) | |||
| distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle) | |||
| ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"], | |||
| sampler=distributed_sampler, num_parallel_workers=num_parallel_workers) | |||
| trans = [ | |||
| C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)), | |||
| C.RandomHorizontalFlip(prob=0.5), | |||
| C.Normalize(mean=mean, std=std), | |||
| C.HWC2CHW() | |||
| ] | |||
| ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers) | |||
| ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers) | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| ds = ds.repeat(1) | |||
| else: | |||
| datadir = os.path.join(dataroot, args.data_dir) | |||
| dataset = ImageFolderDataset(datadir, max_dataset_size=max_dataset_size) | |||
| ds = de.GeneratorDataset(dataset, column_names=["image", "image_name"], | |||
| num_parallel_workers=num_parallel_workers) | |||
| trans = [ | |||
| C.Resize((image_size, image_size)), | |||
| C.Normalize(mean=mean, std=std), | |||
| C.HWC2CHW() | |||
| ] | |||
| ds = ds.map(operations=trans, input_columns=["image"], num_parallel_workers=num_parallel_workers) | |||
| ds = ds.batch(1, drop_remainder=True) | |||
| ds = ds.repeat(1) | |||
| args.dataset_size = len(dataset) | |||
| return ds | |||
| @@ -0,0 +1,102 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN datasets.""" | |||
| import os | |||
| import random | |||
| import numpy as np | |||
| from PIL import Image | |||
| IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff'] | |||
| def is_image_file(filename): | |||
| """Judge whether it is a picture.""" | |||
| return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS) | |||
| def make_dataset(dir_path, max_dataset_size=float("inf")): | |||
| """Return image list in dir.""" | |||
| images = [] | |||
| assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path | |||
| for root, _, fnames in sorted(os.walk(dir_path)): | |||
| for fname in fnames: | |||
| if is_image_file(fname): | |||
| path = os.path.join(root, fname) | |||
| images.append(path) | |||
| return images[:min(max_dataset_size, len(images))] | |||
| class UnalignedDataset: | |||
| """ | |||
| This dataset class can load unaligned/unpaired datasets. | |||
| Args: | |||
| dataroot (str): Images root directory. | |||
| phase (str): Train or test. It requires two directories in dataroot, like trainA and trainB to | |||
| host training images from domain A '{dataroot}/trainA' and from domain B '{dataroot}/trainB' respectively. | |||
| max_dataset_size (int): Maximum number of return image paths. | |||
| Returns: | |||
| Two domain image path list. | |||
| """ | |||
| def __init__(self, dataroot, phase, max_dataset_size=float("inf")): | |||
| self.dir_A = os.path.join(dataroot, phase + 'A') | |||
| self.dir_B = os.path.join(dataroot, phase + 'B') | |||
| self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size)) # load images from '/path/to/data/trainA' | |||
| self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size)) # load images from '/path/to/data/trainB' | |||
| self.A_size = len(self.A_paths) # get the size of dataset A | |||
| self.B_size = len(self.B_paths) # get the size of dataset B | |||
| def __getitem__(self, index): | |||
| if index % max(self.A_size, self.B_size) == 0: | |||
| random.shuffle(self.A_paths) | |||
| A_path = self.A_paths[index % self.A_size] | |||
| index_B = random.randint(0, self.B_size - 1) | |||
| B_path = self.B_paths[index_B] | |||
| A_img = np.array(Image.open(A_path).convert('RGB')) | |||
| B_img = np.array(Image.open(B_path).convert('RGB')) | |||
| return A_img, B_img | |||
| def __len__(self): | |||
| return max(self.A_size, self.B_size) | |||
| class ImageFolderDataset: | |||
| """ | |||
| This dataset class can load images from image folder. | |||
| Args: | |||
| dataroot (str): Images root directory. | |||
| max_dataset_size (int): Maximum number of return image paths. | |||
| Returns: | |||
| Image path list. | |||
| """ | |||
| def __init__(self, dataroot, max_dataset_size=float("inf")): | |||
| self.dataroot = dataroot | |||
| self.paths = sorted(make_dataset(dataroot, max_dataset_size)) | |||
| self.size = len(self.paths) | |||
| def __getitem__(self, index): | |||
| img_path = self.paths[index % self.size] | |||
| img = np.array(Image.open(img_path).convert('RGB')) | |||
| return img, os.path.split(img_path)[1] | |||
| def __len__(self): | |||
| return self.size | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Dataset distributed sampler.""" | |||
| from __future__ import division | |||
| import math | |||
| import numpy as np | |||
| class DistributedSampler: | |||
| """Distributed sampler.""" | |||
| def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True): | |||
| if num_replicas is None: | |||
| print("***********Setting world_size to 1 since it is not passed in ******************") | |||
| num_replicas = 1 | |||
| if rank is None: | |||
| print("***********Setting rank to 0 since it is not passed in ******************") | |||
| rank = 0 | |||
| self.dataset_size = dataset_size | |||
| self.num_replicas = num_replicas | |||
| self.rank = rank | |||
| self.epoch = 0 | |||
| self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas)) | |||
| self.total_size = self.num_samples * self.num_replicas | |||
| self.shuffle = shuffle | |||
| def __iter__(self): | |||
| # deterministically shuffle based on epoch | |||
| if self.shuffle: | |||
| indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size) | |||
| # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset | |||
| indices = indices.tolist() | |||
| self.epoch += 1 | |||
| # change to list type | |||
| else: | |||
| indices = list(range(self.dataset_size)) | |||
| # add extra samples to make it evenly divisible | |||
| indices += indices[:(self.total_size - len(indices))] | |||
| assert len(indices) == self.total_size | |||
| # subsample | |||
| indices = indices[self.rank:self.total_size:self.num_replicas] | |||
| assert len(indices) == self.num_samples | |||
| return iter(indices) | |||
| def __len__(self): | |||
| return self.num_samples | |||
| @@ -0,0 +1,18 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """init file.""" | |||
| from .cycle_gan import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD | |||
| from .losses import DiscriminatorLoss, GeneratorLoss, GANLoss | |||
| from .networks import init_weights | |||
| @@ -0,0 +1,252 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN network.""" | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.communication.management import get_group_size | |||
| import mindspore.ops as ops | |||
| from .resnet import ResNetGenerator | |||
| from .networks import ConvNormReLU, init_weights | |||
| from .unet import UnetGenerator | |||
| def get_generator(args, teacher_net=False): | |||
| """Return generator by args.""" | |||
| if teacher_net: | |||
| if args.model == "resnet": | |||
| net = ResNetGenerator(in_planes=args.in_planes, ngf=args.t_ngf, n_layers=args.t_gl_num, | |||
| alpha=args.t_slope, norm_mode=args.t_norm_mode, dropout=False, | |||
| pad_mode=args.pad_mode) | |||
| init_weights(net, args.init_type, args.init_gain) | |||
| elif args.model == "unet": | |||
| net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.t_ngf, | |||
| n_layers=args.t_gl_num, alpha=args.t_slope, norm_mode=args.t_norm_mode, | |||
| dropout=False) | |||
| init_weights(net, args.init_type, args.init_gain) | |||
| else: | |||
| raise NotImplementedError(f'Model {args.model} not recognized.') | |||
| else: | |||
| if args.model == "resnet": | |||
| net = ResNetGenerator(in_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num, | |||
| alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout, | |||
| pad_mode=args.pad_mode) | |||
| init_weights(net, args.init_type, args.init_gain) | |||
| elif args.model == "unet": | |||
| net = UnetGenerator(in_planes=args.in_planes, out_planes=args.in_planes, ngf=args.ngf, n_layers=args.gl_num, | |||
| alpha=args.slope, norm_mode=args.norm_mode, dropout=args.need_dropout) | |||
| init_weights(net, args.init_type, args.init_gain) | |||
| else: | |||
| raise NotImplementedError(f'Model {args.model} not recognized.') | |||
| return net | |||
| def get_discriminator(args, teacher_net=False): | |||
| """Return discriminator by args.""" | |||
| net = Discriminator(in_planes=args.in_planes, ndf=args.ndf, n_layers=args.dl_num, | |||
| alpha=args.slope, norm_mode=args.norm_mode) | |||
| init_weights(net, args.init_type, args.init_gain) | |||
| return net | |||
| class Discriminator(nn.Cell): | |||
| """ | |||
| Discriminator of GAN. | |||
| Args: | |||
| in_planes (int): Input channel. | |||
| ndf (int): Output channel. | |||
| n_layers (int): The number of ConvNormReLU blocks. | |||
| alpha (float): LeakyRelu slope. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| Returns: | |||
| Tensor, output tensor. | |||
| Examples: | |||
| >>> Discriminator(3, 64, 3) | |||
| """ | |||
| def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'): | |||
| super(Discriminator, self).__init__() | |||
| kernel_size = 4 | |||
| layers = [ | |||
| nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1), | |||
| nn.LeakyReLU(alpha) | |||
| ] | |||
| nf_mult = ndf | |||
| for i in range(1, n_layers): | |||
| nf_mult_prev = nf_mult | |||
| nf_mult = min(2 ** i, 8) * ndf | |||
| layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1)) | |||
| nf_mult_prev = nf_mult | |||
| nf_mult = min(2 ** n_layers, 8) * ndf | |||
| layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1)) | |||
| layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1)) | |||
| self.features = nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| output = self.features(x) | |||
| return output | |||
| class Generator(nn.Cell): | |||
| """ | |||
| Generator of CycleGAN, return fake_A, fake_B, rec_A, rec_B, identity_A and identity_B. | |||
| Args: | |||
| G_A (Cell): The generator network of domain A to domain B. | |||
| G_B (Cell): The generator network of domain B to domain A. | |||
| use_identity (bool): Use identity loss or not. Default: True. | |||
| Returns: | |||
| Tensors, fake_A, fake_B, rec_A, rec_B, identity_A and identity_B. | |||
| Examples: | |||
| >>> Generator(G_A, G_B) | |||
| """ | |||
| def __init__(self, G_A, G_B, use_identity=True): | |||
| super(Generator, self).__init__() | |||
| self.G_A = G_A | |||
| self.G_B = G_B | |||
| self.ones = ops.OnesLike() | |||
| self.use_identity = use_identity | |||
| def construct(self, img_A, img_B): | |||
| """If use_identity, identity loss will be used.""" | |||
| fake_A = self.G_B(img_B) | |||
| fake_B = self.G_A(img_A) | |||
| rec_A = self.G_B(fake_B) | |||
| rec_B = self.G_A(fake_A) | |||
| if self.use_identity: | |||
| identity_A = self.G_B(img_A) | |||
| identity_B = self.G_A(img_B) | |||
| else: | |||
| identity_A = self.ones(img_A) | |||
| identity_B = self.ones(img_B) | |||
| return fake_A, fake_B, rec_A, rec_B, identity_A, identity_B | |||
| class WithLossCell(nn.Cell): | |||
| """ | |||
| Wrap the network with loss function to return generator loss. | |||
| Args: | |||
| network (Cell): The target network to wrap. | |||
| """ | |||
| def __init__(self, network): | |||
| super(WithLossCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| def construct(self, img_A, img_B): | |||
| _, _, lg, _, _, _, _, _, _ = self.network(img_A, img_B) | |||
| return lg | |||
| class TrainOneStepG(nn.Cell): | |||
| """ | |||
| Encapsulation class of Cycle GAN generator network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| G (Cell): Generator with loss Cell. Note that loss function should have been added. | |||
| generator (Cell): Generator of CycleGAN. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| """ | |||
| def __init__(self, G, generator, optimizer, sens=1.0): | |||
| super(TrainOneStepG, self).__init__(auto_prefix=False) | |||
| self.optimizer = optimizer | |||
| self.G = G | |||
| self.G.set_grad() | |||
| self.G.set_train() | |||
| self.G.D_A.set_grad(False) | |||
| self.G.D_A.set_train(False) | |||
| self.G.D_B.set_grad(False) | |||
| self.G.D_B.set_train(False) | |||
| self.grad = ops.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.weights = ms.ParameterTuple(generator.trainable_params()) | |||
| self.net = WithLossCell(G) | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| if auto_parallel_context().get_device_num_is_set(): | |||
| degree = context.get_auto_parallel_context("device_num") | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, img_A, img_B): | |||
| weights = self.weights | |||
| fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) | |||
| sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens) | |||
| grads_g = self.grad(self.net, weights)(img_A, img_B, sens) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads_g = self.grad_reducer(grads_g) | |||
| return fake_A, fake_B, ops.depend(lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib | |||
| class TrainOneStepD(nn.Cell): | |||
| """ | |||
| Encapsulation class of Cycle GAN discriminator network training. | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| G (Cell): Generator with loss Cell. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| sens (Number): The adjust parameter. Default: 1.0. | |||
| """ | |||
| def __init__(self, D, optimizer, sens=1.0): | |||
| super(TrainOneStepD, self).__init__(auto_prefix=False) | |||
| self.optimizer = optimizer | |||
| self.D = D | |||
| self.D.set_grad() | |||
| self.D.set_train() | |||
| self.grad = ops.GradOperation(get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.weights = ms.ParameterTuple(D.trainable_params()) | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| if auto_parallel_context().get_device_num_is_set(): | |||
| degree = context.get_auto_parallel_context("device_num") | |||
| else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, img_A, img_B, fake_A, fake_B): | |||
| weights = self.weights | |||
| ld = self.D(img_A, img_B, fake_A, fake_B) | |||
| sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens) | |||
| grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads_d = self.grad_reducer(grads_d) | |||
| return ops.depend(ld, self.optimizer(grads_d)) | |||
| @@ -0,0 +1,175 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN losses""" | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| from mindspore import Tensor | |||
| from .cycle_gan import get_generator | |||
| from ..utils import load_teacher_ckpt | |||
| class BCEWithLogits(nn.Cell): | |||
| """ | |||
| BCEWithLogits creates a criterion to measure the Binary Cross Entropy between the true labels and | |||
| predicted labels with sigmoid logits. | |||
| Args: | |||
| reduction (str): Specifies the reduction to be applied to the output. | |||
| Its value must be one of 'none', 'mean', 'sum'. Default: 'none'. | |||
| Outputs: | |||
| Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`. | |||
| Otherwise, the output is a scalar. | |||
| """ | |||
| def __init__(self, reduction='mean'): | |||
| super(BCEWithLogits, self).__init__() | |||
| if reduction is None: | |||
| reduction = 'none' | |||
| if reduction not in ('mean', 'sum', 'none'): | |||
| raise ValueError(f"reduction method for {reduction.lower()} is not supported") | |||
| self.loss = ops.SigmoidCrossEntropyWithLogits() | |||
| self.reduce = False | |||
| if reduction == 'sum': | |||
| self.reduce_mode = ops.ReduceSum() | |||
| self.reduce = True | |||
| elif reduction == 'mean': | |||
| self.reduce_mode = ops.ReduceMean() | |||
| self.reduce = True | |||
| def construct(self, predict, target): | |||
| loss = self.loss(predict, target) | |||
| if self.reduce: | |||
| loss = self.reduce_mode(loss) | |||
| return loss | |||
| class GANLoss(nn.Cell): | |||
| """ | |||
| Cycle GAN loss factory. | |||
| Args: | |||
| mode (str): The type of GAN objective. It currently supports 'vanilla', 'lsgan'. Default: 'lsgan'. | |||
| reduction (str): Specifies the reduction to be applied to the output. | |||
| Its value must be one of 'none', 'mean', 'sum'. Default: 'none'. | |||
| Outputs: | |||
| Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`. | |||
| Otherwise, the output is a scalar. | |||
| """ | |||
| def __init__(self, mode="lsgan", reduction='mean'): | |||
| super(GANLoss, self).__init__() | |||
| self.loss = None | |||
| self.ones = ops.OnesLike() | |||
| if mode == "lsgan": | |||
| self.loss = nn.MSELoss(reduction) | |||
| elif mode == "vanilla": | |||
| self.loss = BCEWithLogits(reduction) | |||
| else: | |||
| raise NotImplementedError(f'GANLoss {mode} not recognized, we support lsgan and vanilla.') | |||
| def construct(self, predict, target): | |||
| target = ops.cast(target, ops.dtype(predict)) | |||
| target = self.ones(predict) * target | |||
| loss = self.loss(predict, target) | |||
| return loss | |||
| class GeneratorLoss(nn.Cell): | |||
| """ | |||
| Cycle GAN generator loss. | |||
| Args: | |||
| args (class): Option class. | |||
| generator (Cell): Generator of CycleGAN. | |||
| D_A (Cell): The discriminator network of domain A to domain B. | |||
| D_B (Cell): The discriminator network of domain B to domain A. | |||
| Outputs: | |||
| Tuple Tensor, the losses of generator. | |||
| """ | |||
| def __init__(self, args, generator, D_A, D_B): | |||
| super(GeneratorLoss, self).__init__() | |||
| self.lambda_A = args.lambda_A | |||
| self.lambda_B = args.lambda_B | |||
| self.lambda_idt = args.lambda_idt | |||
| self.use_identity = args.lambda_idt > 0 | |||
| self.dis_loss = GANLoss(args.gan_mode) | |||
| self.rec_loss = nn.L1Loss("mean") | |||
| self.generator = generator | |||
| self.D_A = D_A | |||
| self.D_B = D_B | |||
| self.true = Tensor(True, mstype.bool_) | |||
| self.kd = args.kd | |||
| if self.kd: | |||
| self.GT_A = get_generator(args, True) | |||
| load_teacher_ckpt(self.GT_A, args.GT_A_ckpt, "GT_A", "G_A") | |||
| self.GT_B = get_generator(args, True) | |||
| load_teacher_ckpt(self.GT_B, args.GT_B_ckpt, "GT_B", "G_B") | |||
| self.GT_A.set_train(True) | |||
| self.GT_B.set_train(True) | |||
| def construct(self, img_A, img_B): | |||
| """If use_identity, identity loss will be used.""" | |||
| fake_A, fake_B, rec_A, rec_B, identity_A, identity_B = self.generator(img_A, img_B) | |||
| loss_G_A = self.dis_loss(self.D_B(fake_B), self.true) | |||
| loss_G_B = self.dis_loss(self.D_A(fake_A), self.true) | |||
| loss_C_A = self.rec_loss(rec_A, img_A) * self.lambda_A | |||
| loss_C_B = self.rec_loss(rec_B, img_B) * self.lambda_B | |||
| if self.use_identity: | |||
| loss_idt_A = self.rec_loss(identity_A, img_A) * self.lambda_A * self.lambda_idt | |||
| loss_idt_B = self.rec_loss(identity_B, img_B) * self.lambda_B * self.lambda_idt | |||
| else: | |||
| loss_idt_A = 0 | |||
| loss_idt_B = 0 | |||
| loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B | |||
| if self.kd: | |||
| teacher_A = self.GT_B(img_B) | |||
| teacher_B = self.GT_A(img_A) | |||
| kd_loss_A = self.rec_loss(teacher_A, fake_A) * self.lambda_A * 5 | |||
| kd_loss_B = self.rec_loss(teacher_B, fake_B) * self.lambda_A * 5 | |||
| loss_G += kd_loss_A + kd_loss_B | |||
| return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B) | |||
| class DiscriminatorLoss(nn.Cell): | |||
| """ | |||
| Cycle GAN discriminator loss. | |||
| Args: | |||
| args (class): option class. | |||
| D_A (Cell): The discriminator network of domain A to domain B. | |||
| D_B (Cell): The discriminator network of domain B to domain A. | |||
| Outputs: | |||
| Tuple Tensor, the loss of discriminator. | |||
| """ | |||
| def __init__(self, args, D_A, D_B): | |||
| super(DiscriminatorLoss, self).__init__() | |||
| self.D_A = D_A | |||
| self.D_B = D_B | |||
| self.false = Tensor(False, mstype.bool_) | |||
| self.true = Tensor(True, mstype.bool_) | |||
| self.dis_loss = GANLoss(args.gan_mode) | |||
| self.rec_loss = nn.L1Loss("mean") | |||
| def construct(self, img_A, img_B, fake_A, fake_B): | |||
| D_fake_A = self.D_A(fake_A) | |||
| D_img_A = self.D_A(img_A) | |||
| D_fake_B = self.D_B(fake_B) | |||
| D_img_B = self.D_B(img_B) | |||
| loss_D_A = self.dis_loss(D_fake_A, self.false) + self.dis_loss(D_img_A, self.true) | |||
| loss_D_B = self.dis_loss(D_fake_B, self.false) + self.dis_loss(D_img_B, self.true) | |||
| loss_D = (loss_D_A + loss_D_B) * 0.5 | |||
| return loss_D | |||
| @@ -0,0 +1,156 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN network.""" | |||
| import mindspore.nn as nn | |||
| def init_weights(net, init_type='normal', init_gain=0.02): | |||
| """ | |||
| Initialize network weights. | |||
| Parameters: | |||
| net (Cell): Network to be initialized | |||
| init_type (str): The name of an initialization method: normal | xavier. | |||
| init_gain (float): Gain factor for normal and xavier. | |||
| """ | |||
| for cell in net.cells_and_names(): | |||
| if isinstance(cell, nn.Conv2d): | |||
| if init_type == 'normal': | |||
| cell.weight.set_data(init.initializer(init.Normal(init_gain))) | |||
| elif init_type == 'xavier': | |||
| cell.weight.set_data(init.initializer(init.XavierUniform(init_gain))) | |||
| else: | |||
| raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |||
| elif isinstance(cell, nn.BatchNorm2d): | |||
| cell.gamma.set_data(init.initializer('ones', cell.gamma.shape)) | |||
| cell.beta.set_data(init.initializer('zeros', cell.beta.shape)) | |||
| class ConvNormReLU(nn.Cell): | |||
| """ | |||
| Convolution fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition. | |||
| Args: | |||
| in_planes (int): Input channel. | |||
| out_planes (int): Output channel. | |||
| kernel_size (int): Input kernel size. Default: 4. | |||
| stride (int): Stride size for the first convolutional layer. Default: 2. | |||
| alpha (float): Slope of LackyReLU. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". | |||
| Default: "CONSTANT". | |||
| use_relu (bool): Use relu or not. Default: True. | |||
| padding (int): Pad size, if it is None, it will calculate by kernel_size. Default: None. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=4, | |||
| stride=2, | |||
| alpha=0.2, | |||
| norm_mode='batch', | |||
| pad_mode='CONSTANT', | |||
| use_relu=True, | |||
| padding=None): | |||
| super(ConvNormReLU, self).__init__() | |||
| norm = nn.BatchNorm2d(out_planes) | |||
| if norm_mode == 'instance': | |||
| # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d | |||
| norm = nn.BatchNorm2d(out_planes, affine=False) | |||
| has_bias = (norm_mode == 'instance') | |||
| if padding is None: | |||
| padding = (kernel_size - 1) // 2 | |||
| if pad_mode == 'CONSTANT': | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', | |||
| has_bias=has_bias, padding=padding) | |||
| layers = [conv, norm] | |||
| else: | |||
| paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding)) | |||
| pad = nn.Pad(paddings=paddings, mode=pad_mode) | |||
| conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias) | |||
| layers = [pad, conv, norm] | |||
| if use_relu: | |||
| relu = nn.ReLU() | |||
| if alpha > 0: | |||
| relu = nn.LeakyReLU(alpha) | |||
| layers.append(relu) | |||
| self.features = nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| output = self.features(x) | |||
| return output | |||
| class ConvTransposeNormReLU(nn.Cell): | |||
| """ | |||
| ConvTranspose2d fused with BatchNorm/InstanceNorm and ReLU/LackyReLU block definition. | |||
| Args: | |||
| in_planes (int): Input channel. | |||
| out_planes (int): Output channel. | |||
| kernel_size (int): Input kernel size. Default: 4. | |||
| stride (int): Stride size for the first convolutional layer. Default: 2. | |||
| alpha (float): Slope of LackyReLU. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". | |||
| Default: "CONSTANT". | |||
| use_relu (bool): use relu or not. Default: True. | |||
| padding (int): pad size, if it is None, it will calculate by kernel_size. Default: None. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, | |||
| in_planes, | |||
| out_planes, | |||
| kernel_size=4, | |||
| stride=2, | |||
| alpha=0.2, | |||
| norm_mode='batch', | |||
| pad_mode='CONSTANT', | |||
| use_relu=True, | |||
| padding=None): | |||
| super(ConvTransposeNormReLU, self).__init__() | |||
| conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride=stride, pad_mode='same') | |||
| norm = nn.BatchNorm2d(out_planes) | |||
| if norm_mode == 'instance': | |||
| # Use BatchNorm2d with batchsize=1, affine=False, training=True instead of InstanceNorm2d | |||
| norm = nn.BatchNorm2d(out_planes, affine=False) | |||
| has_bias = (norm_mode == 'instance') | |||
| if padding is None: | |||
| padding = (kernel_size - 1) // 2 | |||
| if pad_mode == 'CONSTANT': | |||
| conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='same', has_bias=has_bias) | |||
| layers = [conv, norm] | |||
| else: | |||
| paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding)) | |||
| pad = nn.Pad(paddings=paddings, mode=pad_mode) | |||
| conv = nn.Conv2dTranspose(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias) | |||
| layers = [pad, conv, norm] | |||
| if use_relu: | |||
| relu = nn.ReLU() | |||
| if alpha > 0: | |||
| relu = nn.LeakyReLU(alpha) | |||
| layers.append(relu) | |||
| self.features = nn.SequentialCell(layers) | |||
| def construct(self, x): | |||
| output = self.features(x) | |||
| return output | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ResNet Generator.""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| from .networks import ConvNormReLU, ConvTransposeNormReLU | |||
| class ResidualBlock(nn.Cell): | |||
| """ | |||
| ResNet residual block definition. | |||
| Args: | |||
| dim (int): Input and output channel. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| dropout (bool): Use dropout or not. Default: False. | |||
| pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". | |||
| Default: "CONSTANT". | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, dim, norm_mode='batch', dropout=False, pad_mode="CONSTANT"): | |||
| super(ResidualBlock, self).__init__() | |||
| self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode) | |||
| self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False) | |||
| self.dropout = dropout | |||
| if dropout: | |||
| self.dropout = nn.Dropout(0.5) | |||
| def construct(self, x): | |||
| out = self.conv1(x) | |||
| if self.dropout: | |||
| out = self.dropout(out) | |||
| out = self.conv2(out) | |||
| return x + out | |||
| class ResNetGenerator(nn.Cell): | |||
| """ | |||
| ResNet Generator of GAN. | |||
| Args: | |||
| in_planes (int): Input channel. | |||
| ngf (int): Output channel. | |||
| n_layers (int): The number of ConvNormReLU blocks. | |||
| alpha (float): LeakyRelu slope. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| dropout (bool): Use dropout or not. Default: False. | |||
| pad_mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". | |||
| Default: "CONSTANT". | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, in_planes=3, ngf=64, n_layers=9, alpha=0.2, norm_mode='batch', dropout=False, | |||
| pad_mode="CONSTANT"): | |||
| super(ResNetGenerator, self).__init__() | |||
| self.conv_in = ConvNormReLU(in_planes, ngf, 7, 1, alpha, norm_mode, pad_mode=pad_mode) | |||
| self.down_1 = ConvNormReLU(ngf, ngf * 2, 3, 2, alpha, norm_mode) | |||
| self.down_2 = ConvNormReLU(ngf * 2, ngf * 4, 3, 2, alpha, norm_mode) | |||
| layers = [ResidualBlock(ngf * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers | |||
| self.residuals = nn.SequentialCell(layers) | |||
| self.up_2 = ConvTransposeNormReLU(ngf * 4, ngf * 2, 3, 2, alpha, norm_mode) | |||
| self.up_1 = ConvTransposeNormReLU(ngf * 2, ngf, 3, 2, alpha, norm_mode) | |||
| if pad_mode == "CONSTANT": | |||
| self.conv_out = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad', padding=3) | |||
| else: | |||
| pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode) | |||
| conv = nn.Conv2d(ngf, 3, kernel_size=7, stride=1, pad_mode='pad') | |||
| self.conv_out = nn.SequentialCell([pad, conv]) | |||
| self.activate = ops.Tanh() | |||
| def construct(self, x): | |||
| x = self.conv_in(x) | |||
| x = self.down_1(x) | |||
| x = self.down_2(x) | |||
| x = self.residuals(x) | |||
| x = self.up_2(x) | |||
| x = self.up_1(x) | |||
| output = self.conv_out(x) | |||
| return self.activate(output) | |||
| @@ -0,0 +1,124 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """UNet Generator.""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| class UnetGenerator(nn.Cell): | |||
| """ | |||
| Unet-based generator. | |||
| Args: | |||
| in_planes (int): the number of channels in input images. | |||
| out_planes (int): the number of channels in output images. | |||
| ngf (int): the number of filters in the last conv layer. | |||
| n_layers (int): the number of downsamplings in UNet. | |||
| alpha (float): LeakyRelu slope. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| dropout (bool): Use dropout or not. Default: False. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, in_planes, out_planes, ngf=64, n_layers=7, alpha=0.2, norm_mode='bn', dropout=False): | |||
| super(UnetGenerator, self).__init__() | |||
| # construct unet structure | |||
| unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None, | |||
| norm_mode=norm_mode, innermost=True) | |||
| for _ in range(n_layers - 5): | |||
| unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block, | |||
| norm_mode=norm_mode, dropout=dropout) | |||
| # gradually reduce the number of filters from ngf * 8 to ngf | |||
| unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block, | |||
| norm_mode=norm_mode) | |||
| unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block, | |||
| norm_mode=norm_mode) | |||
| unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block, norm_mode=norm_mode) | |||
| self.model = UnetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block, | |||
| outermost=True, norm_mode=norm_mode) | |||
| def construct(self, x): | |||
| return self.model(x) | |||
| class UnetSkipConnectionBlock(nn.Cell): | |||
| """Unet submodule with skip connection. | |||
| Args: | |||
| outer_nc (int): The number of filters in the outer conv layer | |||
| inner_nc (int): The number of filters in the inner conv layer | |||
| in_planes (int): The number of channels in input images/features | |||
| dropout (bool): Use dropout or not. Default: False. | |||
| submodule (Cell): Previously defined submodules | |||
| outermost (bool): If this module is the outermost module | |||
| innermost (bool): If this module is the innermost module | |||
| alpha (float): LeakyRelu slope. Default: 0.2. | |||
| norm_mode (str): Specifies norm method. The optional values are "batch", "instance". | |||
| Returns: | |||
| Tensor, output tensor. | |||
| """ | |||
| def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False, | |||
| submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'): | |||
| super(UnetSkipConnectionBlock, self).__init__() | |||
| downnorm = nn.BatchNorm2d(inner_nc) | |||
| upnorm = nn.BatchNorm2d(outer_nc) | |||
| use_bias = False | |||
| if norm_mode == 'instance': | |||
| downnorm = nn.BatchNorm2d(inner_nc, affine=False) | |||
| upnorm = nn.BatchNorm2d(outer_nc, affine=False) | |||
| use_bias = True | |||
| if in_planes is None: | |||
| in_planes = outer_nc | |||
| downconv = nn.Conv2d(in_planes, inner_nc, kernel_size=4, | |||
| stride=2, padding=1, has_bias=use_bias, pad_mode='pad') | |||
| downrelu = nn.LeakyReLU(alpha) | |||
| uprelu = nn.ReLU() | |||
| if outermost: | |||
| upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, | |||
| kernel_size=4, stride=2, | |||
| padding=1, pad_mode='pad') | |||
| down = [downconv] | |||
| up = [uprelu, upconv, nn.Tanh()] | |||
| model = down + [submodule] + up | |||
| elif innermost: | |||
| upconv = nn.Conv2dTranspose(inner_nc, outer_nc, | |||
| kernel_size=4, stride=2, | |||
| padding=1, has_bias=use_bias, pad_mode='pad') | |||
| down = [downrelu, downconv] | |||
| up = [uprelu, upconv, upnorm] | |||
| model = down + up | |||
| else: | |||
| upconv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, | |||
| kernel_size=4, stride=2, | |||
| padding=1, has_bias=use_bias, pad_mode='pad') | |||
| down = [downrelu, downconv, downnorm] | |||
| up = [uprelu, upconv, upnorm] | |||
| model = down + [submodule] + up | |||
| if dropout: | |||
| model.append(nn.Dropout(0.5)) | |||
| self.model = nn.SequentialCell(model) | |||
| self.skip_connections = not outermost | |||
| self.concat = ops.Concat(axis=1) | |||
| def construct(self, x): | |||
| out = self.model(x) | |||
| if self.skip_connections: | |||
| out = self.concat((out, x)) | |||
| return out | |||
| @@ -0,0 +1,19 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """init file.""" | |||
| from .args import get_args | |||
| from .reporter import Reporter | |||
| from .tools import get_lr, load_teacher_ckpt, ImagePool, load_ckpt, save_image | |||
| from .cityscapes_utils import CityScapes, fast_hist, get_scores | |||
| @@ -0,0 +1,145 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """get args.""" | |||
| import argparse | |||
| import ast | |||
| from mindspore.context import ParallelMode | |||
| from mindspore import context | |||
| from mindspore.communication.management import init, get_rank | |||
| def get_args(phase): | |||
| """Define the common options that are used in both training and test.""" | |||
| parser = argparse.ArgumentParser(description='Cycle GAN.') | |||
| # basic parameters | |||
| parser.add_argument('--model', type=str, default="resnet", choices=("resnet", "unet"), \ | |||
| help='generator model, should be in [resnet, unet].') | |||
| parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ | |||
| help='run platform, only support GPU, CPU and Ascend') | |||
| parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.") | |||
| parser.add_argument("--lr", type=float, default=0.0002, help="learning rate, default is 0.0002.") | |||
| parser.add_argument('--pool_size', type=int, default=50, \ | |||
| help='the size of image buffer that stores previously generated images, default is 50.') | |||
| parser.add_argument('--lr_policy', type=str, default='linear', choices=("linear", "constant"), \ | |||
| help='learning rate policy, default is linear') | |||
| parser.add_argument("--image_size", type=int, default=256, help="input image_size, default is 256.") | |||
| parser.add_argument('--batch_size', type=int, default=1, help='batch_size, default is 1.') | |||
| parser.add_argument('--max_epoch', type=int, default=200, help='epoch size for training, default is 200.') | |||
| parser.add_argument('--n_epochs', type=int, default=100, \ | |||
| help='number of epochs with the initial learning rate, default is 100') | |||
| parser.add_argument("--beta1", type=float, default=0.5, help="Adam beta1, default is 0.5.") | |||
| parser.add_argument('--init_type', type=str, default='normal', choices=("normal", "xavier"), \ | |||
| help='network initialization, default is normal.') | |||
| parser.add_argument('--init_gain', type=float, default=0.02, \ | |||
| help='scaling factor for normal, xavier and orthogonal, default is 0.02.') | |||
| # model parameters | |||
| parser.add_argument('--in_planes', type=int, default=3, help='input channels, default is 3.') | |||
| parser.add_argument('--ngf', type=int, default=64, help='generator model filter numbers, default is 64.') | |||
| parser.add_argument('--gl_num', type=int, default=9, help='generator model residual block numbers, default is 9.') | |||
| parser.add_argument('--ndf', type=int, default=64, help='discriminator model filter numbers, default is 64.') | |||
| parser.add_argument('--dl_num', type=int, default=3, \ | |||
| help='discriminator model residual block numbers, default is 3.') | |||
| parser.add_argument('--slope', type=float, default=0.2, help='leakyrelu slope, default is 0.2.') | |||
| parser.add_argument('--norm_mode', type=str, default="instance", choices=("batch", "instance"), \ | |||
| help='norm mode, default is instance.') | |||
| parser.add_argument('--lambda_A', type=float, default=10.0, \ | |||
| help='weight for cycle loss (A -> B -> A), default is 10.') | |||
| parser.add_argument('--lambda_B', type=float, default=10.0, \ | |||
| help='weight for cycle loss (B -> A -> B), default is 10.') | |||
| parser.add_argument('--lambda_idt', type=float, default=0.5, \ | |||
| help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the ' | |||
| 'weight of the identity mapping loss. For example, if the weight of the identity loss ' | |||
| 'should be 10 times smaller than the weight of the reconstruction loss,' | |||
| 'please set lambda_identity = 0.1, default is 0.5.') | |||
| parser.add_argument('--gan_mode', type=str, default='lsgan', choices=("lsgan", "vanilla"), \ | |||
| help='the type of GAN loss, default is lsgan.') | |||
| parser.add_argument('--pad_mode', type=str, default='REFLECT', choices=("CONSTANT", "REFLECT", "SYMMETRIC"), \ | |||
| help='the type of Pad, default is REFLECT.') | |||
| parser.add_argument('--need_dropout', type=ast.literal_eval, default=True, \ | |||
| help='whether need dropout, default is True.') | |||
| # distillation learning parameters | |||
| parser.add_argument('--kd', type=ast.literal_eval, default=False, \ | |||
| help='knowledge distillation learning or not, default is False.') | |||
| parser.add_argument('--t_ngf', type=int, default=64, \ | |||
| help='teacher network generator model filter numbers when `kd` is True, default is 64.') | |||
| parser.add_argument('--t_gl_num', type=int, default=9, \ | |||
| help='teacher network generator model residual block numbers when `kd` is True, default is 9.') | |||
| parser.add_argument('--t_slope', type=float, default=0.2, \ | |||
| help='teacher network leakyrelu slope when `kd` is True, default is 0.2.') | |||
| parser.add_argument('--t_norm_mode', type=str, default="instance", choices=("batch", "instance"), \ | |||
| help='teacher network norm mode when `kd` is True, default is instance.') | |||
| parser.add_argument("--GT_A_ckpt", type=str, default=None, \ | |||
| help="teacher network pretrained checkpoint file path of G_A when `kd` is True.") | |||
| parser.add_argument("--GT_B_ckpt", type=str, default=None, \ | |||
| help="teacher network pretrained checkpoint file path of G_B when `kd` is True.") | |||
| # additional parameters | |||
| parser.add_argument('--device_num', type=int, default=1, help='device num, default is 1.') | |||
| parser.add_argument("--G_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_A.") | |||
| parser.add_argument("--G_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of G_B.") | |||
| parser.add_argument("--D_A_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_A.") | |||
| parser.add_argument("--D_B_ckpt", type=str, default=None, help="pretrained checkpoint file path of D_B.") | |||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 10.") | |||
| parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.") | |||
| parser.add_argument('--need_profiler', type=ast.literal_eval, default=False, \ | |||
| help='whether need profiler, default is False.') | |||
| parser.add_argument('--save_graphs', type=ast.literal_eval, default=False, \ | |||
| help='whether save graphs, default is False.') | |||
| parser.add_argument('--outputs_dir', type=str, default='./outputs', \ | |||
| help='models are saved here, default is ./outputs.') | |||
| parser.add_argument('--dataroot', default=None, \ | |||
| help='path of images (should have subfolders trainA, trainB, testA, testB, etc).') | |||
| parser.add_argument('--save_imgs', type=ast.literal_eval, default=True, \ | |||
| help='whether save imgs when epoch end, if True result images will generate in ' | |||
| '`outputs_dir/imgs`, default is True.') | |||
| if phase == "export": | |||
| parser.add_argument("--file_name", type=str, default="cyclegan", help="output file name prefix.") | |||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', \ | |||
| help='file format') | |||
| args = parser.parse_args() | |||
| if args.device_num > 1 and args.platform != "CPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, save_graphs=args.save_graphs) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | |||
| device_num=args.device_num) | |||
| init() | |||
| args.rank = get_rank() | |||
| else: | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.platform, | |||
| save_graphs=args.save_graphs, device_id=args.device_id) | |||
| args.rank = 0 | |||
| args.device_num = 1 | |||
| if args.platform != "GPU": | |||
| args.pad_mode = "CONSTANT" | |||
| if phase != "train" and (args.G_A_ckpt is None or args.G_B_ckpt is None): | |||
| raise ValueError('Must set G_A_ckpt and G_B_ckpt in predict phase!') | |||
| if args.kd: | |||
| if args.GT_A_ckpt is None or args.GT_B_ckpt is None: | |||
| raise ValueError('Must set GT_A_ckpt, GT_B_ckpt in knowledge distillation!') | |||
| if args.norm_mode == "instance" or (args.kd and args.t_norm_mode == "instance"): | |||
| args.batch_size = 1 | |||
| if args.dataroot is None and (phase in ["train", "predict"]): | |||
| raise ValueError('Must set dataroot!') | |||
| args.n_epochs_decay = args.max_epoch - args.n_epochs | |||
| args.phase = phase | |||
| return args | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cityscape utils.""" | |||
| import numpy as np | |||
| from PIL import Image | |||
| # label name and RGB color map. | |||
| label2color = { | |||
| 'unlabeled': (0, 0, 0), | |||
| 'ego vehicle': (0, 0, 0), | |||
| 'rectification border': (0, 0, 0), | |||
| 'out of roi': (0, 0, 0), | |||
| 'static': (0, 0, 0), | |||
| 'dynamic': (111, 74, 0), | |||
| 'ground': (81, 0, 81), | |||
| 'road': (128, 64, 128), | |||
| 'sidewalk': (244, 35, 232), | |||
| 'parking': (250, 170, 160), | |||
| 'rail track': (230, 150, 140), | |||
| 'building': (70, 70, 70), | |||
| 'wall': (102, 102, 156), | |||
| 'fence': (190, 153, 153), | |||
| 'guard rail': (180, 165, 180), | |||
| 'bridge': (150, 100, 100), | |||
| 'tunnel': (150, 120, 90), | |||
| 'pole': (153, 153, 153), | |||
| 'polegroup': (153, 153, 153), | |||
| 'traffic light': (250, 170, 30), | |||
| 'traffic sign': (220, 220, 0), | |||
| 'vegetation': (107, 142, 35), | |||
| 'terrain': (152, 251, 152), | |||
| 'sky': (70, 130, 180), | |||
| 'person': (220, 20, 60), | |||
| 'rider': (255, 0, 0), | |||
| 'car': (0, 0, 142), | |||
| 'truck': (0, 0, 70), | |||
| 'bus': (0, 60, 100), | |||
| 'caravan': (0, 0, 90), | |||
| 'trailer': (0, 0, 110), | |||
| 'train': (0, 80, 100), | |||
| 'motorcycle': (0, 0, 230), | |||
| 'bicycle': (119, 11, 32), | |||
| 'license plate': (0, 0, 142) | |||
| } | |||
| def fast_hist(a, b, n): | |||
| k = np.where((a >= 0) & (a < n))[0] | |||
| bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2) | |||
| if len(bc) != n**2: | |||
| # ignore this example if dimension mismatch | |||
| return 0 | |||
| return bc.reshape(n, n) | |||
| def get_scores(hist): | |||
| # Mean pixel accuracy | |||
| acc = np.diag(hist).sum() / (hist.sum() + 1e-12) | |||
| # Per class accuracy | |||
| cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12) | |||
| # Per class IoU | |||
| iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12) | |||
| return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu | |||
| class CityScapes: | |||
| """CityScapes util class.""" | |||
| def __init__(self): | |||
| self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence', | |||
| 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', | |||
| 'sky', 'person', 'rider', 'car', 'truck', | |||
| 'bus', 'train', 'motorcycle', 'bicycle', 'unlabeled'] | |||
| self.color_list = [] | |||
| for name in self.classes: | |||
| self.color_list.append(label2color[name].color) | |||
| self.class_num = len(self.classes) | |||
| def get_id(self, img_path): | |||
| """Get train id by img""" | |||
| img = np.array(Image.open(img_path).convert("RGB")) | |||
| w, h, _ = img.shape | |||
| img_tile = np.tile(img, (1, 1, self.class_num)).reshape(w, h, self.class_num, 3) | |||
| diff = np.abs(img_tile - self.color_list).sum(axis=-1) | |||
| ids = diff.argmin(axis=-1) | |||
| return ids | |||
| @@ -0,0 +1,84 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """prepare cityscapes dataset to cyclegan format""" | |||
| import os | |||
| import argparse | |||
| import glob | |||
| from PIL import Image | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--gtFine_dir', type=str, required=True, | |||
| help='Path to the Cityscapes gtFine directory.') | |||
| parser.add_argument('--leftImg8bit_dir', type=str, required=True, | |||
| help='Path to the Cityscapes leftImg8bit_trainvaltest directory.') | |||
| parser.add_argument('--output_dir', type=str, required=True, | |||
| default='./cityscapes', | |||
| help='Directory the output images will be written to.') | |||
| opt = parser.parse_args() | |||
| def load_resized_img(path): | |||
| """Load image with RGB and resize to (256, 256)""" | |||
| return Image.open(path).convert('RGB').resize((256, 256)) | |||
| def check_matching_pair(segmap_path, photo_path): | |||
| """Check the segment images and photo images are matched or not.""" | |||
| segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '') | |||
| photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '') | |||
| assert segmap_identifier == photo_identifier, \ | |||
| f"[{segmap_path}] and [{photo_path}] don't seem to be matching. Aborting." | |||
| def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase): | |||
| """Process citycapes dataset to cyclegan dataset format.""" | |||
| save_phase = 'test' if phase == 'val' else 'train' | |||
| savedir = os.path.join(output_dir, save_phase) | |||
| os.makedirs(savedir + 'A', exist_ok=True) | |||
| os.makedirs(savedir + 'B', exist_ok=True) | |||
| print(f"Directory structure prepared at {output_dir}") | |||
| segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png" | |||
| segmap_paths = glob.glob(segmap_expr) | |||
| segmap_paths = sorted(segmap_paths) | |||
| photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png" | |||
| photo_paths = glob.glob(photo_expr) | |||
| photo_paths = sorted(photo_paths) | |||
| assert len(segmap_paths) == len(photo_paths), \ | |||
| "{} images that match [{}], and {} images that match [{}]. Aborting.".format( | |||
| len(segmap_paths), segmap_expr, len(photo_paths), photo_expr) | |||
| for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)): | |||
| check_matching_pair(segmap_path, photo_path) | |||
| segmap = load_resized_img(segmap_path) | |||
| photo = load_resized_img(photo_path) | |||
| # data for cyclegan where the two images are stored at two distinct directories | |||
| savepath = os.path.join(savedir + 'A', f"{i + 1}.jpg") | |||
| photo.save(savepath) | |||
| savepath = os.path.join(savedir + 'B', f"{i + 1}.jpg") | |||
| segmap.save(savepath) | |||
| if i % (len(segmap_paths) // 10) == 0: | |||
| print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath)) | |||
| if __name__ == '__main__': | |||
| print('Preparing Cityscapes Dataset for val phase') | |||
| process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val") | |||
| print('Preparing Cityscapes Dataset for train phase') | |||
| process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train") | |||
| print('Done') | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Reporter class.""" | |||
| import logging | |||
| import os | |||
| import time | |||
| from datetime import datetime | |||
| from mindspore.train.serialization import save_checkpoint | |||
| from .tools import save_image | |||
| class Reporter(logging.Logger): | |||
| """ | |||
| This class includes several functions that can save images/checkpoints and print/save logging information. | |||
| Args: | |||
| args (class): Option class. | |||
| """ | |||
| def __init__(self, args): | |||
| super(Reporter, self).__init__("cyclegan") | |||
| self.log_dir = os.path.join(args.outputs_dir, 'log') | |||
| self.imgs_dir = os.path.join(args.outputs_dir, "imgs") | |||
| self.ckpts_dir = os.path.join(args.outputs_dir, "ckpt") | |||
| if not os.path.exists(self.log_dir): | |||
| os.makedirs(self.log_dir, exist_ok=True) | |||
| if not os.path.exists(self.imgs_dir): | |||
| os.makedirs(self.imgs_dir, exist_ok=True) | |||
| if not os.path.exists(self.ckpts_dir): | |||
| os.makedirs(self.ckpts_dir, exist_ok=True) | |||
| self.rank = args.rank | |||
| self.save_checkpoint_epochs = args.save_checkpoint_epochs | |||
| self.save_imgs = args.save_imgs | |||
| # console handler | |||
| console = logging.StreamHandler() | |||
| console.setLevel(logging.INFO) | |||
| formatter = logging.Formatter('%(message)s') | |||
| console.setFormatter(formatter) | |||
| self.addHandler(console) | |||
| # file handler | |||
| log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(self.rank) | |||
| self.log_fn = os.path.join(self.log_dir, log_name) | |||
| fh = logging.FileHandler(self.log_fn) | |||
| fh.setLevel(logging.INFO) | |||
| fh.setFormatter(formatter) | |||
| self.addHandler(fh) | |||
| self.save_args(args) | |||
| self.step = 0 | |||
| self.epoch = 0 | |||
| self.dataset_size = args.dataset_size | |||
| self.print_iter = args.print_iter | |||
| self.G_loss = [] | |||
| self.D_loss = [] | |||
| def info(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(logging.INFO): | |||
| self._log(logging.INFO, msg, args, **kwargs) | |||
| def save_args(self, args): | |||
| self.info('Args:') | |||
| args_dict = vars(args) | |||
| for key in args_dict.keys(): | |||
| self.info('--> %s: %s', key, args_dict[key]) | |||
| self.info('') | |||
| def important_info(self, msg, *args, **kwargs): | |||
| if self.logger.isEnabledFor(logging.INFO) and self.rank == 0: | |||
| line_width = 2 | |||
| important_msg = '\n' | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += '*'*line_width + ' '*8 + msg + '\n' | |||
| important_msg += ('*'*line_width + '\n')*2 | |||
| important_msg += ('*'*70 + '\n')*line_width | |||
| self.info(important_msg, *args, **kwargs) | |||
| def epoch_start(self): | |||
| self.step_start_time = time.time() | |||
| self.epoch_start_time = time.time() | |||
| self.step = 0 | |||
| self.epoch += 1 | |||
| self.G_loss = [] | |||
| self.D_loss = [] | |||
| def step_end(self, res_G, res_D): | |||
| """print log when step end.""" | |||
| self.step += 1 | |||
| loss_D = float(res_D.asnumpy()) | |||
| res = [] | |||
| for item in res_G[2:]: | |||
| res.append(float(item.asnumpy())) | |||
| self.G_loss.append(res[0]) | |||
| self.D_loss.append(loss_D) | |||
| if self.step % self.print_iter == 0: | |||
| step_cost = (time.time() - self.step_start_time) * 1000 / self.print_iter | |||
| losses = "G_loss: {:.2f}, D_loss:{:.2f}, loss_G_A: {:.2f}, loss_G_B: {:.2f}, loss_C_A: {:.2f},"\ | |||
| "loss_C_B: {:.2f}, loss_idt_A: {:.2f}, loss_idt_B:{:.2f}".format( | |||
| res[0], loss_D, res[1], res[2], res[3], res[4], res[5], res[6]) | |||
| self.info("Epoch[{}] [{}/{}] step cost: {:.2f} ms, {}".format( | |||
| self.epoch, self.step, self.dataset_size, step_cost, losses)) | |||
| self.step_start_time = time.time() | |||
| def epoch_end(self, net): | |||
| """print log and save cgeckpoints when epoch end.""" | |||
| epoch_cost = (time.time() - self.epoch_start_time) * 1000 | |||
| pre_step_time = epoch_cost / self.dataset_size | |||
| mean_loss_G = sum(self.G_loss) / self.dataset_size | |||
| mean_loss_D = sum(self.D_loss) / self.dataset_size | |||
| self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, G_loss: {:.2f}, D_loss: {:.2f}".format( | |||
| self.epoch, epoch_cost, pre_step_time, mean_loss_G, mean_loss_D)) | |||
| if self.epoch % self.save_checkpoint_epochs == 0 and self.rank == 0: | |||
| save_checkpoint(net.G.generator.G_A, os.path.join(self.ckpts_dir, f"G_A_{self.epoch}.ckpt")) | |||
| save_checkpoint(net.G.generator.G_B, os.path.join(self.ckpts_dir, f"G_B_{self.epoch}.ckpt")) | |||
| save_checkpoint(net.G.D_A, os.path.join(self.ckpts_dir, f"D_A_{self.epoch}.ckpt")) | |||
| save_checkpoint(net.G.D_B, os.path.join(self.ckpts_dir, f"D_B_{self.epoch}.ckpt")) | |||
| def visualizer(self, img_A, img_B, fake_A, fake_B): | |||
| if self.save_imgs and self.step % self.dataset_size == 0 and self.rank == 0: | |||
| save_image(img_A, os.path.join(self.imgs_dir, f"{self.epoch}_img_A.jpg")) | |||
| save_image(img_B, os.path.join(self.imgs_dir, f"{self.epoch}_img_B.jpg")) | |||
| save_image(fake_A, os.path.join(self.imgs_dir, f"{self.epoch}_fake_A.jpg")) | |||
| save_image(fake_B, os.path.join(self.imgs_dir, f"{self.epoch}_fake_B.jpg")) | |||
| def start_predict(self, direction): | |||
| self.predict_start_time = time.time() | |||
| self.direction = direction | |||
| self.info('==========start predict %s===============', self.direction) | |||
| def end_predict(self): | |||
| cost = (time.time() - self.predict_start_time) * 1000 | |||
| pre_step_cost = cost / self.dataset_size | |||
| self.info('total {} imgs cost {:.2f} ms, pre img cost {:.2f}'.format(self.dataset_size, cost, pre_step_cost)) | |||
| self.info('==========end predict %s===============\n', self.direction) | |||
| @@ -0,0 +1,141 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utils for cyclegan.""" | |||
| import random | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindspore import Tensor | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| class ImagePool(): | |||
| """ | |||
| This class implements an image buffer that stores previously generated images. | |||
| This buffer enables us to update discriminators using a history of generated images | |||
| rather than the ones produced by the latest generators. | |||
| """ | |||
| def __init__(self, pool_size): | |||
| """ | |||
| Initialize the ImagePool class | |||
| Args: | |||
| pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created. | |||
| """ | |||
| self.pool_size = pool_size | |||
| if self.pool_size > 0: # create an empty pool | |||
| self.num_imgs = 0 | |||
| self.images = [] | |||
| def query(self, images): | |||
| """ | |||
| Return an image from the pool. | |||
| Args: | |||
| images: the latest generated images from the generator | |||
| Returns images Tensor from the buffer. | |||
| By 50/100, the buffer will return input images. | |||
| By 50/100, the buffer will return images previously stored in the buffer, | |||
| and insert the current images to the buffer. | |||
| """ | |||
| if isinstance(images, Tensor): | |||
| images = images.asnumpy() | |||
| if self.pool_size == 0: # if the buffer size is 0, do nothing | |||
| return Tensor(images) | |||
| return_images = [] | |||
| for image in images: | |||
| if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer | |||
| self.num_imgs = self.num_imgs + 1 | |||
| self.images.append(image) | |||
| return_images.append(image) | |||
| else: | |||
| p = random.uniform(0, 1) | |||
| if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer | |||
| random_id = random.randint(0, self.pool_size - 1) # randint is inclusive | |||
| tmp = self.images[random_id].copy() | |||
| self.images[random_id] = image | |||
| return_images.append(tmp) | |||
| else: # by another 50% chance, the buffer will return the current image | |||
| return_images.append(image) | |||
| return_images = np.array(return_images) # collect all the images and return | |||
| if len(return_images.shape) != 4: | |||
| raise ValueError("img should be 4d, but get shape {}".format(return_images.shape)) | |||
| return Tensor(return_images) | |||
| def save_image(img, img_path): | |||
| """Save a numpy image to the disk | |||
| Parameters: | |||
| img (numpy array / Tensor): image to save. | |||
| image_path (str): the path of the image. | |||
| """ | |||
| if isinstance(img, Tensor): | |||
| img = decode_image(img) | |||
| elif not isinstance(img, np.ndarray): | |||
| raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img))) | |||
| img_pil = Image.fromarray(img) | |||
| img_pil.save(img_path) | |||
| def decode_image(img): | |||
| """Decode a [1, C, H, W] Tensor to image numpy array.""" | |||
| mean = 0.5 * 255 | |||
| std = 0.5 * 255 | |||
| return (img.asnumpy()[0] * std + mean).astype(np.uint8).transpose((1, 2, 0)) | |||
| def get_lr(args): | |||
| """Learning rate generator.""" | |||
| if args.lr_policy == 'linear': | |||
| lrs = [args.lr] * args.dataset_size * args.n_epochs | |||
| lr_epoch = 0 | |||
| for epoch in range(args.n_epochs_decay): | |||
| lr_epoch = args.lr * (args.n_epochs_decay - epoch) / args.n_epochs_decay | |||
| lrs += [lr_epoch] * args.dataset_size | |||
| lrs += [lr_epoch] * args.dataset_size * (args.max_epoch - args.n_epochs_decay - args.n_epochs) | |||
| return Tensor(np.array(lrs).astype(np.float32)) | |||
| return args.lr | |||
| def load_ckpt(args, G_A, G_B, D_A=None, D_B=None): | |||
| """Load parameter from checkpoint.""" | |||
| if args.G_A_ckpt is not None: | |||
| param_GA = load_checkpoint(args.G_A_ckpt) | |||
| load_param_into_net(G_A, param_GA) | |||
| if args.G_B_ckpt is not None: | |||
| param_GB = load_checkpoint(args.G_B_ckpt) | |||
| load_param_into_net(G_B, param_GB) | |||
| if D_A is not None and args.D_A_ckpt is not None: | |||
| param_DA = load_checkpoint(args.D_A_ckpt) | |||
| load_param_into_net(D_A, param_DA) | |||
| if D_B is not None and args.D_B_ckpt is not None: | |||
| param_DB = load_checkpoint(args.D_B_ckpt) | |||
| load_param_into_net(D_B, param_DB) | |||
| def load_teacher_ckpt(net, ckpt_path, teacher, student): | |||
| """Replace parameter name to teacher net and load parameter from checkpoint.""" | |||
| param = load_checkpoint(ckpt_path) | |||
| new_param = {} | |||
| for k, v in param.items(): | |||
| new_name = k.replace(student, teacher) | |||
| new_param_name = v.name.replace(student, teacher) | |||
| v.name = new_param_name | |||
| new_param[new_name] = v | |||
| load_param_into_net(net, new_param) | |||
| @@ -0,0 +1,74 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cycle GAN train.""" | |||
| import mindspore.nn as nn | |||
| from mindspore.common import set_seed | |||
| from src.models import get_generator, get_discriminator, Generator, TrainOneStepG, TrainOneStepD, \ | |||
| DiscriminatorLoss, GeneratorLoss | |||
| from src.utils import get_lr, get_args, Reporter, ImagePool, load_ckpt | |||
| from src.dataset import create_dataset | |||
| set_seed(1) | |||
| def train(): | |||
| """Train function.""" | |||
| args = get_args("train") | |||
| if args.need_profiler: | |||
| from mindspore.profiler.profiling import Profiler | |||
| profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True) | |||
| ds = create_dataset(args) | |||
| G_A = get_generator(args) | |||
| G_B = get_generator(args) | |||
| D_A = get_discriminator(args) | |||
| D_B = get_discriminator(args) | |||
| load_ckpt(args, G_A, G_B, D_A, D_B) | |||
| imgae_pool_A = ImagePool(args.pool_size) | |||
| imgae_pool_B = ImagePool(args.pool_size) | |||
| generator = Generator(G_A, G_B, args.lambda_idt > 0) | |||
| loss_D = DiscriminatorLoss(args, D_A, D_B) | |||
| loss_G = GeneratorLoss(args, generator, D_A, D_B) | |||
| optimizer_G = nn.Adam(generator.trainable_params(), get_lr(args), beta1=args.beta1) | |||
| optimizer_D = nn.Adam(loss_D.trainable_params(), get_lr(args), beta1=args.beta1) | |||
| net_G = TrainOneStepG(loss_G, generator, optimizer_G) | |||
| net_D = TrainOneStepD(loss_D, optimizer_D) | |||
| data_loader = ds.create_dict_iterator() | |||
| reporter = Reporter(args) | |||
| reporter.info('==========start training===============') | |||
| for _ in range(args.max_epoch): | |||
| reporter.epoch_start() | |||
| for data in data_loader: | |||
| img_A = data["image_A"] | |||
| img_B = data["image_B"] | |||
| res_G = net_G(img_A, img_B) | |||
| fake_A = res_G[0] | |||
| fake_B = res_G[1] | |||
| res_D = net_D(img_A, img_B, imgae_pool_A.query(fake_A), imgae_pool_B.query(fake_B)) | |||
| reporter.step_end(res_G, res_D) | |||
| reporter.visualizer(img_A, img_B, fake_A, fake_B) | |||
| reporter.epoch_end(net_G) | |||
| if args.need_profiler: | |||
| profiler.analyse() | |||
| break | |||
| reporter.info('==========end training===============') | |||
| if __name__ == "__main__": | |||
| train() | |||