Browse Source

!12487 Add Xception Model for GPU

From: @riesman
Reviewed-by: 
Signed-off-by:
pull/12487/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
f498c1a8c7
11 changed files with 750 additions and 116 deletions
  1. +18
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h
  2. +298
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cu
  3. +33
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh
  4. +130
    -44
      model_zoo/official/cv/xception/README.md
  5. +11
    -4
      model_zoo/official/cv/xception/eval.py
  6. +10
    -4
      model_zoo/official/cv/xception/export.py
  7. +32
    -0
      model_zoo/official/cv/xception/scripts/run_eval_gpu.sh
  8. +55
    -0
      model_zoo/official/cv/xception/scripts/run_train_gpu_fp16.sh
  9. +58
    -0
      model_zoo/official/cv/xception/scripts/run_train_gpu_fp32.sh
  10. +24
    -2
      model_zoo/official/cv/xception/src/config.py
  11. +81
    -61
      model_zoo/official/cv/xception/train.py

+ 18
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
@@ -49,7 +50,23 @@ class TransposeGpuFwdKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
size_t size = input_size_ / sizeof(T);
CalTranspose(size, input, input_shape, input_axis, shape_size_, output, reinterpret_cast<cudaStream_t>(stream_ptr));

size_t *h_input_shape = &input_shape_[0];
size_t *h_input_axis = &input_axis_[0];
// nhwc->nchw: 0,3,1,2
if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 3 && h_input_axis[2] == 1 &&
h_input_axis[3] == 2) {
CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else if (shape_size_ == 4 && h_input_axis[0] == 0 && h_input_axis[1] == 2 && h_input_axis[2] == 3 &&
h_input_axis[3] == 1) {
// nchw->nhwc: 0,2,3,1
CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CalTranspose(size, input, input_shape, input_axis, shape_size_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}



+ 298
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cu View File

@@ -0,0 +1,298 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <cstdint>
#include <vector>
#include <limits>
#include <utility>
#include <algorithm>
#include "transpose_impl_opt.cuh"
#include "runtime/device/gpu/cuda_common.h"

// Optimize nchw2nhwc && nhwc2nchw with tiling and shared memory.
// Firstly, combined 2 dims hw together, treat input and output as 3D tensor.
// Secondly, determine whether a matrix is a large matrix or a narrow matrix,
// which determines the chosen TileSize.
// Reason: tiling and shared memory can avoid uncoalesced global memory access.
// There are two stages of this kernel, load-to-shm and write-to-output.
// load-to-shm: Threads in a thread block work together to load input data tile to shared mem.
// write-to-output: Threads in a thread block work together to write shared mem to output tile.
// because of the shared mem usage, The access to both input and output memory can be coalesced.

// SimpleTransposeKernel for small matrix
template <typename T>
__global__ void SimpleTransposeKernel(const size_t size, const T *input, const size_t *input_shape,
const size_t *input_axis, const size_t shape_size, T *output) {
size_t pos_size;
size_t temp_pos;
size_t newpos;
size_t newpos_size;
size_t pos_array[4];
// for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] +
// posArray[1] * input_shape[2] * input_shape[3] +
// posArray[2] * input_shape[3] +
// posArray[3]
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
temp_pos = pos;
pos_size = size / input_shape[0]; // C * H * W
pos_array[0] = temp_pos / pos_size; // i / (CHW)
for (size_t i = 1; i < shape_size; i++) {
temp_pos -= pos_array[i - 1] * pos_size;
pos_size = pos_size / input_shape[i];
pos_array[i] = temp_pos / pos_size;
}
newpos = pos_array[input_axis[shape_size - 1]];
newpos_size = 1;
for (int64_t j = shape_size - 2; j >= 0; j--) {
newpos_size *= input_shape[input_axis[j + 1]];
newpos += pos_array[input_axis[j]] * newpos_size;
}
output[newpos] = *(input + pos);
}
return;
}

__forceinline__ __device__ int TensorIdxToOneDimIdx(int ndims, const int *idx, const int *dims) {
int flat_idx = idx[0];
for (int i = 1; i < ndims; i++) {
flat_idx = flat_idx * dims[i] + idx[i];
}
return flat_idx;
}
__forceinline__ __device__ void OneDimIdxToTensorIdx(int ndims, int idx, const int *dims, int *out_tensor_idx) {
for (int i = ndims - 1; i >= 0; i--) {
int new_idx = idx / dims[i];
out_tensor_idx[i] = idx - dims[i] * new_idx;
idx = new_idx;
}
}

template <typename T>
__global__ void Swap3DTensorLast2DimKernel_shared(const T *input, int NumThreads, int TileHeight, int TileWidth,
int input_dims_0, int input_dims_1, int input_dims_2, T *output) {
extern __shared__ unsigned char sdata_uchar[];
// shm_tile[TileHeight][TileWidth + 1]: to avoid bank conflict in write-to-output period
T *shm_tile = reinterpret_cast<T*>(sdata_uchar);
int NumRowsPerLoadLoop = NumThreads / TileWidth; // the number of shm rows that all threads can load into shm once
int NumColsPerWriteLoop =
NumThreads / TileHeight; // the number of shm cols that all threads can write into output once
int load_thread_num_align = NumRowsPerLoadLoop * TileWidth; // use align num threads in load-to-shm period
int write_thread_num_align = NumColsPerWriteLoop * TileHeight; // use align num threads in write-to-output period
int tid = threadIdx.x;
int input_dims[3] = {input_dims_0, input_dims_1, input_dims_2};
int output_dims[3] = {input_dims[0], input_dims[2], input_dims[1]};
int input_dims_in_tiles[3] = {input_dims[0], (input_dims[1] + TileHeight - 1) / TileHeight,
(input_dims[2] + TileWidth - 1) / TileWidth};
int input_tile_idx[3];
OneDimIdxToTensorIdx(3, blockIdx.x, input_dims_in_tiles, input_tile_idx);
int input_tile_origin[3] = {input_tile_idx[0], input_tile_idx[1] * TileHeight, input_tile_idx[2] * TileWidth};
int input_block_start_idx = TensorIdxToOneDimIdx(3, input_tile_origin, input_dims); // input idx of this thread block
bool full_tile = true;
int tile_width = TileWidth;
// Only the last row or column may not have the full size
// boundary process
if (input_tile_idx[2] == input_dims_in_tiles[2] - 1) {
tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileWidth;
full_tile &= false;
}
int tile_height = TileHeight;
if (input_tile_idx[1] == input_dims_in_tiles[1] - 1) {
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileHeight;
full_tile &= false;
}
// load-to-shm: each block load input data into shared mem(loop)
if (tid < load_thread_num_align) {
// Map task blocks to thread blocks.
// organize threads to n*TileWidth
int shm_row_id = tid / TileWidth; // shem_row_id, also the block row_id of input
int shm_col_id = tid % TileWidth; // shem_col_id, also the block col_id of input
int input_idx = input_block_start_idx + shm_row_id * input_dims[2] + shm_col_id; // the input idx of this thread
int input_step = NumRowsPerLoadLoop * input_dims[2];
if (full_tile) { // thread blocks responses for inner tiles
#pragma unroll
for (int row_id = shm_row_id; row_id < (TileHeight);
row_id += NumRowsPerLoadLoop) { // move to the next pass, loop
// shm_tile[row_id][shm_col_id]
shm_tile[row_id * (TileWidth + 1) + shm_col_id] =
input[input_idx]; // each thread load one input data into shared mem
input_idx += input_step; // calculate the next input idx this thread should load
}
} else { // boundary process: thread blocks responses for edge tiles
if (shm_col_id < tile_width) {
for (int row_id = shm_row_id; row_id < (tile_height); row_id += NumRowsPerLoadLoop) {
// shm_tile[row_id][shm_col_id]
shm_tile[row_id * (TileWidth + 1) + shm_col_id] = input[input_idx];
input_idx += input_step;
}
}
}
}
__syncthreads();
// load-to-shm: end

// write-to-output: each block write shared mem into output(loop)
int output_tile_idx[3] = {input_tile_idx[0], input_tile_idx[2], input_tile_idx[1]};
int output_tile_origin[3] = {output_tile_idx[0], output_tile_idx[1] * TileWidth, output_tile_idx[2] * TileHeight};
int output_block_start_idx = TensorIdxToOneDimIdx(3, output_tile_origin, output_dims);
if (tid < write_thread_num_align) {
// organize threads to TileHeight*n1
int shm_col_id = tid / TileHeight; // shm_col_id, also the block row_id of output
int shm_row_id = tid % TileHeight; // shm_row_id, also the block col_id of output
int output_idx = output_block_start_idx + shm_col_id * output_dims[2] + shm_row_id;
int output_step = NumColsPerWriteLoop * output_dims[2];
if (full_tile) {
#pragma unroll
for (int col_id = shm_col_id; col_id < (TileWidth);
col_id += NumColsPerWriteLoop) { // move to the next pass, loop
// shm_tile[shm_row_id][col_id]
output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id]; // avoid bank conflict
output_idx += output_step;
}
} else {
if (shm_row_id < tile_height) {
for (int col_id = shm_col_id; col_id < (tile_width); col_id += NumColsPerWriteLoop) {
// shm_tile[shm_row_id][col_id];
output[output_idx] = shm_tile[shm_row_id * (TileWidth + 1) + col_id];
output_idx += output_step;
}
}
}
}
}

