diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 82cb38aa19..6fb31007a9 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -178,6 +178,21 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co } } +void TransformFormatPosition(std::vector *format_position, size_t position_num) { + MS_EXCEPTION_IF_NULL(format_position); + if (format_position->size() == 0) { + return; + } + + // If the inserted position is kAllPositions, then insert all the positions. + if ((*format_position)[0] == kAllPositions) { + format_position->clear(); + for (size_t index = 0; index < position_num; index++) { + format_position->push_back(index); + } + } +} + bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector &inputs_type) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -198,20 +213,28 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vectorsecond.first; - // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. - if (inputs_format_position.size() == 0) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; input_index++) { - inputs_format_position.push_back(input_index); - } - } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + TransformFormatPosition(&inputs_format_position, input_num); for (const auto &input_format_position : inputs_format_position) { auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position); + // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension. if (input_shape.size() != 4) { return false; } } + + auto outputs_format_position = iter->second.second; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + TransformFormatPosition(&outputs_format_position, output_num); + for (const auto &output_format_position : outputs_format_position) { + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_format_position); + // Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension. + if (output_shape.size() != 4) { + return false; + } + } return true; } @@ -226,13 +249,8 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vectorfullname_with_scope() << ", format: " << cal_format; auto inputs_format_position = iter->second.first; - // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. - if (inputs_format_position.size() == 0) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; input_index++) { - inputs_format_position.push_back(input_index); - } - } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + TransformFormatPosition(&inputs_format_position, input_num); for (const auto &input_format_position : inputs_format_position) { if (input_format_position >= inputs_format->size()) { MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size [" @@ -240,7 +258,10 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vectorsecond.second; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + TransformFormatPosition(&outputs_format_position, output_num); for (const auto &output_format_position : outputs_format_position) { if (output_format_position >= outputs_format->size()) { MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size [" diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h index d15de86529..7cf9db256e 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h @@ -32,8 +32,11 @@ namespace mindspore { namespace device { namespace gpu { -// map, used for getting the insert position of format transform. -// If input position is empty, then insert all the input positions, because the input numbers of this op are variable. +const size_t kAllPositions = SIZE_MAX; + +// Map, used for getting the inserted position of format transform. +// If the inserted position is kAllPositions, then insert all the positions, because the input or output numbers of +// this op are variable. static std::map, std::vector>> kKernelFormatPositionMap = { // Format sensitive. {prim::kPrimConv2D->name(), {{0, 1}, {0}}}, @@ -58,8 +61,8 @@ static std::map, std::vector> {prim::kPrimRelu6Grad->name(), {{0, 1}, {0}}}, {kSliceOpName, {{0}, {0}}}, {kTensorAddOpName, {{0, 1}, {0}}}, - {prim::kPrimConcat->name(), {{}, {0}}}, - {prim::kPrimAddN->name(), {{}, {0}}}, + {prim::kPrimConcat->name(), {{kAllPositions}, {0}}}, + {prim::kPrimAddN->name(), {{kAllPositions}, {0}}}, }; void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);