diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h index 24c064307e..bb2f4b6dd5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h @@ -22,6 +22,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh" namespace mindspore { namespace kernel { template @@ -49,7 +50,23 @@ class TransposeGpuFwdKernel : public GpuKernel { reinterpret_cast(stream_ptr)), "cudaMemcpyAsync input_axis failed"); size_t size = input_size_ / sizeof(T); - CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast(stream_ptr)); + + size_t *h_input_shape = &input_shape_[0]; + size_t *h_input_axis = &input_axis_[0]; + // nhwc->nchw: 0,3,1,2 + if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 3 && h_input_axis[2] == 1 && + h_input_axis[3] == 2) { + CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, + reinterpret_cast(stream_ptr)); + } else if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 2 && h_input_axis[2] == 3 && + h_input_axis[3] == 1) { + // nchw->nhwc: 0,2,3,1 + CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, + reinterpret_cast(stream_ptr)); + } else { + CalTranspose(size, input, input_shape, input_axis, shape_size_, output, + reinterpret_cast(stream_ptr)); + } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cu new file mode 100644 index 0000000000..7ce046341b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cu @@ -0,0 +1,298 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "transpose_impl_opt.cuh" +#include "runtime/device/gpu/cuda_common.h" + +// Optimize nchw2nhwc && nhwc2nchw with tiling and shared memory. +// Firstly, combined 2 dims hw together, treat input and output as 3D tensor. +// Secondly, determine whether a matrix is a large matrix or a narrow matrix, +// which determines the chosen TileSize. +// Reason: tiling and shared memory can avoid uncoalesced global memory access. +// There are two stages of this kernel, load-to-shm and write-to-output. +// load-to-shm: Threads in a thread block work together to load input data tile to shared mem. +// write-to-output: Threads in a thread block work together to write shared mem to output tile. +// because of the shared mem usage, The access to both input and output memory can be coalesced. + +// SimpleTransposeKernel for small matrix +template +__global__ void SimpleTransposeKernel(const size_t size, const T *input, const size_t *input_shape, + const size_t *input_axis, const size_t shape_size, T *output) { + size_t pos_size; + size_t temp_pos; + size_t newpos; + size_t newpos_size; + size_t pos_array[4]; + // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + + // posArray[1] * input_shape[2] * input_shape[3] + + // posArray[2] * input_shape[3] + + // posArray[3] + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + temp_pos = pos; + pos_size = size / input_shape[0]; // C * H * W + pos_array[0] = temp_pos / pos_size; // i / (CHW) + for (size_t i = 1; i < shape_size; i++) { + temp_pos -= pos_array[i - 1] * pos_size; + pos_size = pos_size / input_shape[i]; + pos_array[i] = temp_pos / pos_size; + } + newpos = pos_array[input_axis[shape_size - 1]]; + newpos_size = 1; + for (int64_t j = shape_size - 2; j >= 0; j--) { + newpos_size *= input_shape[input_axis[j + 1]]; + newpos += pos_array[input_axis[j]] * newpos_size; + } + output[newpos] = *(input + pos); + } + return; +} + +__forceinline__ __device__ int TensorIdxToOneDimIdx(int ndims, const int *idx, const int *dims) { + int flat_idx = idx[0]; + for (int i = 1; i < ndims; i++) { + flat_idx = flat_idx * dims[i] + idx[i]; + } + return flat_idx; +} +__forceinline__ __device__ void OneDimIdxToTensorIdx(int ndims, int idx, const int *dims, int *out_tensor_idx) { + for (int i = ndims - 1; i >= 0; i--) { + int new_idx = idx / dims[i]; + out_tensor_idx[i] = idx - dims[i] * new_idx; + idx = new_idx; + } +} + +template +__global__ void Swap3DTensorLast2DimKernel_shared(const T *input, int NumThreads, int TileHeight, int TileWidth, + int input_dims_0, int input_dims_1, int input_dims_2, T *output) { + extern __shared__ unsigned char sdata_uchar[]; + // shm_tile[TileHeight][TileWidth + 1]: to avoid bank conflict in write-to-output period + T *shm_tile = reinterpret_cast(sdata_uchar); + int NumRowsPerLoadLoop = NumThreads / TileWidth; // the number of shm rows that all threads can load into shm once + int NumColsPerWriteLoop = + NumThreads / TileHeight; // the number of shm cols that all threads can write into output once + int load_thread_num_align = NumRowsPerLoadLoop * TileWidth; // use align num threads in load-to-shm period + int write_thread_num_align = NumColsPerWriteLoop * TileHeight; // use align num threads in write-to-output period + int tid = threadIdx.x; + int input_dims[3] = {input_dims_0, input_dims_1, input_dims_2}; + int output_dims[3] = {input_dims[0], input_dims[2], input_dims[1]}; + int input_dims_in_tiles[3] = {input_dims[0], (input_dims[1] + TileHeight - 1) / TileHeight, + (input_dims[2] + TileWidth - 1) / TileWidth}; + int input_tile_idx[3]; + OneDimIdxToTensorIdx(3, blockIdx.x, input_dims_in_tiles, input_tile_idx); + int input_tile_origin[3] = {input_tile_idx[0], input_tile_idx[1] * TileHeight, input_tile_idx[2] * TileWidth}; + int input_block_start_idx = TensorIdxToOneDimIdx(3, input_tile_origin, input_dims); // input idx of this thread block + bool full_tile = true; + int tile_width = TileWidth; + // Only the last row or column may not have the full size + // boundary process + if (input_tile_idx[2] == input_dims_in_tiles[2] - 1) { + tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileWidth; + full_tile &= false; + } + int tile_height = TileHeight; + if (input_tile_idx[1] == input_dims_in_tiles[1] - 1) { + tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileHeight; + full_tile &= false; + } + // load-to-shm: each block load input data into shared mem(loop) + if (tid < load_thread_num_align) { + // Map task blocks to thread blocks. + // organize threads to n*TileWidth + int shm_row_id = tid / TileWidth; // shem_row_id, also the block row_id of input + int shm_col_id = tid % TileWidth; // shem_col_id, also the block col_id of input + int input_idx = input_block_start_idx + shm_row_id * input_dims[2] + shm_col_id; // the input idx of this thread + int input_step = NumRowsPerLoadLoop * input_dims[2]; + if (full_tile) { // thread blocks responses for inner tiles +#pragma unroll + for (int row_id = shm_row_id; row_id < (TileHeight); + row_id += NumRowsPerLoadLoop) { // move to the next pass, loop + // shm_tile[row_id][shm_col_id] + shm_tile[row_id * (TileWidth + 1) + shm_col_id] = + input[input_idx]; // each thread load one input data into shared mem + input_idx += input_step; // calculate the next input idx this thread should load + } + } else { // boundary process: thread blocks responses for edge tiles + if (shm_col_id < tile_width) { + for (int row_id = shm_row_id; row_id < (tile_height); row_id += NumRowsPerLoadLoop) { + // shm_tile[row_id][shm_col_id] + shm_tile[row_id * (TileWidth + 1) + shm_col_id] = input[input_idx]; + input_idx += input_step; + } + } + } + } + __syncthreads(); + // load-to-shm: end + + // write-to-output: each block write shared mem into output(loop) + int output_tile_idx[3] = {input_tile_idx[0], input_tile_idx[2], input_tile_idx[1]}; + int output_tile_origin[3] = {output_tile_idx[0], output_tile_idx[1] * TileWidth, output_tile_idx[2] * TileHeight}; + int output_block_start_idx = TensorIdxToOneDimIdx(3, output_tile_origin, output_dims); + if (tid < write_thread_num_align) { + // organize threads to TileHeight*n1 + int shm_col_id = tid / TileHeight; // shm_col_id, also the block row_id of output + int shm_row_id = tid % TileHeight; // shm_row_id, also the block col_id of output + int output_idx = output_block_start_idx + shm_col_id * output_dims[2] + shm_row_id; + int output_step = NumColsPerWriteLoop * output_dims[2]; + if (full_tile) { +#pragma unroll + for (int col_id = shm_col_id; col_id < (TileWidth); + col_id += NumColsPerWriteLoop) { // move to the next pass, loop + // shm_tile[shm_row_id][col_id] + output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; // avoid bank conflict + output_idx += output_step; + } + } else { + if (shm_row_id < tile_height) { + for (int col_id = shm_col_id; col_id < (tile_width); col_id += NumColsPerWriteLoop) { + // shm_tile[shm_row_id][col_id]; + output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; + output_idx += output_step; + } + } + } + } +} + +template +void Swap3DTensorLast2Dim(const size_t size, const size_t shape_size, int *combined_dims, const T *d_input, + const size_t *input_shape, const size_t *input_axis, const size_t *d_input_shape, + const size_t *d_input_axis, T *d_output, cudaStream_t cuda_stream) { + static const int kMinDimensionToUseTiles = 16; + static const int kMinDimensionToUseRectTiles = 96; + auto short_side = std::min(combined_dims[1], combined_dims[2]); + auto long_side = std::max(combined_dims[1], combined_dims[2]); + // large matrix + // Both dims are greater than 16 && cuda blocks have enough shared mem. + constexpr int kTileSizeLargeMat = 32; + constexpr int kNumThreadsLargeMat = 256; + auto ShmemReqLargeMat = kTileSizeLargeMat * (kTileSizeLargeMat + 1) * sizeof(T); + bool is_large_matrix = short_side >= kMinDimensionToUseTiles && ShmemReqLargeMat <= SHARED_MEM_PER_BLOCK; + // narrow matrix + // one dim less than 16 && one dim greater than 96(narrow) + constexpr int kTileSizeNarrowMatLongSide = 128; + const int kTileSizeNarrowMatShortSide = short_side; + constexpr int kNumThreadsNarrowMat = kTileSizeNarrowMatLongSide; + auto ShmemReqNarrowMat = kTileSizeNarrowMatLongSide * (kTileSizeNarrowMatShortSide + 1) * sizeof(T); + bool is_narrow_matrix = short_side < kMinDimensionToUseTiles && long_side >= kMinDimensionToUseRectTiles && + ShmemReqNarrowMat <= SHARED_MEM_PER_BLOCK; + if (is_large_matrix) { + int input_dims_in_tiles[3] = {combined_dims[0], (combined_dims[1] + kTileSizeLargeMat - 1) / kTileSizeLargeMat, + (combined_dims[2] + kTileSizeLargeMat - 1) / kTileSizeLargeMat}; + int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + Swap3DTensorLast2DimKernel_shared<<>>( + d_input, kNumThreadsLargeMat, kTileSizeLargeMat, kTileSizeLargeMat, combined_dims[0], combined_dims[1], + combined_dims[2], d_output); + } else if (is_narrow_matrix) { + int input_dims_in_tiles[3] = {combined_dims[0], 1, + (long_side + kTileSizeNarrowMatLongSide - 1) / kTileSizeNarrowMatLongSide}; + int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + int TileHeight, TileWidth; + if (long_side == combined_dims[1]) { + TileHeight = kTileSizeNarrowMatLongSide; + TileWidth = short_side; + } else { + TileHeight = short_side; + TileWidth = kTileSizeNarrowMatLongSide; + } + Swap3DTensorLast2DimKernel_shared<<>>( + d_input, kNumThreadsNarrowMat, TileHeight, TileWidth, combined_dims[0], combined_dims[1], combined_dims[2], + d_output); + } else { + SimpleTransposeKernel<<>>(size, d_input, d_input_shape, d_input_axis, + shape_size, d_output); + } + return; +} +// specific for NHWC -> NCHW +template +void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, + T *d_output, cudaStream_t cuda_stream) { + int combined_dims[3]; + combined_dims[0] = input_shape[0]; // N + combined_dims[1] = input_shape[1]; // HW + for (unsigned int i = 2; i < shape_size - 1; i++) { + combined_dims[1] *= input_shape[i]; + } + combined_dims[2] = input_shape[shape_size - 1]; // C + Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis, + d_output, cuda_stream); +} +// specific for NCHW -> NHWC +template +void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, + T *d_output, cudaStream_t cuda_stream) { + int combined_dims[3]; + combined_dims[0] = input_shape[0]; // N + combined_dims[1] = input_shape[1]; // C + combined_dims[2] = input_shape[2]; // HW + for (unsigned int i = 3; i < shape_size; ++i) { + combined_dims[2] *= input_shape[i]; + } + Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis, + d_output, cuda_stream); +} + +template void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const double *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, double *d_output, + cudaStream_t cuda_stream); +template void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const float *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, float *d_output, + cudaStream_t cuda_stream); +template void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const half *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, + cudaStream_t cuda_stream); +template void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const int *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, int *d_output, + cudaStream_t cuda_stream); +template void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const int64_t *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, + cudaStream_t cuda_stream); + +template void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const double *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, double *d_output, + cudaStream_t cuda_stream); +template void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const float *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, float *d_output, + cudaStream_t cuda_stream); +template void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const half *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, + cudaStream_t cuda_stream); +template void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const int *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, int *d_output, + cudaStream_t cuda_stream); +template void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const int64_t *d_input, + const size_t *input_shape, const size_t *input_axis, + const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh new file mode 100644 index 0000000000..e7c6306d29 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ + +#include + +#define TRANSPOSE_MAX_DIMENSION 100 +template +void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output, + cudaStream_t cuda_stream); + +template +void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, + const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ diff --git a/model_zoo/official/cv/xception/README.md b/model_zoo/official/cv/xception/README.md index 4f9e51187e..dd5049a6fb 100644 --- a/model_zoo/official/cv/xception/README.md +++ b/model_zoo/official/cv/xception/README.md @@ -5,7 +5,7 @@ - [Model architecture](#model-architecture) - [Dataset](#dataset) - [Features](#features) - - [Mixed Precision(Ascend)](#mixed-precisionascend) + - [Mixed Precision](#mixed-precisionascend) - [Environment Requirements](#environment-requirements) - [Script description](#script-description) - [Script and sample code](#script-and-sample-code) @@ -49,7 +49,7 @@ Dataset used can refer to paper. # [Features](#contents) -## [Mixed Precision(Ascend)](#contents) +## [Mixed Precision](#contents) The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. @@ -57,8 +57,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil # [Environment Requirements](#contents) -- Hardware(Ascend) - - Prepare hardware environment with Ascend. +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend, GPU or CPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: @@ -74,23 +74,30 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil └─Xception ├─README.md ├─scripts - ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) - ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) - └─run_eval.sh # launch evaluating with ascend platform + ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) + ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) + ├─run_train_gpu_fp32.sh # launch standalone or distributed fp32 training with gpu platform(1p or 8p) + ├─run_train_gpu_fp16.sh # launch standalone or distributed fp16 training with gpu platform(1p or 8p) + ├─run_eval.sh # launch evaluating with ascend platform + └─run_eval_gpu.sh # launch evaluating with gpu platform ├─src - ├─config.py # parameter configuration - ├─dataset.py # data preprocessing - ├─Xception.py # network definition - ├─loss.py # Customized CrossEntropy loss function - └─lr_generator.py # learning rate generator - ├─train.py # train net - ├─export.py # export net - └─eval.py # eval net + ├─config.py # parameter configuration + ├─dataset.py # data preprocessing + ├─Xception.py # network definition + ├─loss.py # Customized CrossEntropy loss function + └─lr_generator.py # learning rate generator + ├─train.py # train net + ├─export.py # export net + └─eval.py # eval net ``` ## [Script Parameters](#contents) +Parameters for both training and evaluation can be set in config.py. + +- Config on ascend + ```python Major parameters in train.py and config.py are: 'num_classes': 1000 # dataset class numbers @@ -113,6 +120,30 @@ Major parameters in train.py and config.py are: 'lr_end': 0.00004 # min bound of learning rate ``` +- Config on gpu + +```python +Major parameters in train.py and config.py are: +'num_classes': 1000 # dataset class numbers +'batch_size': 64 # input batchsize +'loss_scale': 1024 # loss scale +'momentum': 0.9 # momentum +'weight_decay': 1e-4 # weight decay +'epoch_size': 250 # total epoch numbers +'save_checkpoint': True # save checkpoint +'save_checkpoint_epochs': 1 # save checkpoint epochs +'keep_checkpoint_max': 5 # max numbers to keep checkpoints +'save_checkpoint_path': "./gpu-ckpt" # save checkpoint path +'warmup_epochs': 1 # warmup epoch numbers +'lr_decay_mode': "linear" # lr decay mode +'use_label_smooth': True # use label smooth +'finish_epoch': 0 # finished epochs numbers +'label_smooth_factor': 0.1 # label smoothing factor +'lr_init': 0.00004 # initiate learning rate +'lr_max': 0.4 # max bound of learning rate +'lr_end': 0.00004 # min bound of learning rate +``` + ## [Training process](#contents) ### Usage @@ -128,6 +159,25 @@ sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH ``` +- GPU: + +```shell +# fp32 distributed training example(8p) +sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional) + +# fp32 standalone training example +sh scripts/run_train_gpu_fp32.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional) + +# fp16 distributed training example(8p) +sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional) + +# fp16 standalone training example +sh scripts/run_train_gpu_fp16.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional) + +# infer example +sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH +``` + > Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). ### Launch @@ -137,6 +187,8 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH python: Ascend: python train.py --device_target Ascend --dataset_path /dataset/train + GPU: + python train.py --device_target GPU --dataset_path /dataset/train shell: Ascend: @@ -144,11 +196,18 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH # standalone training sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH + GPU: + # fp16 training example(8p) + sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATA_PATH + # fp32 training example(8p) + sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATA_PATH ``` ### Result -Training result will be stored in the example path. Checkpoints will be stored at `. /ckpt_0` by default, and training log will be redirected to `log.txt` like following. +Training result will be stored in the example path. Checkpoints will be stored at `./ckpt_0` for Ascend and `./gpu_ckpt` for GPU by default, and training log will be redirected to `log.txt` fo Ascend and `log_gpu.txt` for GPU like following. + +- Ascend: ``` shell epoch: 1 step: 1251, loss is 4.8427444 @@ -157,73 +216,100 @@ epoch: 2 step: 1251, loss is 4.0637593 epoch time: 598591.422 ms, per step time: 478.490 ms ``` +- GPU: + +``` shell +epoch: 1 step: 20018, loss is 5.479554 +epoch time: 5664051.330 ms, per step time: 282.948 ms +epoch: 2 step: 20018, loss is 5.179064 +epoch time: 5628609.779 ms, per step time: 281.177 ms +``` + ## [Eval process](#contents) ### Usage You can start training using python or shell scripts. The usage of shell scripts as follows: +- Ascend: + ```shell sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT ``` +- GPU: + +```shell +sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT +``` + ### Launch ```shell # eval example python: Ascend: python eval.py --device_target Ascend --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR + GPU: python eval.py --device_target GPU --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR shell: Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT + GPU: sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT ``` > checkpoint can be produced in training process. ### Result -Evaluation result will be stored in the example path, you can find result like the following in `eval.log`. +Evaluation result will be stored in the example path, you can find result like the following in `eval.log` on ascend and `eval_gpu.log` on gpu. + +- Evaluating with ascend ```shell result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc': 0.9485777243589744} ``` +- Evaluating with gpu + +```shell +result: {'Loss': 1.7846775874590903, 'Top_1_Acc': 0.798735595390525, 'Top_5_Acc': 0.9498439500640204} +``` + # [Model description](#contents) ## [Performance](#contents) ### Training Performance -| Parameters | Ascend | -| -------------------------- | ---------------------------------------------- | -| Model Version | Xception | -| Resource | HUAWEI CLOUD Modelarts | -| uploaded Date | 12/10/2020 | -| MindSpore Version | 1.1.0 | -| Dataset | 1200k images | -| Batch_size | 128 | -| Training Parameters | src/config.py | -| Optimizer | Momentum | -| Loss Function | CrossEntropySmooth | -| Loss | 1.78 | -| Accuracy (8p) | Top1[79.8%] Top5[94.8%] | -| Per step time (8p) | 479 ms/step | -| Total time (8p) | 42h | -| Params (M) | 180M | -| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | +| Parameters | Ascend | GPU | +| -------------------------- | ------------------------- | ------------------------- | +| Model Version | Xception | Xception | +| Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts | +| uploaded Date | 12/10/2020 | 02/09/2021 | +| MindSpore Version | 1.1.0 | 1.1.0 | +| Dataset | 1200k images | 1200k images | +| Batch_size | 128 | 64 | +| Training Parameters | src/config.py | src/config.py | +| Optimizer | Momentum | Momentum | +| Loss Function | CrossEntropySmooth | CrossEntropySmooth | +| Loss | 1.78 | 1.78 | +| Accuracy (8p) | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] | +| Per step time (8p) | 479 ms/step | 282 ms/step | +| Total time (8p) | 42h | 51h | +| Params (M) | 180M | 180M | +| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | #### Inference Performance -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | Xception | -| Resource | HUAWEI CLOUD Modelarts | -| Uploaded Date | 12/10/2020 | -| MindSpore Version | 1.1.0 | -| Dataset | 50k images | -| Batch_size | 128 | -| Accuracy | Top1[79.8%] Top5[94.8%] | -| Total time | 3mins | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------- | --------------------------- | +| Model Version | Xception | Xception | +| Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts | +| Uploaded Date | 12/10/2020 | 02/09/2021 | +| MindSpore Version | 1.1.0 | 1.1.0 | +| Dataset | 50k images | 50k images | +| Batch_size | 128 | 64 | +| Accuracy | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] | +| Total time | 3mins | 4.7mins | # [Description of Random Situation](#contents) diff --git a/model_zoo/official/cv/xception/eval.py b/model_zoo/official/cv/xception/eval.py index 1618d26ed2..75abc7d108 100644 --- a/model_zoo/official/cv/xception/eval.py +++ b/model_zoo/official/cv/xception/eval.py @@ -20,7 +20,7 @@ from mindspore.common import set_seed from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.Xception import xception -from src.config import config +from src.config import config_gpu, config_ascend from src.dataset import create_dataset from src.loss import CrossEntropySmooth @@ -28,14 +28,21 @@ set_seed(1) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Image classification') - parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') + parser.add_argument('--device_target', type=str, default='GPU', help='Device target') parser.add_argument('--device_id', type=int, default=0, help='Device id') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') + args_opt = parser.parse_args() + if args_opt.device_target == "Ascend": + config = config_ascend + elif args_opt.device_target == "GPU": + config = config_gpu + else: + raise ValueError("Unsupported device_target.") context.set_context(device_id=args_opt.device_id) - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) # create dataset dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0) @@ -59,5 +66,5 @@ if __name__ == '__main__': model = Model(net, loss_fn=loss, metrics=eval_metrics) # eval model - res = model.eval(dataset, dataset_sink_mode=False) + res = model.eval(dataset, dataset_sink_mode=True) print("result:", res, "ckpt=", args_opt.checkpoint_path) diff --git a/model_zoo/official/cv/xception/export.py b/model_zoo/official/cv/xception/export.py index ee57ce8a55..cfd89310e8 100644 --- a/model_zoo/official/cv/xception/export.py +++ b/model_zoo/official/cv/xception/export.py @@ -19,7 +19,7 @@ import numpy as np from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from src.Xception import xception -from src.config import config +from src.config import config_ascend, config_gpu parser = argparse.ArgumentParser(description="Image classification") parser.add_argument("--device_id", type=int, default=0, help="Device id") @@ -29,13 +29,19 @@ parser.add_argument("--height", type=int, default=299, help="input height") parser.add_argument("--file_name", type=str, default="xception", help="xception output file name.") parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format") -parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", +parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="GPU", help="device target") + args = parser.parse_args() +if args.device_target == "Ascend": + config = config_ascend +elif args.device_target == "GPU": + config = config_gpu +else: + raise ValueError("Unsupported device_target.") context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) -if args.device_target == "Ascend": - context.set_context(device_id=args.device_id) +context.set_context(device_id=args.device_id) if __name__ == "__main__": # define net diff --git a/model_zoo/official/cv/xception/scripts/run_eval_gpu.sh b/model_zoo/official/cv/xception/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..9860530eae --- /dev/null +++ b/model_zoo/official/cv/xception/scripts/run_eval_gpu.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT=$3 + +rm -rf eval_output +mkdir ./eval_output +cd ./eval_output || exit +echo "start evaluating model..." + +python ../eval.py \ + --device_target=GPU \ + --device_id=$DEVICE_ID \ + --checkpoint_path=$PATH_CHECKPOINT \ + --dataset_path=$DATA_DIR + #--dataset_path=$DATA_DIR > eval_gpu.log 2>&1 & +cd ../ diff --git a/model_zoo/official/cv/xception/scripts/run_train_gpu_fp16.sh b/model_zoo/official/cv/xception/scripts/run_train_gpu_fp16.sh new file mode 100644 index 0000000000..e718e55cbb --- /dev/null +++ b/model_zoo/official/cv/xception/scripts/run_train_gpu_fp16.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_NUM=$1 +export RANK_SIZE=$1 +DATA_DIR=$2 +#DATA_DIR=/gdata/ImageNet2012/train/ +PYTHON_EXEC=python + +if [ $1 -gt 1 ] +then + PATH_TRAIN="./train_distribute_gpu_fp16"$(date "+%Y%m%d%H%M%S") + if [ -d $PATH_TRAIN ]; + then + rm -rf $PATH_TRAIN + fi + mkdir $PATH_TRAIN + cd $PATH_TRAIN || exit + echo "start distributed training on $DEVICE_NUM gpus" + + mpirun -n $1 --allow-run-as-root \ + --output-filename gpu_fp16_dist_log \ + --merge-stderr-to-stdout \ + ${PYTHON_EXEC} ../train.py \ + --is_distributed \ + --device_target=GPU \ + --dataset_path=$DATA_DIR > gpu_fp16_dist_log.txt 2>&1 & +else + PATH_TRAIN="./train_standalone_gpu_fp16"$(date "+%Y%m%d%H%M%S") + if [ -d $PATH_TRAIN ]; + then + rm -rf $PATH_TRAIN + fi + mkdir $PATH_TRAIN + cd $PATH_TRAIN || exit + echo "start training standalone on gpu device $DEVICE_ID" + + ${PYTHON_EXEC} ../train.py \ + --device_target=GPU \ + --dataset_path=$DATA_DIR > gpu_fp16_standard_log.txt 2>&1 & +fi +cd ../ \ No newline at end of file diff --git a/model_zoo/official/cv/xception/scripts/run_train_gpu_fp32.sh b/model_zoo/official/cv/xception/scripts/run_train_gpu_fp32.sh new file mode 100644 index 0000000000..56fe238178 --- /dev/null +++ b/model_zoo/official/cv/xception/scripts/run_train_gpu_fp32.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_NUM=$1 +export RANK_SIZE=$1 +DATA_DIR=$2 +#DATA_DIR=/gdata/ImageNet2012/train/ +PYTHON_EXEC=python + +if [ $1 -gt 1 ] +then + PATH_TRAIN="./train_distribute_gpu_fp32"$(date "+%Y%m%d%H%M%S") + if [ -d $PATH_TRAIN ]; + then + rm -rf $PATH_TRAIN + fi + mkdir $PATH_TRAIN + cd $PATH_TRAIN || exit + echo "start distributed training on $DEVICE_NUM gpus" + + mpirun -n $1 --allow-run-as-root \ + --output-filename gpu_fp32_dist_log \ + --merge-stderr-to-stdout \ + ${PYTHON_EXEC} ../train.py \ + --is_distributed \ + --device_target=GPU \ + --is_fp32 \ + --dataset_path=$DATA_DIR > gpu_fp32_dist_log.txt 2>&1 & +else + PATH_TRAIN="./train_standalone_gpu_fp32"$(date "+%Y%m%d%H%M%S") + if [ -d $PATH_TRAIN ]; + then + rm -rf $PATH_TRAIN + fi + mkdir $PATH_TRAIN + cd $PATH_TRAIN || exit + echo "start training standalone on gpu device $DEVICE_ID" + + #${PYTHON_EXEC} ../train.py \ --dataset_path=/gdata/ImageNet2012/train/ + ${PYTHON_EXEC} ../train.py \ + --device_target=GPU \ + --is_fp32 \ + --dataset_path=$DATA_DIR > gpu_fp32_standard_log.txt 2>&1 & +fi +cd ../ \ No newline at end of file diff --git a/model_zoo/official/cv/xception/src/config.py b/model_zoo/official/cv/xception/src/config.py index 2b82494104..abbe492546 100644 --- a/model_zoo/official/cv/xception/src/config.py +++ b/model_zoo/official/cv/xception/src/config.py @@ -17,8 +17,30 @@ network config setting, will be used in train.py and eval.py """ from easydict import EasyDict as ed -# config for Xception, imagenet2012. -config = ed({ +# config on GPU for Xception, imagenet2012. +config_gpu = ed({ + "class_num": 1000, + "batch_size": 64, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 250, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 5, + "save_checkpoint_path": "./gpu-ckpt", + "warmup_epochs": 1, + "lr_decay_mode": "linear", + "use_label_smooth": True, + "finish_epoch": 0, + "label_smooth_factor": 0.1, + "lr_init": 0.00004, + "lr_max": 0.4, + "lr_end": 0.00004 +}) + +# config on Ascend for Xception, imagenet2012. +config_ascend = ed({ "class_num": 1000, "batch_size": 128, "loss_scale": 1024, diff --git a/model_zoo/official/cv/xception/train.py b/model_zoo/official/cv/xception/train.py index d8e212e5e8..976bae365f 100644 --- a/model_zoo/official/cv/xception/train.py +++ b/model_zoo/official/cv/xception/train.py @@ -29,7 +29,7 @@ from mindspore.common import set_seed from src.lr_generator import get_lr from src.Xception import xception -from src.config import config +from src.config import config_gpu, config_ascend from src.dataset import create_dataset from src.loss import CrossEntropySmooth @@ -38,83 +38,103 @@ set_seed(1) if __name__ == '__main__': parser = argparse.ArgumentParser(description='image classification training') parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') - parser.add_argument('--device_target', type=str, default='Ascend', help='run platform') + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='run platform, (Default: Ascend)') parser.add_argument('--dataset_path', type=str, default=None, help='dataset path') + parser.add_argument("--is_fp32", action='store_true', default=False, help='fp32 training, add --is_fp32') parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') - args_opt = parser.parse_args() + args_opt = parser.parse_args() if args_opt.device_target == "Ascend": - #train on Ascend - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) + config = config_ascend + elif args_opt.device_target == "GPU": + config = config_gpu + else: + raise ValueError("Unsupported device_target.") - # init distributed - if args_opt.is_distributed: - if os.getenv('DEVICE_ID', "not_set").isdigit(): - context.set_context(device_id=int(os.getenv('DEVICE_ID'))) - init() - rank = get_rank() - group_size = get_group_size() - parallel_mode = ParallelMode.DATA_PARALLEL - context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) - else: - rank = 0 - group_size = 1 - if os.getenv('DEVICE_ID', "not_set").isdigit(): - context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + # init distributed + if args_opt.is_distributed: + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + init() + rank = get_rank() + group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) + else: + rank = 0 + group_size = 1 + context.set_context(device_id=0) + # if os.getenv('DEVICE_ID', "not_set").isdigit(): + # context.set_context(device_id=int(os.getenv('DEVICE_ID'))) - # define network - net = xception(class_num=config.class_num) + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) + # define network + net = xception(class_num=config.class_num) + if args_opt.device_target == "Ascend": net.to_float(mstype.float16) - # define loss - if not config.use_label_smooth: - config.label_smooth_factor = 0.0 - loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + # define loss + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) - # define dataset - dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size, - device_num=group_size, rank=rank) - step_size = dataset.get_dataset_size() + # define dataset + dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size, + device_num=group_size, rank=rank) + step_size = dataset.get_dataset_size() - # resume - if args_opt.resume: - ckpt = load_checkpoint(args_opt.resume) - load_param_into_net(net, ckpt) + # resume + if args_opt.resume: + ckpt = load_checkpoint(args_opt.resume) + load_param_into_net(net, ckpt) - # get learning rate - loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(lr_init=config.lr_init, - lr_end=config.lr_end, - lr_max=config.lr_max, - warmup_epochs=config.warmup_epochs, - total_epochs=config.epoch_size, - steps_per_epoch=step_size, - lr_decay_mode=config.lr_decay_mode, - global_step=config.finish_epoch * step_size)) + # get learning rate + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(lr_init=config.lr_init, + lr_end=config.lr_end, + lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, + total_epochs=config.epoch_size, + steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode, + global_step=config.finish_epoch * step_size)) - # define optimization + # define optimization and model + if args_opt.device_target == "Ascend": opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale) - - # define model model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level='O3', keep_batchnorm_fp32=True) + elif args_opt.device_target == "GPU": + if args_opt.is_fp32: + opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + else: + opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, + amp_level='O2', keep_batchnorm_fp32=True) - # define callbacks - cb = [TimeMonitor(), LossMonitor()] - if config.save_checkpoint: + # define callbacks + cb = [TimeMonitor(), LossMonitor()] + if config.save_checkpoint: + if args_opt.device_target == "Ascend": save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') - config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck) + elif args_opt.device_target == "GPU": + if args_opt.is_fp32: + save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp32/' + 'model_' + str(rank)) + else: + save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp16/' + 'model_' + str(rank)) + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck) - # begin train - if args_opt.is_distributed: - if rank == 0: - cb += [ckpt_cb] - model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) - else: + # begin train + print("begin train") + if args_opt.is_distributed: + if rank == 0: cb += [ckpt_cb] - model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) - print("train success") + model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) else: - raise ValueError("Unsupported device_target.") + cb += [ckpt_cb] + model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) + print("train success")