template <typename T>
void Swap3DTensorLast2Dim(const size_t size, const size_t shape_size, int *combined_dims, const T *d_input,
const size_t *input_shape, const size_t *input_axis, const size_t *d_input_shape,
const size_t *d_input_axis, T *d_output, cudaStream_t cuda_stream) {
static const int kMinDimensionToUseTiles = 16;
static const int kMinDimensionToUseRectTiles = 96;
auto short_side = std::min(combined_dims[1], combined_dims[2]);
auto long_side = std::max(combined_dims[1], combined_dims[2]);
// large matrix
// Both dims are greater than 16 && cuda blocks have enough shared mem.
constexpr int kTileSizeLargeMat = 32;
constexpr int kNumThreadsLargeMat = 256;
auto ShmemReqLargeMat = kTileSizeLargeMat * (kTileSizeLargeMat + 1) * sizeof(T);
bool is_large_matrix = short_side >= kMinDimensionToUseTiles && ShmemReqLargeMat <= SHARED_MEM_PER_BLOCK;
// narrow matrix
// one dim less than 16 && one dim greater than 96(narrow)
constexpr int kTileSizeNarrowMatLongSide = 128;
const int kTileSizeNarrowMatShortSide = short_side;
constexpr int kNumThreadsNarrowMat = kTileSizeNarrowMatLongSide;
auto ShmemReqNarrowMat = kTileSizeNarrowMatLongSide * (kTileSizeNarrowMatShortSide + 1) * sizeof(T);
bool is_narrow_matrix = short_side < kMinDimensionToUseTiles && long_side >= kMinDimensionToUseRectTiles &&
ShmemReqNarrowMat <= SHARED_MEM_PER_BLOCK;
if (is_large_matrix) {
int input_dims_in_tiles[3] = {combined_dims[0], (combined_dims[1] + kTileSizeLargeMat - 1) / kTileSizeLargeMat,
(combined_dims[2] + kTileSizeLargeMat - 1) / kTileSizeLargeMat};
int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
Swap3DTensorLast2DimKernel_shared<T><<<TotalNumTiles, kNumThreadsLargeMat, ShmemReqLargeMat, cuda_stream>>>(
d_input, kNumThreadsLargeMat, kTileSizeLargeMat, kTileSizeLargeMat, combined_dims[0], combined_dims[1],
combined_dims[2], d_output);
} else if (is_narrow_matrix) {
int input_dims_in_tiles[3] = {combined_dims[0], 1,
(long_side + kTileSizeNarrowMatLongSide - 1) / kTileSizeNarrowMatLongSide};
int TotalNumTiles = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
int TileHeight, TileWidth;
if (long_side == combined_dims[1]) {
TileHeight = kTileSizeNarrowMatLongSide;
TileWidth = short_side;
} else {
TileHeight = short_side;
TileWidth = kTileSizeNarrowMatLongSide;
}
Swap3DTensorLast2DimKernel_shared<T><<<TotalNumTiles, kNumThreadsNarrowMat, ShmemReqNarrowMat, cuda_stream>>>(
d_input, kNumThreadsNarrowMat, TileHeight, TileWidth, combined_dims[0], combined_dims[1], combined_dims[2],
d_output);
} else {
SimpleTransposeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, d_input, d_input_shape, d_input_axis,
shape_size, d_output);
}
return;
}
// specific for NHWC -> NCHW
template <typename T>
void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape,
const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis,
T *d_output, cudaStream_t cuda_stream) {
int combined_dims[3];
combined_dims[0] = input_shape[0]; // N
combined_dims[1] = input_shape[1]; // HW
for (unsigned int i = 2; i < shape_size - 1; i++) {
combined_dims[1] *= input_shape[i];
}
combined_dims[2] = input_shape[shape_size - 1]; // C
Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis,
d_output, cuda_stream);
}
// specific for NCHW -> NHWC
template <typename T>
void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape,
const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis,
T *d_output, cudaStream_t cuda_stream) {
int combined_dims[3];
combined_dims[0] = input_shape[0]; // N
combined_dims[1] = input_shape[1]; // C
combined_dims[2] = input_shape[2]; // HW
for (unsigned int i = 3; i < shape_size; ++i) {
combined_dims[2] *= input_shape[i];
}
Swap3DTensorLast2Dim(size, shape_size, combined_dims, d_input, input_shape, input_axis, d_input_shape, d_input_axis,
d_output, cuda_stream);
}

