Browse Source

add format transform pass on cpu

pull/15765/head
zuochuanyong 4 years ago
parent
commit
e7ea343738
6 changed files with 265 additions and 3 deletions
  1. +204
    -0
      mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc
  2. +35
    -0
      mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h
  3. +4
    -1
      mindspore/ccsrc/backend/session/cpu_session.cc
  4. +18
    -2
      mindspore/ccsrc/common/trans.cc
  5. +2
    -0
      mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc
  6. +2
    -0
      mindspore/ccsrc/utils/utils.h

+ 204
- 0
mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.cc View File

@@ -0,0 +1,204 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/cpu/insert_format_transform_op.h"
#include <unordered_set>
#include <numeric>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "utils/utils.h"
namespace mindspore {
namespace opt {
namespace {
constexpr int kMinDimNeedToTransform = 3;
enum FormatTransformDir { ChannelFisrt2ChannelLast = 0, ChannelLast2ChannelFirst };
// get perm between channel-first shape and channel-last shape.
// eg. 4D channe-first => channel-last: [0,1,2,3] => [0,2,3,1];
// eg. 4D channe-last => channel-first: [0,1,2,3] => [0,3,1,2];
std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) {
std::vector<int64_t> axis;
axis.resize(dim);
if (dir == ChannelFisrt2ChannelLast) {
std::iota(axis.begin() + 1, axis.end(), 2);
axis[dim - 1] = 1;
} else {
std::iota(axis.begin() + 2, axis.end(), 1);
axis[1] = dim - 1;
}
return axis;
}
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
int used_node_index, const std::vector<int64_t> &transpose_perm) {
MS_LOG(ERROR) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
<< ", index: " << used_node_index;
MS_EXCEPTION_IF_NULL(graph);
// 1.Create a transpose node or a fake transpose node:reshape.
auto primitive_ptr = prim::kPrimTranspose;
auto transpose_prim = std::make_shared<Primitive>(primitive_ptr->name());
MS_EXCEPTION_IF_NULL(transpose_prim);
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
auto transpose_op = graph->NewCNode(transpose_input);
// 3.Set the output info of transpose.
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
// 4. Set the new edge of transpose op.
FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->SetEdge(used_node, used_node_index + 1, transpose_op);
return transpose_op;
}
void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
auto output_type = AnfAlgo::GetOutputInferDataType(node, 0);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsFormat({output_format});
builder.SetOutputsDeviceType({output_type});
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}
void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
const std::vector<int64_t> &transpose_perm, const std::string &transpose_format) {
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
for (size_t i = 0; i < used_node_list->size(); i++) {
auto used_node = used_node_list->at(i).first;
auto used_node_index = used_node_list->at(i).second - 1;
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item.";
}
// node->used_node, if output format of node equals input format of used_node,
// then no need to insert transpose between node and used_node.
auto used_node_in_format =
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
if (transpose_format == used_node_in_format) {
continue;
}
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op);
}
}
void InsertTransformOpForInput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) {
auto inputs_format = AnfAlgo::GetAllInputFormats(node);
for (size_t i = 0; i < inputs_format.size(); ++i) {
if ((inputs_format[i] == kOpFormat_DEFAULT) || (inputs_format[i] == origin_format)) {
continue;
}
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(node, i);
if (inputs_format[i] == prev_input_format) {
continue;
}
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
auto dim = in_shape.size();
if (dim < kMinDimNeedToTransform) {
continue;
}
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
MS_EXCEPTION_IF_NULL(input_node);
auto transpose_perm = TransposeAxis(dim, ChannelFisrt2ChannelLast);
auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
}
}
// Insert output transpose from output_format to origin_format.
void InsertTransformOpForOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) {
auto outputs_format = AnfAlgo::GetAllOutputFormats(node);
for (size_t i = 0; i < outputs_format.size(); ++i) {
if ((outputs_format[i] == kOpFormat_DEFAULT) || (outputs_format[i] == origin_format)) {
continue;
}
auto out_shape = AnfAlgo::GetOutputInferShape(node, i);
auto dim = out_shape.size();
if (dim < kMinDimNeedToTransform) {
continue;
}
auto transpose_perm = TransposeAxis(dim, ChannelLast2ChannelFirst);
// Find all nodes connected with node output, and change their inputs to transpose.
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
for (size_t j = 0; j < used_node_list->size(); ++j) {
auto used_node = used_node_list->at(j).first;
auto used_node_index = used_node_list->at(j).second - 1;
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item.";
// The tuple item need get next used nodes again.
ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]);
continue;
}
// node->used_node, if output format of node equals input format of used_node,
// then no need to insert transpose between node and used_node.
auto used_node_in_format =
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
if (outputs_format[i] == used_node_in_format) {
continue;
}
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op);
}
}
}
} // namespace
const std::unordered_set<std::string> kChannelLastKernel = {prim::kPrimBiasAdd->name()};
bool InsertFormatTransformOpCPU::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
for (auto node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
continue;
}
auto iter = kChannelLastKernel.find(AnfAlgo::GetCNodeName(node));
if (iter == kChannelLastKernel.end()) {
continue;
}
auto origin_format = AnfAlgo::GetOriginDataFormat(node);
if (origin_format == kOpFormat_DEFAULT) {
origin_format = kOpFormat_ChannelFirst;
}
InsertTransformOpForInput(graph, node, origin_format);
InsertTransformOpForOutput(graph, node, origin_format);
}
return true;
}
} // namespace opt
} // namespace mindspore

