|
|
|
@@ -176,9 +176,18 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type |
|
|
|
if (inputs_type.size() == 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
if (input_shape.size() != 4) { |
|
|
|
return false; |
|
|
|
auto inputs_format_position = iter->second.first; |
|
|
|
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable. |
|
|
|
if (inputs_format_position.size() == 0) { |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) { |
|
|
|
inputs_format_position.push_back(input_index); |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &input_format_position : inputs_format_position) { |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position); |
|
|
|
if (input_shape.size() != 4) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -223,7 +232,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { |
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) { |
|
|
|
std::vector<std::string> inputs_format; |
|
|
|
std::vector<TypeId> inputs_type; |
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { |
|
|
|
@@ -237,7 +246,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) { |
|
|
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); |
|
|
|
} |
|
|
|
std::string origin_data_format = kOpFormat_DEFAULT; |
|
|
|
if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { |
|
|
|
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) { |
|
|
|
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format); |
|
|
|
} |
|
|
|
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = |
|
|
|
|