From: @riesman Reviewed-by: Signed-off-by:pull/12487/MERGE
| @@ -22,6 +22,7 @@ | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| @@ -49,7 +50,23 @@ class TransposeGpuFwdKernel : public GpuKernel { | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync input_axis failed"); | |||
| size_t size = input_size_ / sizeof(T); | |||
| CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| size_t *h_input_shape = &input_shape_[0]; | |||
| size_t *h_input_axis = &input_axis_[0]; | |||
| // nhwc->nchw: 0,3,1,2 | |||
| if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 3 && h_input_axis[2] == 1 && | |||
| h_input_axis[3] == 2) { | |||
| CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 2 && h_input_axis[2] == 3 && | |||
| h_input_axis[3] == 1) { | |||
| // nchw->nhwc: 0,2,3,1 | |||
| CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalTranspose(size, input, input_shape, input_axis, shape_size_, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -0,0 +1,298 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <cstdint> | |||
| #include <vector> | |||
| #include <limits> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| #include "transpose_impl_opt.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| // Optimize nchw2nhwc && nhwc2nchw with tiling and shared memory. | |||
| // Firstly, combined 2 dims hw together, treat input and output as 3D tensor. | |||
| // Secondly, determine whether a matrix is a large matrix or a narrow matrix, | |||
| // which determines the chosen TileSize. | |||
| // Reason: tiling and shared memory can avoid uncoalesced global memory access. | |||
| // There are two stages of this kernel, load-to-shm and write-to-output. | |||
| // load-to-shm: Threads in a thread block work together to load input data tile to shared mem. | |||
| // write-to-output: Threads in a thread block work together to write shared mem to output tile. | |||
| // because of the shared mem usage, The access to both input and output memory can be coalesced. | |||
| // SimpleTransposeKernel for small matrix | |||
| template <typename T> | |||
| __global__ void SimpleTransposeKernel(const size_t size, const T *input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t shape_size, T *output) { | |||
| size_t pos_size; | |||
| size_t temp_pos; | |||
| size_t newpos; | |||
| size_t newpos_size; | |||
| size_t pos_array[4]; | |||
| // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + | |||
| // posArray[1] * input_shape[2] * input_shape[3] + | |||
| // posArray[2] * input_shape[3] + | |||
| // posArray[3] | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| temp_pos = pos; | |||
| pos_size = size / input_shape[0]; // C * H * W | |||
| pos_array[0] = temp_pos / pos_size; // i / (CHW) | |||
| for (size_t i = 1; i < shape_size; i++) { | |||
| temp_pos -= pos_array[i - 1] * pos_size; | |||
| pos_size = pos_size / input_shape[i]; | |||
| pos_array[i] = temp_pos / pos_size; | |||
| } | |||
| newpos = pos_array[input_axis[shape_size - 1]]; | |||
| newpos_size = 1; | |||
| for (int64_t j = shape_size - 2; j >= 0; j--) { | |||
| newpos_size *= input_shape[input_axis[j + 1]]; | |||
| newpos += pos_array[input_axis[j]] * newpos_size; | |||
| } | |||
| output[newpos] = *(input + pos); | |||
| } | |||
| return; | |||
| } | |||
| __forceinline__ __device__ int TensorIdxToOneDimIdx(int ndims, const int *idx, const int *dims) { | |||
| int flat_idx = idx[0]; | |||
| for (int i = 1; i < ndims; i++) { | |||
| flat_idx = flat_idx * dims[i] + idx[i]; | |||
| } | |||
| return flat_idx; | |||
| } | |||
| __forceinline__ __device__ void OneDimIdxToTensorIdx(int ndims, int idx, const int *dims, int *out_tensor_idx) { | |||
| for (int i = ndims - 1; i >= 0; i--) { | |||
| int new_idx = idx / dims[i]; | |||
| out_tensor_idx[i] = idx - dims[i] * new_idx; | |||
| idx = new_idx; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void Swap3DTensorLast2DimKernel_shared(const T *input, int NumThreads, int TileHeight, int TileWidth, | |||
| int input_dims_0, int input_dims_1, int input_dims_2, T *output) { | |||
| extern __shared__ unsigned char sdata_uchar[]; | |||
| // shm_tile[TileHeight][TileWidth + 1]: to avoid bank conflict in write-to-output period | |||
| T *shm_tile = reinterpret_cast<T*>(sdata_uchar); | |||
| int NumRowsPerLoadLoop = NumThreads / TileWidth; // the number of shm rows that all threads can load into shm once | |||
| int NumColsPerWriteLoop = | |||
| NumThreads / TileHeight; // the number of shm cols that all threads can write into output once | |||
| int load_thread_num_align = NumRowsPerLoadLoop * TileWidth; // use align num threads in load-to-shm period | |||
| int write_thread_num_align = NumColsPerWriteLoop * TileHeight; // use align num threads in write-to-output period | |||
| int tid = threadIdx.x; | |||
| int input_dims[3] = {input_dims_0, input_dims_1, input_dims_2}; | |||
| int output_dims[3] = {input_dims[0], input_dims[2], input_dims[1]}; | |||
| int input_dims_in_tiles[3] = {input_dims[0], (input_dims[1] + TileHeight - 1) / TileHeight, | |||
| (input_dims[2] + TileWidth - 1) / TileWidth}; | |||
| int input_tile_idx[3]; | |||
| OneDimIdxToTensorIdx(3, blockIdx.x, input_dims_in_tiles, input_tile_idx); | |||
| int input_tile_origin[3] = {input_tile_idx[0], input_tile_idx[1] * TileHeight, input_tile_idx[2] * TileWidth}; | |||
| int input_block_start_idx = TensorIdxToOneDimIdx(3, input_tile_origin, input_dims); // input idx of this thread block | |||
| bool full_tile = true; | |||
| int tile_width = TileWidth; | |||
| // Only the last row or column may not have the full size | |||
| // boundary process | |||
| if (input_tile_idx[2] == input_dims_in_tiles[2] - 1) { | |||
| tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileWidth; | |||
| full_tile &= false; | |||
| } | |||
| int tile_height = TileHeight; | |||
| if (input_tile_idx[1] == input_dims_in_tiles[1] - 1) { | |||
| tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileHeight; | |||
| full_tile &= false; | |||
| } | |||
| // load-to-shm: each block load input data into shared mem(loop) | |||
| if (tid < load_thread_num_align) { | |||
| // Map task blocks to thread blocks. | |||
| // organize threads to n*TileWidth | |||
| int shm_row_id = tid / TileWidth; // shem_row_id, also the block row_id of input | |||
| int shm_col_id = tid % TileWidth; // shem_col_id, also the block col_id of input | |||
| int input_idx = input_block_start_idx + shm_row_id * input_dims[2] + shm_col_id; // the input idx of this thread | |||
| int input_step = NumRowsPerLoadLoop * input_dims[2]; | |||
| if (full_tile) { // thread blocks responses for inner tiles | |||
| #pragma unroll | |||
| for (int row_id = shm_row_id; row_id < (TileHeight); | |||
| row_id += NumRowsPerLoadLoop) { // move to the next pass, loop | |||
| // shm_tile[row_id][shm_col_id] | |||
| shm_tile[row_id * (TileWidth + 1) + shm_col_id] = | |||
| input[input_idx]; // each thread load one input data into shared mem | |||
| input_idx += input_step; // calculate the next input idx this thread should load | |||
| } | |||
| } else { // boundary process: thread blocks responses for edge tiles | |||
| if (shm_col_id < tile_width) { | |||
| for (int row_id = shm_row_id; row_id < (tile_height); row_id += NumRowsPerLoadLoop) { | |||
| // shm_tile[row_id][shm_col_id] | |||
| shm_tile[row_id * (TileWidth + 1) + shm_col_id] = input[input_idx]; | |||
| input_idx += input_step; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| __syncthreads(); | |||
| // load-to-shm: end | |||
| // write-to-output: each block write shared mem into output(loop) | |||
| int output_tile_idx[3] = {input_tile_idx[0], input_tile_idx[2], input_tile_idx[1]}; | |||
| int output_tile_origin[3] = {output_tile_idx[0], output_tile_idx[1] * TileWidth, output_tile_idx[2] * TileHeight}; | |||
| int output_block_start_idx = TensorIdxToOneDimIdx(3, output_tile_origin, output_dims); | |||
| if (tid < write_thread_num_align) { | |||
| // organize threads to TileHeight*n1 | |||
| int shm_col_id = tid / TileHeight; // shm_col_id, also the block row_id of output | |||
| int shm_row_id = tid % TileHeight; // shm_row_id, also the block col_id of output | |||
| int output_idx = output_block_start_idx + shm_col_id * output_dims[2] + shm_row_id; | |||
| int output_step = NumColsPerWriteLoop * output_dims[2]; | |||
| if (full_tile) { | |||
| #pragma unroll | |||
| for (int col_id = shm_col_id; col_id < (TileWidth); | |||
| col_id += NumColsPerWriteLoop) { // move to the next pass, loop | |||
| // shm_tile[shm_row_id][col_id] | |||
| output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; // avoid bank conflict | |||
| output_idx += output_step; | |||
| } | |||
| } else { | |||
| if (shm_row_id < tile_height) { | |||
| for (int col_id = shm_col_id; col_id < (tile_width); col_id += NumColsPerWriteLoop) { | |||
| // shm_tile[shm_row_id][col_id]; | |||
| output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; | |||
| output_idx += output_step; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Swap3DTensorLast2Dim(const size_t size, const size_t shape_size, int *combined_dims, const T *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, const size_t *d_input_shape, | |||
| const size_t *d_input_axis, T *d_output, cudaStream_t cuda_stream) { | |||
| static const int kMinDimensionToUseTiles = 16; | |||
| static const int kMinDimensionToUseRectTiles = 96; | |||
| auto short_side = std::min(combined_dims[1], combined_dims[2]); | |||
| auto long_side = std::max(combined_dims[1], combined_dims[2]); | |||
| // large matrix | |||
| // Both dims are greater than 16 && cuda blocks have enough shared mem. | |||
| constexpr int kTileSizeLargeMat = 32; | |||
| constexpr int kNumThreadsLargeMat = 256; | |||
| auto ShmemReqLargeMat = kTileSizeLargeMat * (kTileSizeLargeMat + 1) * sizeof(T); | |||
| bool is_large_matrix = short_side >= kMinDimensionToUseTiles && ShmemReqLargeMat <= SHARED_MEM_PER_BLOCK; | |||
| // narrow matrix | |||
| // one dim less than 16 && one dim greater than 96(narrow) | |||
| constexpr int kTileSizeNarrowMatLongSide = 128; | |||
| const int kTileSizeNarrowMatShortSide = short_side; | |||
| constexpr int kNumThreadsNarrowMat = kTileSizeNarrowMatLongSide; | |||
| auto ShmemReqNarrowMat = kTileSizeNarrowMatLongSide * (kTileSizeNarrowMatShortSide + 1) * sizeof(T); | |||
| bool is_narrow_matrix = short_side < kMinDimensionToUseTiles && long_side >= kMinDimensionToUseRectTiles && | |||
| ShmemReqNarrowMat <= SHARED_MEM_PER_BLOCK; | |||
| if (is_large_matrix) { | |||
| int input_dims_in_tiles[3] = {combined_dims[0], (combined_dims[1] + kTileSizeLargeMat - 1) / kTileSizeLargeMat, | |||
| (combined_dims[2] + kTileSizeLargeMat - 1) / kTileSizeLargeMat}; | |||
| int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; | |||
| Swap3DTensorLast2DimKernel_shared<T><<<TotalNumTiles, kNumThreadsLargeMat, ShmemReqLargeMat, cuda_stream>>>( | |||
| d_input, kNumThreadsLargeMat, kTileSizeLargeMat, kTileSizeLargeMat, combined_dims[0], combined_dims[1], | |||
| combined_dims[2], d_output); | |||
| } else if (is_narrow_matrix) { | |||
| int input_dims_in_tiles[3] = {combined_dims[0], 1, | |||
| (long_side + kTileSizeNarrowMatLongSide - 1) / kTileSizeNarrowMatLongSide}; | |||
| int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; | |||
| int TileHeight, TileWidth; | |||
| if (long_side == combined_dims[1]) { | |||
| TileHeight = kTileSizeNarrowMatLongSide; | |||
| TileWidth = short_side; | |||
| } else { | |||
| TileHeight = short_side; | |||
| TileWidth = kTileSizeNarrowMatLongSide; | |||
| } | |||
| Swap3DTensorLast2DimKernel_shared<T><<<TotalNumTiles, kNumThreadsNarrowMat, ShmemReqNarrowMat, cuda_stream>>>( | |||
| d_input, kNumThreadsNarrowMat, TileHeight, TileWidth, combined_dims[0], combined_dims[1], combined_dims[2], | |||
| d_output); | |||
| } else { | |||
| SimpleTransposeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, d_input, d_input_shape, d_input_axis, | |||
| shape_size, d_output); | |||
| } | |||
| return; | |||
| } | |||
| // specific for NHWC -> NCHW | |||
| template <typename T> | |||
| void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, | |||
| T *d_output, cudaStream_t cuda_stream) { | |||
| int combined_dims[3]; | |||
| combined_dims[0] = input_shape[0]; // N | |||
| combined_dims[1] = input_shape[1]; // HW | |||
| for (unsigned int i = 2; i < shape_size - 1; i++) { | |||
| combined_dims[1] *= input_shape[i]; | |||
| } | |||
| combined_dims[2] = input_shape[shape_size - 1]; // C | |||
| Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis, | |||
| d_output, cuda_stream); | |||
| } | |||
| // specific for NCHW -> NHWC | |||
| template <typename T> | |||
| void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, | |||
| T *d_output, cudaStream_t cuda_stream) { | |||
| int combined_dims[3]; | |||
| combined_dims[0] = input_shape[0]; // N | |||
| combined_dims[1] = input_shape[1]; // C | |||
| combined_dims[2] = input_shape[2]; // HW | |||
| for (unsigned int i = 3; i < shape_size; ++i) { | |||
| combined_dims[2] *= input_shape[i]; | |||
| } | |||
| Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis, | |||
| d_output, cuda_stream); | |||
| } | |||
| template void CalNHWC2NCHWInterface<double>(const size_t size, const size_t shape_size, const double *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, double *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNHWC2NCHWInterface<float>(const size_t size, const size_t shape_size, const float *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, float *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNHWC2NCHWInterface<half>(const size_t size, const size_t shape_size, const half *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNHWC2NCHWInterface<int>(const size_t size, const size_t shape_size, const int *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, int *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNHWC2NCHWInterface<int64_t>(const size_t size, const size_t shape_size, const int64_t *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNCHW2NHWCInterface<double>(const size_t size, const size_t shape_size, const double *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, double *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNCHW2NHWCInterface<float>(const size_t size, const size_t shape_size, const float *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, float *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNCHW2NHWCInterface<half>(const size_t size, const size_t shape_size, const half *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, half *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNCHW2NHWCInterface<int>(const size_t size, const size_t shape_size, const int *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, int *d_output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalNCHW2NHWCInterface<int64_t>(const size_t size, const size_t shape_size, const int64_t *d_input, | |||
| const size_t *input_shape, const size_t *input_axis, | |||
| const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ | |||
| #include <cuda_runtime.h> | |||
| #define TRANSPOSE_MAX_DIMENSION 100 | |||
| template <typename T> | |||
| void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape, | |||
| const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_ | |||
| @@ -5,7 +5,7 @@ | |||
| - [Model architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Features](#features) | |||
| - [Mixed Precision(Ascend)](#mixed-precisionascend) | |||
| - [Mixed Precision](#mixed-precisionascend) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Script description](#script-description) | |||
| - [Script and sample code](#script-and-sample-code) | |||
| @@ -49,7 +49,7 @@ Dataset used can refer to paper. | |||
| # [Features](#contents) | |||
| ## [Mixed Precision(Ascend)](#contents) | |||
| ## [Mixed Precision](#contents) | |||
| The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||
| @@ -57,8 +57,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend. | |||
| - Hardware(Ascend/GPU) | |||
| - Prepare hardware environment with Ascend, GPU or CPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| @@ -74,23 +74,30 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||
| └─Xception | |||
| ├─README.md | |||
| ├─scripts | |||
| ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) | |||
| ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) | |||
| └─run_eval.sh # launch evaluating with ascend platform | |||
| ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) | |||
| ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) | |||
| ├─run_train_gpu_fp32.sh # launch standalone or distributed fp32 training with gpu platform(1p or 8p) | |||
| ├─run_train_gpu_fp16.sh # launch standalone or distributed fp16 training with gpu platform(1p or 8p) | |||
| ├─run_eval.sh # launch evaluating with ascend platform | |||
| └─run_eval_gpu.sh # launch evaluating with gpu platform | |||
| ├─src | |||
| ├─config.py # parameter configuration | |||
| ├─dataset.py # data preprocessing | |||
| ├─Xception.py # network definition | |||
| ├─loss.py # Customized CrossEntropy loss function | |||
| └─lr_generator.py # learning rate generator | |||
| ├─train.py # train net | |||
| ├─export.py # export net | |||
| └─eval.py # eval net | |||
| ├─config.py # parameter configuration | |||
| ├─dataset.py # data preprocessing | |||
| ├─Xception.py # network definition | |||
| ├─loss.py # Customized CrossEntropy loss function | |||
| └─lr_generator.py # learning rate generator | |||
| ├─train.py # train net | |||
| ├─export.py # export net | |||
| └─eval.py # eval net | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| Parameters for both training and evaluation can be set in config.py. | |||
| - Config on ascend | |||
| ```python | |||
| Major parameters in train.py and config.py are: | |||
| 'num_classes': 1000 # dataset class numbers | |||
| @@ -113,6 +120,30 @@ Major parameters in train.py and config.py are: | |||
| 'lr_end': 0.00004 # min bound of learning rate | |||
| ``` | |||
| - Config on gpu | |||
| ```python | |||
| Major parameters in train.py and config.py are: | |||
| 'num_classes': 1000 # dataset class numbers | |||
| 'batch_size': 64 # input batchsize | |||
| 'loss_scale': 1024 # loss scale | |||
| 'momentum': 0.9 # momentum | |||
| 'weight_decay': 1e-4 # weight decay | |||
| 'epoch_size': 250 # total epoch numbers | |||
| 'save_checkpoint': True # save checkpoint | |||
| 'save_checkpoint_epochs': 1 # save checkpoint epochs | |||
| 'keep_checkpoint_max': 5 # max numbers to keep checkpoints | |||
| 'save_checkpoint_path': "./gpu-ckpt" # save checkpoint path | |||
| 'warmup_epochs': 1 # warmup epoch numbers | |||
| 'lr_decay_mode': "linear" # lr decay mode | |||
| 'use_label_smooth': True # use label smooth | |||
| 'finish_epoch': 0 # finished epochs numbers | |||
| 'label_smooth_factor': 0.1 # label smoothing factor | |||
| 'lr_init': 0.00004 # initiate learning rate | |||
| 'lr_max': 0.4 # max bound of learning rate | |||
| 'lr_end': 0.00004 # min bound of learning rate | |||
| ``` | |||
| ## [Training process](#contents) | |||
| ### Usage | |||
| @@ -128,6 +159,25 @@ sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| ``` | |||
| - GPU: | |||
| ```shell | |||
| # fp32 distributed training example(8p) | |||
| sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional) | |||
| # fp32 standalone training example | |||
| sh scripts/run_train_gpu_fp32.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional) | |||
| # fp16 distributed training example(8p) | |||
| sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional) | |||
| # fp16 standalone training example | |||
| sh scripts/run_train_gpu_fp16.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional) | |||
| # infer example | |||
| sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH | |||
| ``` | |||
| > Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). | |||
| ### Launch | |||
| @@ -137,6 +187,8 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| python: | |||
| Ascend: | |||
| python train.py --device_target Ascend --dataset_path /dataset/train | |||
| GPU: | |||
| python train.py --device_target GPU --dataset_path /dataset/train | |||
| shell: | |||
| Ascend: | |||
| @@ -144,11 +196,18 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| # standalone training | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| GPU: | |||
| # fp16 training example(8p) | |||
| sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATA_PATH | |||
| # fp32 training example(8p) | |||
| sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATA_PATH | |||
| ``` | |||
| ### Result | |||
| Training result will be stored in the example path. Checkpoints will be stored at `. /ckpt_0` by default, and training log will be redirected to `log.txt` like following. | |||
| Training result will be stored in the example path. Checkpoints will be stored at `./ckpt_0` for Ascend and `./gpu_ckpt` for GPU by default, and training log will be redirected to `log.txt` fo Ascend and `log_gpu.txt` for GPU like following. | |||
| - Ascend: | |||
| ``` shell | |||
| epoch: 1 step: 1251, loss is 4.8427444 | |||
| @@ -157,73 +216,100 @@ epoch: 2 step: 1251, loss is 4.0637593 | |||
| epoch time: 598591.422 ms, per step time: 478.490 ms | |||
| ``` | |||
| - GPU: | |||
| ``` shell | |||
| epoch: 1 step: 20018, loss is 5.479554 | |||
| epoch time: 5664051.330 ms, per step time: 282.948 ms | |||
| epoch: 2 step: 20018, loss is 5.179064 | |||
| epoch time: 5628609.779 ms, per step time: 281.177 ms | |||
| ``` | |||
| ## [Eval process](#contents) | |||
| ### Usage | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| - Ascend: | |||
| ```shell | |||
| sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| - GPU: | |||
| ```shell | |||
| sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| ### Launch | |||
| ```shell | |||
| # eval example | |||
| python: | |||
| Ascend: python eval.py --device_target Ascend --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR | |||
| GPU: python eval.py --device_target GPU --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR | |||
| shell: | |||
| Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| GPU: sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| ### Result | |||
| Evaluation result will be stored in the example path, you can find result like the following in `eval.log`. | |||
| Evaluation result will be stored in the example path, you can find result like the following in `eval.log` on ascend and `eval_gpu.log` on gpu. | |||
| - Evaluating with ascend | |||
| ```shell | |||
| result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc': 0.9485777243589744} | |||
| ``` | |||
| - Evaluating with gpu | |||
| ```shell | |||
| result: {'Loss': 1.7846775874590903, 'Top_1_Acc': 0.798735595390525, 'Top_5_Acc': 0.9498439500640204} | |||
| ``` | |||
| # [Model description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ---------------------------------------------- | | |||
| | Model Version | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | | |||
| | uploaded Date | 12/10/2020 | | |||
| | MindSpore Version | 1.1.0 | | |||
| | Dataset | 1200k images | | |||
| | Batch_size | 128 | | |||
| | Training Parameters | src/config.py | | |||
| | Optimizer | Momentum | | |||
| | Loss Function | CrossEntropySmooth | | |||
| | Loss | 1.78 | | |||
| | Accuracy (8p) | Top1[79.8%] Top5[94.8%] | | |||
| | Per step time (8p) | 479 ms/step | | |||
| | Total time (8p) | 42h | | |||
| | Params (M) | 180M | | |||
| | Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | | |||
| | Parameters | Ascend | GPU | | |||
| | -------------------------- | ------------------------- | ------------------------- | | |||
| | Model Version | Xception | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts | | |||
| | uploaded Date | 12/10/2020 | 02/09/2021 | | |||
| | MindSpore Version | 1.1.0 | 1.1.0 | | |||
| | Dataset | 1200k images | 1200k images | | |||
| | Batch_size | 128 | 64 | | |||
| | Training Parameters | src/config.py | src/config.py | | |||
| | Optimizer | Momentum | Momentum | | |||
| | Loss Function | CrossEntropySmooth | CrossEntropySmooth | | |||
| | Loss | 1.78 | 1.78 | | |||
| | Accuracy (8p) | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] | | |||
| | Per step time (8p) | 479 ms/step | 282 ms/step | | |||
| | Total time (8p) | 42h | 51h | | |||
| | Params (M) | 180M | 180M | | |||
| | Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | | |||
| #### Inference Performance | |||
| | Parameters | Ascend | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | | |||
| | Uploaded Date | 12/10/2020 | | |||
| | MindSpore Version | 1.1.0 | | |||
| | Dataset | 50k images | | |||
| | Batch_size | 128 | | |||
| | Accuracy | Top1[79.8%] Top5[94.8%] | | |||
| | Total time | 3mins | | |||
| | Parameters | Ascend | GPU | | |||
| | ------------------- | --------------------------- | --------------------------- | | |||
| | Model Version | Xception | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts | | |||
| | Uploaded Date | 12/10/2020 | 02/09/2021 | | |||
| | MindSpore Version | 1.1.0 | 1.1.0 | | |||
| | Dataset | 50k images | 50k images | | |||
| | Batch_size | 128 | 64 | | |||
| | Accuracy | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] | | |||
| | Total time | 3mins | 4.7mins | | |||
| # [Description of Random Situation](#contents) | |||
| @@ -20,7 +20,7 @@ from mindspore.common import set_seed | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.Xception import xception | |||
| from src.config import config | |||
| from src.config import config_gpu, config_ascend | |||
| from src.dataset import create_dataset | |||
| from src.loss import CrossEntropySmooth | |||
| @@ -28,14 +28,21 @@ set_seed(1) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||
| parser.add_argument('--device_target', type=str, default='GPU', help='Device target') | |||
| parser.add_argument('--device_id', type=int, default=0, help='Device id') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| args_opt = parser.parse_args() | |||
| if args_opt.device_target == "Ascend": | |||
| config = config_ascend | |||
| elif args_opt.device_target == "GPU": | |||
| config = config_gpu | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| context.set_context(device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) | |||
| # create dataset | |||
| dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0) | |||
| @@ -59,5 +66,5 @@ if __name__ == '__main__': | |||
| model = Model(net, loss_fn=loss, metrics=eval_metrics) | |||
| # eval model | |||
| res = model.eval(dataset, dataset_sink_mode=False) | |||
| res = model.eval(dataset, dataset_sink_mode=True) | |||
| print("result:", res, "ckpt=", args_opt.checkpoint_path) | |||
| @@ -19,7 +19,7 @@ import numpy as np | |||
| from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export | |||
| from src.Xception import xception | |||
| from src.config import config | |||
| from src.config import config_ascend, config_gpu | |||
| parser = argparse.ArgumentParser(description="Image classification") | |||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | |||
| @@ -29,13 +29,19 @@ parser.add_argument("--height", type=int, default=299, help="input height") | |||
| parser.add_argument("--file_name", type=str, default="xception", help="xception output file name.") | |||
| parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], | |||
| default="MINDIR", help="file format") | |||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend", | |||
| parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="GPU", | |||
| help="device target") | |||
| args = parser.parse_args() | |||
| if args.device_target == "Ascend": | |||
| config = config_ascend | |||
| elif args.device_target == "GPU": | |||
| config = config_gpu | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
| if args.device_target == "Ascend": | |||
| context.set_context(device_id=args.device_id) | |||
| context.set_context(device_id=args.device_id) | |||
| if __name__ == "__main__": | |||
| # define net | |||
| @@ -0,0 +1,32 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=$1 | |||
| DATA_DIR=$2 | |||
| PATH_CHECKPOINT=$3 | |||
| rm -rf eval_output | |||
| mkdir ./eval_output | |||
| cd ./eval_output || exit | |||
| echo "start evaluating model..." | |||
| python ../eval.py \ | |||
| --device_target=GPU \ | |||
| --device_id=$DEVICE_ID \ | |||
| --checkpoint_path=$PATH_CHECKPOINT \ | |||
| --dataset_path=$DATA_DIR | |||
| #--dataset_path=$DATA_DIR > eval_gpu.log 2>&1 & | |||
| cd ../ | |||
| @@ -0,0 +1,55 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_NUM=$1 | |||
| export RANK_SIZE=$1 | |||
| DATA_DIR=$2 | |||
| #DATA_DIR=/gdata/ImageNet2012/train/ | |||
| PYTHON_EXEC=python | |||
| if [ $1 -gt 1 ] | |||
| then | |||
| PATH_TRAIN="./train_distribute_gpu_fp16"$(date "+%Y%m%d%H%M%S") | |||
| if [ -d $PATH_TRAIN ]; | |||
| then | |||
| rm -rf $PATH_TRAIN | |||
| fi | |||
| mkdir $PATH_TRAIN | |||
| cd $PATH_TRAIN || exit | |||
| echo "start distributed training on $DEVICE_NUM gpus" | |||
| mpirun -n $1 --allow-run-as-root \ | |||
| --output-filename gpu_fp16_dist_log \ | |||
| --merge-stderr-to-stdout \ | |||
| ${PYTHON_EXEC} ../train.py \ | |||
| --is_distributed \ | |||
| --device_target=GPU \ | |||
| --dataset_path=$DATA_DIR > gpu_fp16_dist_log.txt 2>&1 & | |||
| else | |||
| PATH_TRAIN="./train_standalone_gpu_fp16"$(date "+%Y%m%d%H%M%S") | |||
| if [ -d $PATH_TRAIN ]; | |||
| then | |||
| rm -rf $PATH_TRAIN | |||
| fi | |||
| mkdir $PATH_TRAIN | |||
| cd $PATH_TRAIN || exit | |||
| echo "start training standalone on gpu device $DEVICE_ID" | |||
| ${PYTHON_EXEC} ../train.py \ | |||
| --device_target=GPU \ | |||
| --dataset_path=$DATA_DIR > gpu_fp16_standard_log.txt 2>&1 & | |||
| fi | |||
| cd ../ | |||
| @@ -0,0 +1,58 @@ | |||
| #!/bin/bash | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_NUM=$1 | |||
| export RANK_SIZE=$1 | |||
| DATA_DIR=$2 | |||
| #DATA_DIR=/gdata/ImageNet2012/train/ | |||
| PYTHON_EXEC=python | |||
| if [ $1 -gt 1 ] | |||
| then | |||
| PATH_TRAIN="./train_distribute_gpu_fp32"$(date "+%Y%m%d%H%M%S") | |||
| if [ -d $PATH_TRAIN ]; | |||
| then | |||
| rm -rf $PATH_TRAIN | |||
| fi | |||
| mkdir $PATH_TRAIN | |||
| cd $PATH_TRAIN || exit | |||
| echo "start distributed training on $DEVICE_NUM gpus" | |||
| mpirun -n $1 --allow-run-as-root \ | |||
| --output-filename gpu_fp32_dist_log \ | |||
| --merge-stderr-to-stdout \ | |||
| ${PYTHON_EXEC} ../train.py \ | |||
| --is_distributed \ | |||
| --device_target=GPU \ | |||
| --is_fp32 \ | |||
| --dataset_path=$DATA_DIR > gpu_fp32_dist_log.txt 2>&1 & | |||
| else | |||
| PATH_TRAIN="./train_standalone_gpu_fp32"$(date "+%Y%m%d%H%M%S") | |||
| if [ -d $PATH_TRAIN ]; | |||
| then | |||
| rm -rf $PATH_TRAIN | |||
| fi | |||
| mkdir $PATH_TRAIN | |||
| cd $PATH_TRAIN || exit | |||
| echo "start training standalone on gpu device $DEVICE_ID" | |||
| #${PYTHON_EXEC} ../train.py \ --dataset_path=/gdata/ImageNet2012/train/ | |||
| ${PYTHON_EXEC} ../train.py \ | |||
| --device_target=GPU \ | |||
| --is_fp32 \ | |||
| --dataset_path=$DATA_DIR > gpu_fp32_standard_log.txt 2>&1 & | |||
| fi | |||
| cd ../ | |||
| @@ -17,8 +17,30 @@ network config setting, will be used in train.py and eval.py | |||
| """ | |||
| from easydict import EasyDict as ed | |||
| # config for Xception, imagenet2012. | |||
| config = ed({ | |||
| # config on GPU for Xception, imagenet2012. | |||
| config_gpu = ed({ | |||
| "class_num": 1000, | |||
| "batch_size": 64, | |||
| "loss_scale": 1024, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1e-4, | |||
| "epoch_size": 250, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 1, | |||
| "keep_checkpoint_max": 5, | |||
| "save_checkpoint_path": "./gpu-ckpt", | |||
| "warmup_epochs": 1, | |||
| "lr_decay_mode": "linear", | |||
| "use_label_smooth": True, | |||
| "finish_epoch": 0, | |||
| "label_smooth_factor": 0.1, | |||
| "lr_init": 0.00004, | |||
| "lr_max": 0.4, | |||
| "lr_end": 0.00004 | |||
| }) | |||
| # config on Ascend for Xception, imagenet2012. | |||
| config_ascend = ed({ | |||
| "class_num": 1000, | |||
| "batch_size": 128, | |||
| "loss_scale": 1024, | |||
| @@ -29,7 +29,7 @@ from mindspore.common import set_seed | |||
| from src.lr_generator import get_lr | |||
| from src.Xception import xception | |||
| from src.config import config | |||
| from src.config import config_gpu, config_ascend | |||
| from src.dataset import create_dataset | |||
| from src.loss import CrossEntropySmooth | |||
| @@ -38,83 +38,103 @@ set_seed(1) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification training') | |||
| parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', help='run platform') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], | |||
| help='run platform, (Default: Ascend)') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='dataset path') | |||
| parser.add_argument("--is_fp32", action='store_true', default=False, help='fp32 training, add --is_fp32') | |||
| parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') | |||
| args_opt = parser.parse_args() | |||
| args_opt = parser.parse_args() | |||
| if args_opt.device_target == "Ascend": | |||
| #train on Ascend | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) | |||
| config = config_ascend | |||
| elif args_opt.device_target == "GPU": | |||
| config = config_gpu | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| # init distributed | |||
| if args_opt.is_distributed: | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| init() | |||
| rank = get_rank() | |||
| group_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) | |||
| else: | |||
| rank = 0 | |||
| group_size = 1 | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| # init distributed | |||
| if args_opt.is_distributed: | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| init() | |||
| rank = get_rank() | |||
| group_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) | |||
| else: | |||
| rank = 0 | |||
| group_size = 1 | |||
| context.set_context(device_id=0) | |||
| # if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| # context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| # define network | |||
| net = xception(class_num=config.class_num) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) | |||
| # define network | |||
| net = xception(class_num=config.class_num) | |||
| if args_opt.device_target == "Ascend": | |||
| net.to_float(mstype.float16) | |||
| # define loss | |||
| if not config.use_label_smooth: | |||
| config.label_smooth_factor = 0.0 | |||
| loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| # define loss | |||
| if not config.use_label_smooth: | |||
| config.label_smooth_factor = 0.0 | |||
| loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| # define dataset | |||
| dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size, | |||
| device_num=group_size, rank=rank) | |||
| step_size = dataset.get_dataset_size() | |||
| # define dataset | |||
| dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size, | |||
| device_num=group_size, rank=rank) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| if args_opt.resume: | |||
| ckpt = load_checkpoint(args_opt.resume) | |||
| load_param_into_net(net, ckpt) | |||
| # resume | |||
| if args_opt.resume: | |||
| ckpt = load_checkpoint(args_opt.resume) | |||
| load_param_into_net(net, ckpt) | |||
| # get learning rate | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| lr = Tensor(get_lr(lr_init=config.lr_init, | |||
| lr_end=config.lr_end, | |||
| lr_max=config.lr_max, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=config.epoch_size, | |||
| steps_per_epoch=step_size, | |||
| lr_decay_mode=config.lr_decay_mode, | |||
| global_step=config.finish_epoch * step_size)) | |||
| # get learning rate | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| lr = Tensor(get_lr(lr_init=config.lr_init, | |||
| lr_end=config.lr_end, | |||
| lr_max=config.lr_max, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=config.epoch_size, | |||
| steps_per_epoch=step_size, | |||
| lr_decay_mode=config.lr_decay_mode, | |||
| global_step=config.finish_epoch * step_size)) | |||
| # define optimization | |||
| # define optimization and model | |||
| if args_opt.device_target == "Ascend": | |||
| opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale) | |||
| # define model | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level='O3', keep_batchnorm_fp32=True) | |||
| elif args_opt.device_target == "GPU": | |||
| if args_opt.is_fp32: | |||
| opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| else: | |||
| opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level='O2', keep_batchnorm_fp32=True) | |||
| # define callbacks | |||
| cb = [TimeMonitor(), LossMonitor()] | |||
| if config.save_checkpoint: | |||
| # define callbacks | |||
| cb = [TimeMonitor(), LossMonitor()] | |||
| if config.save_checkpoint: | |||
| if args_opt.device_target == "Ascend": | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck) | |||
| elif args_opt.device_target == "GPU": | |||
| if args_opt.is_fp32: | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp32/' + 'model_' + str(rank)) | |||
| else: | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp16/' + 'model_' + str(rank)) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck) | |||
| # begin train | |||
| if args_opt.is_distributed: | |||
| if rank == 0: | |||
| cb += [ckpt_cb] | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||
| else: | |||
| # begin train | |||
| print("begin train") | |||
| if args_opt.is_distributed: | |||
| if rank == 0: | |||
| cb += [ckpt_cb] | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||
| print("train success") | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| cb += [ckpt_cb] | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True) | |||
| print("train success") | |||