template void CalNHWC2NCHWInterface<double>(const size_t size, const size_t shape_size, const double *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, double *d_output,
cudaStream_t cuda_stream);
template void CalNHWC2NCHWInterface<float>(const size_t size, const size_t shape_size, const float *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, float *d_output,
cudaStream_t cuda_stream);
template void CalNHWC2NCHWInterface<half>(const size_t size, const size_t shape_size, const half *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, half *d_output,
cudaStream_t cuda_stream);
template void CalNHWC2NCHWInterface<int>(const size_t size, const size_t shape_size, const int *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, int *d_output,
cudaStream_t cuda_stream);
template void CalNHWC2NCHWInterface<int64_t>(const size_t size, const size_t shape_size, const int64_t *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output,
cudaStream_t cuda_stream);

template void CalNCHW2NHWCInterface<double>(const size_t size, const size_t shape_size, const double *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, double *d_output,
cudaStream_t cuda_stream);
template void CalNCHW2NHWCInterface<float>(const size_t size, const size_t shape_size, const float *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, float *d_output,
cudaStream_t cuda_stream);
template void CalNCHW2NHWCInterface<half>(const size_t size, const size_t shape_size, const half *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, half *d_output,
cudaStream_t cuda_stream);
template void CalNCHW2NHWCInterface<int>(const size_t size, const size_t shape_size, const int *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, int *d_output,
cudaStream_t cuda_stream);
template void CalNCHW2NHWCInterface<int64_t>(const size_t size, const size_t shape_size, const int64_t *d_input,
const size_t *input_shape, const size_t *input_axis,
const size_t *d_input_shape, const size_t *d_input_axis, int64_t *d_output,
cudaStream_t cuda_stream);

