From e7ea343738908c7c8b15665767867e12e93dcfd7 Mon Sep 17 00:00:00 2001 From: zuochuanyong Date: Thu, 15 Apr 2021 20:23:14 +0800 Subject: [PATCH] add format transform pass on cpu --- .../cpu/insert_format_transform_op.cc | 204 ++++++++++++++++++ .../cpu/insert_format_transform_op.h | 35 +++ .../ccsrc/backend/session/cpu_session.cc | 5 +- mindspore/ccsrc/common/trans.cc | 20 +- .../hardware/cpu/cpu_device_context.cc | 2 + mindspore/ccsrc/utils/utils.h | 2 + 6 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc create mode 100644 mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc b/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc new file mode 100644 index 0000000000..cb86cafb09 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc @@ -0,0 +1,204 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/cpu/insert_format_transform_op.h" + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { + +constexpr int kMinDimNeedToTransform = 3; +enum FormatTransformDir { ChannelFisrt2ChannelLast = 0, ChannelLast2ChannelFirst }; + +// get perm between channel-first shape and channel-last shape. +// eg. 4D channe-first => channel-last: [0,1,2,3] => [0,2,3,1]; +// eg. 4D channe-last => channel-first: [0,1,2,3] => [0,3,1,2]; +std::vector TransposeAxis(const int dim, FormatTransformDir dir) { + std::vector axis; + axis.resize(dim); + if (dir == ChannelFisrt2ChannelLast) { + std::iota(axis.begin() + 1, axis.end(), 2); + axis[dim - 1] = 1; + } else { + std::iota(axis.begin() + 2, axis.end(), 1); + axis[1] = dim - 1; + } + return axis; +} + +CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, + int used_node_index, const std::vector &transpose_perm) { + MS_LOG(ERROR) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() + << ", index: " << used_node_index; + MS_EXCEPTION_IF_NULL(graph); + // 1.Create a transpose node or a fake transpose node:reshape. + auto primitive_ptr = prim::kPrimTranspose; + auto transpose_prim = std::make_shared(primitive_ptr->name()); + MS_EXCEPTION_IF_NULL(transpose_prim); + // 2.Set the input of transpose. + std::vector transpose_input = {NewValueNode(transpose_prim), node}; + auto transpose_op = graph->NewCNode(transpose_input); + // 3.Set the output info of transpose. + auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; + auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; + AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); + // 4. Set the new edge of transpose op. + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->SetEdge(used_node, used_node_index + 1, transpose_op); + return transpose_op; +} + +void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + auto output_type = AnfAlgo::GetOutputInferDataType(node, 0); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({input_format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsFormat({output_format}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(UNKNOWN_KERNEL_TYPE); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get()); +} + +void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index, + const std::vector &transpose_perm, const std::string &transpose_format) { + auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index); + for (size_t i = 0; i < used_node_list->size(); i++) { + auto used_node = used_node_list->at(i).first; + auto used_node_index = used_node_list->at(i).second - 1; + if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { + MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item."; + } + + // node->used_node, if output format of node equals input format of used_node, + // then no need to insert transpose between node and used_node. + auto used_node_in_format = + AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT; + if (transpose_format == used_node_in_format) { + continue; + } + auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); + SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op); + } +} + +void InsertTransformOpForInput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) { + auto inputs_format = AnfAlgo::GetAllInputFormats(node); + for (size_t i = 0; i < inputs_format.size(); ++i) { + if ((inputs_format[i] == kOpFormat_DEFAULT) || (inputs_format[i] == origin_format)) { + continue; + } + auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(node, i); + if (inputs_format[i] == prev_input_format) { + continue; + } + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i); + auto dim = in_shape.size(); + if (dim < kMinDimNeedToTransform) { + continue; + } + auto input_node = AnfAlgo::GetInputNode(utils::cast(node), i); + MS_EXCEPTION_IF_NULL(input_node); + auto transpose_perm = TransposeAxis(dim, ChannelFisrt2ChannelLast); + auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm); + SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op); + } +} + +// Insert output transpose from output_format to origin_format. +void InsertTransformOpForOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) { + auto outputs_format = AnfAlgo::GetAllOutputFormats(node); + for (size_t i = 0; i < outputs_format.size(); ++i) { + if ((outputs_format[i] == kOpFormat_DEFAULT) || (outputs_format[i] == origin_format)) { + continue; + } + auto out_shape = AnfAlgo::GetOutputInferShape(node, i); + auto dim = out_shape.size(); + if (dim < kMinDimNeedToTransform) { + continue; + } + auto transpose_perm = TransposeAxis(dim, ChannelLast2ChannelFirst); + // Find all nodes connected with node output, and change their inputs to transpose. + auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i); + for (size_t j = 0; j < used_node_list->size(); ++j) { + auto used_node = used_node_list->at(j).first; + auto used_node_index = used_node_list->at(j).second - 1; + if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) { + MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item."; + // The tuple item need get next used nodes again. + ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]); + continue; + } + // node->used_node, if output format of node equals input format of used_node, + // then no need to insert transpose between node and used_node. + auto used_node_in_format = + AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT; + if (outputs_format[i] == used_node_in_format) { + continue; + } + auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm); + SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op); + } + } +} + +} // namespace + +const std::unordered_set kChannelLastKernel = {prim::kPrimBiasAdd->name()}; + +bool InsertFormatTransformOpCPU::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(graph->get_return()); + + for (auto node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node)) { + continue; + } + + auto iter = kChannelLastKernel.find(AnfAlgo::GetCNodeName(node)); + if (iter == kChannelLastKernel.end()) { + continue; + } + auto origin_format = AnfAlgo::GetOriginDataFormat(node); + if (origin_format == kOpFormat_DEFAULT) { + origin_format = kOpFormat_ChannelFirst; + } + + InsertTransformOpForInput(graph, node, origin_format); + InsertTransformOpForOutput(graph, node, origin_format); + } + + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h b/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h new file mode 100644 index 0000000000..90152272a9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H + +#include +#include "backend/optimizer/common/optimizer.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertFormatTransformOpCPU : public Pass { + public: + explicit InsertFormatTransformOpCPU(const std::string &name) : Pass("insert_format_transform_op_cpu") {} + ~InsertFormatTransformOpCPU() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index bd7f946697..9c0a83ed85 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -28,6 +28,7 @@ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass_manager.h" #include "backend/optimizer/cpu/insert_cast_cpu.h" +#include "backend/optimizer/cpu/insert_format_transform_op.h" #include "backend/optimizer/pass/replace_node_by_proxy.h" #include "backend/optimizer/pass/erase_visit_attr.h" #include "debug/anf_ir_dump.h" @@ -87,8 +88,10 @@ void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { } #endif pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); MS_LOG(INFO) << "insert cast pass"; + pm->AddPass(std::make_shared("insert_format_transform_op_cpu")); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 497b29e278..83d0b16fcd 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -350,6 +350,21 @@ std::vector NcdhwDeviceShape(const std::vector &shape) { return shape; } +// change channel-first shape to channel-last shape. +// eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3] +std::vector ChannelLastDeviceShape(const std::vector &shape) { + auto dim = shape.size(); + std::vector axis; + axis.resize(dim); + std::iota(axis.begin() + 1, axis.end(), 2); + axis[dim - 1] = 1; + + std::vector device_shape; + std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape), [&shape](int n) { return shape[n]; }); + + return device_shape; +} + std::vector PaddingShapeTo4dByDefault(const std::vector &shape) { std::vector shape_4d(kNchwDims, 1); switch (shape.size()) { @@ -381,7 +396,7 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { if (shape_size == 0) { return false; } - if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { + if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ || format == kOpFormat_ChannelLast) { return false; } else if (shape_size < kNchwDims) { return true; @@ -555,6 +570,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, {kOpFormat_NCDHW, NcdhwDeviceShape}, + {kOpFormat_ChannelLast, ChannelLastDeviceShape}, {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}}; @@ -592,7 +608,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s device_shape.push_back(kCubeSize); return device_shape; } - if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { + if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; temp_shape = PaddingShapeTo4dByDefault(shape); } diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index f404e79385..9b6acba623 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -24,6 +24,7 @@ #include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/pass_manager.h" #include "backend/optimizer/cpu/insert_cast_cpu.h" +#include "backend/optimizer/cpu/insert_format_transform_op.h" #include "backend/optimizer/pass/replace_node_by_proxy.h" #include "backend/optimizer/pass/erase_visit_attr.h" @@ -89,6 +90,7 @@ void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const { auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared("insert_format_transform_op_cpu")); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); (void)optimizer->Optimize(graph); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 0437ce3255..4f64ff6bfc 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -471,6 +471,8 @@ constexpr auto kLoadRealInput = 1; constexpr auto kLoadStateInput = 2; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; +constexpr auto kOpFormat_ChannelFirst = "ChannelFirst"; +constexpr auto kOpFormat_ChannelLast = "ChannelLast"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; constexpr auto kOpFormat_ND = "ND"; constexpr auto kOpFormat_NCHW = "NCHW";