|
|
|
@@ -178,6 +178,21 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TransformFormatPosition(std::vector<size_t> *format_position, size_t position_num) { |
|
|
|
MS_EXCEPTION_IF_NULL(format_position); |
|
|
|
if (format_position->size() == 0) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
// If the inserted position is kAllPositions, then insert all the positions. |
|
|
|
if ((*format_position)[0] == kAllPositions) { |
|
|
|
format_position->clear(); |
|
|
|
for (size_t index = 0; index < position_num; index++) { |
|
|
|
format_position->push_back(index); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) { |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
@@ -198,20 +213,28 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type |
|
|
|
if (inputs_type.size() == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto inputs_format_position = iter->second.first; |
|
|
|
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable. |
|
|
|
if (inputs_format_position.size() == 0) { |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
for (size_t input_index = 0; input_index < input_num; input_index++) { |
|
|
|
inputs_format_position.push_back(input_index); |
|
|
|
} |
|
|
|
} |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
TransformFormatPosition(&inputs_format_position, input_num); |
|
|
|
for (const auto &input_format_position : inputs_format_position) { |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position); |
|
|
|
// Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension. |
|
|
|
if (input_shape.size() != 4) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto outputs_format_position = iter->second.second; |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
TransformFormatPosition(&outputs_format_position, output_num); |
|
|
|
for (const auto &output_format_position : outputs_format_position) { |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, output_format_position); |
|
|
|
// Only support the transformer between NCHW and NHWC, so need the shape is 4 dimension. |
|
|
|
if (output_shape.size() != 4) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -226,13 +249,8 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI |
|
|
|
auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW; |
|
|
|
MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format; |
|
|
|
auto inputs_format_position = iter->second.first; |
|
|
|
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable. |
|
|
|
if (inputs_format_position.size() == 0) { |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
for (size_t input_index = 0; input_index < input_num; input_index++) { |
|
|
|
inputs_format_position.push_back(input_index); |
|
|
|
} |
|
|
|
} |
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); |
|
|
|
TransformFormatPosition(&inputs_format_position, input_num); |
|
|
|
for (const auto &input_format_position : inputs_format_position) { |
|
|
|
if (input_format_position >= inputs_format->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size [" |
|
|
|
@@ -240,7 +258,10 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI |
|
|
|
} |
|
|
|
(*inputs_format)[input_format_position] = cal_format; |
|
|
|
} |
|
|
|
|
|
|
|
auto outputs_format_position = iter->second.second; |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); |
|
|
|
TransformFormatPosition(&outputs_format_position, output_num); |
|
|
|
for (const auto &output_format_position : outputs_format_position) { |
|
|
|
if (output_format_position >= outputs_format->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size [" |
|
|
|
|