+ 33
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_

#include <cuda_runtime.h>

#define TRANSPOSE_MAX_DIMENSION 100
template <typename T>
void CalNHWC2NCHWInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape,
const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output,
cudaStream_t cuda_stream);

template <typename T>
void CalNCHW2NHWCInterface(const size_t size, const size_t shape_size, const T *d_input, const size_t *input_shape,
const size_t *input_axis, const size_t *d_input_shape, const size_t *d_input_axis, T *output,
cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TRANSPOSE_OPT_H_

+ 130
- 44
model_zoo/official/cv/xception/README.md View File

@@ -5,7 +5,7 @@
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision(Ascend)](#mixed-precisionascend)
- [Mixed Precision](#mixed-precisionascend)
- [Environment Requirements](#environment-requirements)
- [Script description](#script-description)
- [Script and sample code](#script-and-sample-code)
@@ -49,7 +49,7 @@ Dataset used can refer to paper.

# [Features](#contents)

## [Mixed Precision(Ascend)](#contents)
## [Mixed Precision](#contents)

The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.

@@ -57,8 +57,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil

# [Environment Requirements](#contents)

- Hardware(Ascend)
- Prepare hardware environment with Ascend.
- Hardware(Ascend/GPU
- Prepare hardware environment with Ascend, GPU or CPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below:
@@ -74,23 +74,30 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
└─Xception
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train.sh # launch distributed training with ascend platform(8p)
└─run_eval.sh # launch evaluating with ascend platform
├─run_standalone_train.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train.sh # launch distributed training with ascend platform(8p)
├─run_train_gpu_fp32.sh # launch standalone or distributed fp32 training with gpu platform(1p or 8p)
├─run_train_gpu_fp16.sh # launch standalone or distributed fp16 training with gpu platform(1p or 8p)
├─run_eval.sh # launch evaluating with ascend platform
└─run_eval_gpu.sh # launch evaluating with gpu platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─Xception.py # network definition
├─loss.py # Customized CrossEntropy loss function
└─lr_generator.py # learning rate generator
├─train.py # train net
├─export.py # export net
└─eval.py # eval net
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─Xception.py # network definition
├─loss.py # Customized CrossEntropy loss function
└─lr_generator.py # learning rate generator
├─train.py # train net
├─export.py # export net
└─eval.py # eval net

```

## [Script Parameters](#contents)

Parameters for both training and evaluation can be set in config.py.

- Config on ascend

```python
Major parameters in train.py and config.py are:
'num_classes': 1000 # dataset class numbers
@@ -113,6 +120,30 @@ Major parameters in train.py and config.py are:
'lr_end': 0.00004 # min bound of learning rate
```

- Config on gpu

```python
Major parameters in train.py and config.py are:
'num_classes': 1000 # dataset class numbers
'batch_size': 64 # input batchsize
'loss_scale': 1024 # loss scale
'momentum': 0.9 # momentum
'weight_decay': 1e-4 # weight decay
'epoch_size': 250 # total epoch numbers
'save_checkpoint': True # save checkpoint
'save_checkpoint_epochs': 1 # save checkpoint epochs
'keep_checkpoint_max': 5 # max numbers to keep checkpoints
'save_checkpoint_path': "./gpu-ckpt" # save checkpoint path
'warmup_epochs': 1 # warmup epoch numbers
'lr_decay_mode': "linear" # lr decay mode
'use_label_smooth': True # use label smooth
'finish_epoch': 0 # finished epochs numbers
'label_smooth_factor': 0.1 # label smoothing factor
'lr_init': 0.00004 # initiate learning rate
'lr_max': 0.4 # max bound of learning rate
'lr_end': 0.00004 # min bound of learning rate
```

## [Training process](#contents)

### Usage
@@ -128,6 +159,25 @@ sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
```

- GPU:

```shell
# fp32 distributed training example(8p)
sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional)

# fp32 standalone training example
sh scripts/run_train_gpu_fp32.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional)

# fp16 distributed training example(8p)
sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATASET_PATH PRETRAINED_CKPT_PATH(optional)

# fp16 standalone training example
sh scripts/run_train_gpu_fp16.sh 1 DATASET_PATH PRETRAINED_CKPT_PATH(optional)

# infer example
sh run_eval_gpu.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH
```

> Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).

### Launch
@@ -137,6 +187,8 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
python:
Ascend:
python train.py --device_target Ascend --dataset_path /dataset/train
GPU:
python train.py --device_target GPU --dataset_path /dataset/train

shell:
Ascend:
@@ -144,11 +196,18 @@ sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# standalone training
sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
GPU:
# fp16 training example(8p)
sh scripts/run_train_gpu_fp16.sh DEVICE_NUM DATA_PATH
# fp32 training example(8p)
sh scripts/run_train_gpu_fp32.sh DEVICE_NUM DATA_PATH
```

### Result

Training result will be stored in the example path. Checkpoints will be stored at `. /ckpt_0` by default, and training log will be redirected to `log.txt` like following.
Training result will be stored in the example path. Checkpoints will be stored at `./ckpt_0` for Ascend and `./gpu_ckpt` for GPU by default, and training log will be redirected to `log.txt` fo Ascend and `log_gpu.txt` for GPU like following.

- Ascend:

``` shell
epoch: 1 step: 1251, loss is 4.8427444
@@ -157,73 +216,100 @@ epoch: 2 step: 1251, loss is 4.0637593
epoch time: 598591.422 ms, per step time: 478.490 ms
```

- GPU:

``` shell
epoch: 1 step: 20018, loss is 5.479554
epoch time: 5664051.330 ms, per step time: 282.948 ms
epoch: 2 step: 20018, loss is 5.179064
epoch time: 5628609.779 ms, per step time: 281.177 ms
```

## [Eval process](#contents)

### Usage

You can start training using python or shell scripts. The usage of shell scripts as follows:

- Ascend:

```shell
sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```

- GPU:

```shell
sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```

### Launch

```shell
# eval example
python:
Ascend: python eval.py --device_target Ascend --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR
GPU: python eval.py --device_target GPU --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR

shell:
Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
GPU: sh scripts/run_eval_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```

> checkpoint can be produced in training process.

### Result

Evaluation result will be stored in the example path, you can find result like the following in `eval.log`.
Evaluation result will be stored in the example path, you can find result like the following in `eval.log` on ascend and `eval_gpu.log` on gpu.

- Evaluating with ascend

```shell
result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc': 0.9485777243589744}
```

- Evaluating with gpu

```shell
result: {'Loss': 1.7846775874590903, 'Top_1_Acc': 0.798735595390525, 'Top_5_Acc': 0.9498439500640204}
```

# [Model description](#contents)

## [Performance](#contents)

### Training Performance

| Parameters | Ascend |
| -------------------------- | ---------------------------------------------- |
| Model Version | Xception |
| Resource | HUAWEI CLOUD Modelarts |
| uploaded Date | 12/10/2020 |
| MindSpore Version | 1.1.0 |
| Dataset | 1200k images |
| Batch_size | 128 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | CrossEntropySmooth |
| Loss | 1.78 |
| Accuracy (8p) | Top1[79.8%] Top5[94.8%] |
| Per step time (8p) | 479 ms/step |
| Total time (8p) | 42h |
| Params (M) | 180M |
| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) |
| Parameters | Ascend | GPU |
| -------------------------- | ------------------------- | ------------------------- |
| Model Version | Xception | Xception |
| Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts |
| uploaded Date | 12/10/2020 | 02/09/2021 |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | 1200k images | 1200k images |
| Batch_size | 128 | 64 |
| Training Parameters | src/config.py | src/config.py |
| Optimizer | Momentum | Momentum |
| Loss Function | CrossEntropySmooth | CrossEntropySmooth |
| Loss | 1.78 | 1.78 |
| Accuracy (8p) | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] |
| Per step time (8p) | 479 ms/step | 282 ms/step |
| Total time (8p) | 42h | 51h |
| Params (M) | 180M | 180M |
| Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/xception) |

#### Inference Performance

| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | Xception |
| Resource | HUAWEI CLOUD Modelarts |
| Uploaded Date | 12/10/2020 |
| MindSpore Version | 1.1.0 |
| Dataset | 50k images |
| Batch_size | 128 |
| Accuracy | Top1[79.8%] Top5[94.8%] |
| Total time | 3mins |
| Parameters | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | Xception | Xception |
| Resource | HUAWEI CLOUD Modelarts | HUAWEI CLOUD Modelarts |
| Uploaded Date | 12/10/2020 | 02/09/2021 |
| MindSpore Version | 1.1.0 | 1.1.0 |
| Dataset | 50k images | 50k images |
| Batch_size | 128 | 64 |
| Accuracy | Top1[79.8%] Top5[94.8%] | Top1[79.8%] Top5[94.9%] |
| Total time | 3mins | 4.7mins |

# [Description of Random Situation](#contents)



+ 11
- 4
model_zoo/official/cv/xception/eval.py View File

@@ -20,7 +20,7 @@ from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src.Xception import xception
from src.config import config
from src.config import config_gpu, config_ascend
from src.dataset import create_dataset
from src.loss import CrossEntropySmooth

@@ -28,14 +28,21 @@ set_seed(1)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_target', type=str, default='GPU', help='Device target')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')

args_opt = parser.parse_args()
if args_opt.device_target == "Ascend":
config = config_ascend
elif args_opt.device_target == "GPU":
config = config_gpu
else:
raise ValueError("Unsupported device_target.")

context.set_context(device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)

# create dataset
dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0)
@@ -59,5 +66,5 @@ if __name__ == '__main__':
model = Model(net, loss_fn=loss, metrics=eval_metrics)

# eval model
res = model.eval(dataset, dataset_sink_mode=False)
res = model.eval(dataset, dataset_sink_mode=True)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

+ 10
- 4
model_zoo/official/cv/xception/export.py View File

@@ -19,7 +19,7 @@ import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export

from src.Xception import xception
from src.config import config
from src.config import config_ascend, config_gpu

parser = argparse.ArgumentParser(description="Image classification")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
@@ -29,13 +29,19 @@ parser.add_argument("--height", type=int, default=299, help="input height")
parser.add_argument("--file_name", type=str, default="xception", help="xception output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"],
default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="GPU",
help="device target")

args = parser.parse_args()
if args.device_target == "Ascend":
config = config_ascend
elif args.device_target == "GPU":
config = config_gpu
else:
raise ValueError("Unsupported device_target.")

context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
context.set_context(device_id=args.device_id)

if __name__ == "__main__":
# define net


+ 32
- 0
model_zoo/official/cv/xception/scripts/run_eval_gpu.sh View File

@@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3

rm -rf eval_output
mkdir ./eval_output
cd ./eval_output || exit
echo "start evaluating model..."

python ../eval.py \
--device_target=GPU \
--device_id=$DEVICE_ID \
--checkpoint_path=$PATH_CHECKPOINT \
--dataset_path=$DATA_DIR
#--dataset_path=$DATA_DIR > eval_gpu.log 2>&1 &
cd ../

+ 55
- 0
model_zoo/official/cv/xception/scripts/run_train_gpu_fp16.sh View File

@@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

export DEVICE_NUM=$1
export RANK_SIZE=$1
DATA_DIR=$2
#DATA_DIR=/gdata/ImageNet2012/train/
PYTHON_EXEC=python

if [ $1 -gt 1 ]
then
PATH_TRAIN="./train_distribute_gpu_fp16"$(date "+%Y%m%d%H%M%S")
if [ -d $PATH_TRAIN ];
then
rm -rf $PATH_TRAIN
fi
mkdir $PATH_TRAIN
cd $PATH_TRAIN || exit
echo "start distributed training on $DEVICE_NUM gpus"

mpirun -n $1 --allow-run-as-root \
--output-filename gpu_fp16_dist_log \
--merge-stderr-to-stdout \
${PYTHON_EXEC} ../train.py \
--is_distributed \
--device_target=GPU \
--dataset_path=$DATA_DIR > gpu_fp16_dist_log.txt 2>&1 &
else
PATH_TRAIN="./train_standalone_gpu_fp16"$(date "+%Y%m%d%H%M%S")
if [ -d $PATH_TRAIN ];
then
rm -rf $PATH_TRAIN
fi
mkdir $PATH_TRAIN
cd $PATH_TRAIN || exit
echo "start training standalone on gpu device $DEVICE_ID"
${PYTHON_EXEC} ../train.py \
--device_target=GPU \
--dataset_path=$DATA_DIR > gpu_fp16_standard_log.txt 2>&1 &
fi
cd ../

+ 58
- 0
model_zoo/official/cv/xception/scripts/run_train_gpu_fp32.sh View File

@@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

export DEVICE_NUM=$1
export RANK_SIZE=$1
DATA_DIR=$2
#DATA_DIR=/gdata/ImageNet2012/train/
PYTHON_EXEC=python

if [ $1 -gt 1 ]
then
PATH_TRAIN="./train_distribute_gpu_fp32"$(date "+%Y%m%d%H%M%S")
if [ -d $PATH_TRAIN ];
then
rm -rf $PATH_TRAIN
fi
mkdir $PATH_TRAIN
cd $PATH_TRAIN || exit
echo "start distributed training on $DEVICE_NUM gpus"

mpirun -n $1 --allow-run-as-root \
--output-filename gpu_fp32_dist_log \
--merge-stderr-to-stdout \
${PYTHON_EXEC} ../train.py \
--is_distributed \
--device_target=GPU \
--is_fp32 \
--dataset_path=$DATA_DIR > gpu_fp32_dist_log.txt 2>&1 &
else
PATH_TRAIN="./train_standalone_gpu_fp32"$(date "+%Y%m%d%H%M%S")
if [ -d $PATH_TRAIN ];
then
rm -rf $PATH_TRAIN
fi
mkdir $PATH_TRAIN
cd $PATH_TRAIN || exit
echo "start training standalone on gpu device $DEVICE_ID"
#${PYTHON_EXEC} ../train.py \ --dataset_path=/gdata/ImageNet2012/train/
${PYTHON_EXEC} ../train.py \
--device_target=GPU \
--is_fp32 \
--dataset_path=$DATA_DIR > gpu_fp32_standard_log.txt 2>&1 &
fi
cd ../

+ 24
- 2
model_zoo/official/cv/xception/src/config.py View File

@@ -17,8 +17,30 @@ network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed

# config for Xception, imagenet2012.
config = ed({
# config on GPU for Xception, imagenet2012.
config_gpu = ed({
"class_num": 1000,
"batch_size": 64,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 250,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./gpu-ckpt",
"warmup_epochs": 1,
"lr_decay_mode": "linear",
"use_label_smooth": True,
"finish_epoch": 0,
"label_smooth_factor": 0.1,
"lr_init": 0.00004,
"lr_max": 0.4,
"lr_end": 0.00004
})

# config on Ascend for Xception, imagenet2012.
config_ascend = ed({
"class_num": 1000,
"batch_size": 128,
"loss_scale": 1024,


+ 81
- 61
model_zoo/official/cv/xception/train.py View File

@@ -29,7 +29,7 @@ from mindspore.common import set_seed

from src.lr_generator import get_lr
from src.Xception import xception
from src.config import config
from src.config import config_gpu, config_ascend
from src.dataset import create_dataset
from src.loss import CrossEntropySmooth

@@ -38,83 +38,103 @@ set_seed(1)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training')
parser.add_argument('--device_target', type=str, default='Ascend', help='run platform')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='run platform, (Default: Ascend)')
parser.add_argument('--dataset_path', type=str, default=None, help='dataset path')
parser.add_argument("--is_fp32", action='store_true', default=False, help='fp32 training, add --is_fp32')
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = parser.parse_args()

args_opt = parser.parse_args()
if args_opt.device_target == "Ascend":
#train on Ascend
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
config = config_ascend
elif args_opt.device_target == "GPU":
config = config_gpu
else:
raise ValueError("Unsupported device_target.")

# init distributed
if args_opt.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
else:
rank = 0
group_size = 1
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args_opt.is_distributed:
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init()
rank = get_rank()
group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
else:
rank = 0
group_size = 1
context.set_context(device_id=0)
# if os.getenv('DEVICE_ID', "not_set").isdigit():
# context.set_context(device_id=int(os.getenv('DEVICE_ID')))

# define network
net = xception(class_num=config.class_num)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
# define network
net = xception(class_num=config.class_num)
if args_opt.device_target == "Ascend":
net.to_float(mstype.float16)

# define loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
# define loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)

# define dataset
dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size,
device_num=group_size, rank=rank)
step_size = dataset.get_dataset_size()
# define dataset
dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size,
device_num=group_size, rank=rank)
step_size = dataset.get_dataset_size()

# resume
if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net, ckpt)
# resume
if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume)
load_param_into_net(net, ckpt)

# get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode,
global_step=config.finish_epoch * step_size))
# get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode,
global_step=config.finish_epoch * step_size))

# define optimization
# define optimization and model
if args_opt.device_target == "Ascend":
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale)

# define model
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level='O3', keep_batchnorm_fp32=True)
elif args_opt.device_target == "GPU":
if args_opt.is_fp32:
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
else:
opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level='O2', keep_batchnorm_fp32=True)

# define callbacks
cb = [TimeMonitor(), LossMonitor()]
if config.save_checkpoint:
# define callbacks
cb = [TimeMonitor(), LossMonitor()]
if config.save_checkpoint:
if args_opt.device_target == "Ascend":
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck)
elif args_opt.device_target == "GPU":
if args_opt.is_fp32:
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp32/' + 'model_' + str(rank))
else:
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'fp16/' + 'model_' + str(rank))
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck)

# begin train
if args_opt.is_distributed:
if rank == 0:
cb += [ckpt_cb]
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
else:
# begin train
print("begin train")
if args_opt.is_distributed:
if rank == 0:
cb += [ckpt_cb]
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
print("train success")
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
else:
raise ValueError("Unsupported device_target.")
cb += [ckpt_cb]
model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=True)
print("train success")

Loading…
Cancel
Save