+ 35
- 0
mindspore/ccsrc/backend/optimizer/cpu/insert_format_transform_op.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "ir/anf.h"
namespace mindspore {
namespace opt {
class InsertFormatTransformOpCPU : public Pass {
public:
explicit InsertFormatTransformOpCPU(const std::string &name) : Pass("insert_format_transform_op_cpu") {}
~InsertFormatTransformOpCPU() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_CPU_INSERT_FORMAT_TRANSFORM_OP_H

+ 4
- 1
mindspore/ccsrc/backend/session/cpu_session.cc View File

@@ -28,6 +28,7 @@
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/cpu/insert_cast_cpu.h"
#include "backend/optimizer/cpu/insert_format_transform_op.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#include "backend/optimizer/pass/erase_visit_attr.h"
#include "debug/anf_ir_dump.h"
@@ -87,8 +88,10 @@ void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
}
#endif
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
MS_LOG(INFO) << "insert cast pass";
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());

optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();


+ 18
- 2
mindspore/ccsrc/common/trans.cc View File

@@ -350,6 +350,21 @@ std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
return shape;
}

// change channel-first shape to channel-last shape.
// eg. [2,3,4] => [2,4,3]; [2,3,4,5] => [2,4,5,3]
std::vector<size_t> ChannelLastDeviceShape(const std::vector<size_t> &shape) {
auto dim = shape.size();
std::vector<int64_t> axis;
axis.resize(dim);
std::iota(axis.begin() + 1, axis.end(), 2);
axis[dim - 1] = 1;

std::vector<size_t> device_shape;
std::transform(axis.begin(), axis.end(), std::back_inserter(device_shape), [&shape](int n) { return shape[n]; });

return device_shape;
}

std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape) {
std::vector<size_t> shape_4d(kNchwDims, 1);
switch (shape.size()) {
@@ -381,7 +396,7 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) {
if (shape_size == 0) {
return false;
}
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) {
if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ || format == kOpFormat_ChannelLast) {
return false;
} else if (shape_size < kNchwDims) {
return true;
@@ -555,6 +570,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
{kOpFormat_NCDHW, NcdhwDeviceShape},
{kOpFormat_ChannelLast, ChannelLastDeviceShape},
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};

@@ -592,7 +608,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
device_shape.push_back(kCubeSize);
return device_shape;
}
if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
if (format != kOpFormat_ChannelLast && shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) {
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
temp_shape = PaddingShapeTo4dByDefault(shape);
}


+ 2
- 0
mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc View File

@@ -24,6 +24,7 @@
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/common/pass_manager.h"
#include "backend/optimizer/cpu/insert_cast_cpu.h"
#include "backend/optimizer/cpu/insert_format_transform_op.h"
#include "backend/optimizer/pass/replace_node_by_proxy.h"
#include "backend/optimizer/pass/erase_visit_attr.h"

@@ -89,6 +90,7 @@ void CPUDeviceContext::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertCastCPU>());
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
pm->AddPass(std::make_shared<opt::EraseVisitAttr>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(graph);


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -471,6 +471,8 @@ constexpr auto kLoadRealInput = 1;
constexpr auto kLoadStateInput = 2;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_ChannelFirst = "ChannelFirst";
constexpr auto kOpFormat_ChannelLast = "ChannelLast";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
constexpr auto kOpFormat_ND = "ND";
constexpr auto kOpFormat_NCHW = "NCHW";


Loading…
Cancel
Save