|
|
@@ -0,0 +1,204 @@ |
|
|
|
|
|
/**
|
|
|
|
|
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
|
|
|
*
|
|
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
|
|
*
|
|
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
*
|
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
#include "backend/optimizer/cpu/insert_format_transform_op.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
#include <numeric>
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
#include <string>
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
|
#include "backend/session/kernel_graph.h"
|
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
|
namespace opt {
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
|
|
constexpr int kMinDimNeedToTransform = 3;
|
|
|
|
|
|
enum FormatTransformDir { ChannelFisrt2ChannelLast = 0, ChannelLast2ChannelFirst };
|
|
|
|
|
|
|
|
|
|
|
|
// get perm between channel-first shape and channel-last shape.
|
|
|
|
|
|
// eg. 4D channe-first => channel-last: [0,1,2,3] => [0,2,3,1];
|
|
|
|
|
|
// eg. 4D channe-last => channel-first: [0,1,2,3] => [0,3,1,2];
|
|
|
|
|
|
std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) {
|
|
|
|
|
|
std::vector<int64_t> axis;
|
|
|
|
|
|
axis.resize(dim);
|
|
|
|
|
|
if (dir == ChannelFisrt2ChannelLast) {
|
|
|
|
|
|
std::iota(axis.begin() + 1, axis.end(), 2);
|
|
|
|
|
|
axis[dim - 1] = 1;
|
|
|
|
|
|
} else {
|
|
|
|
|
|
std::iota(axis.begin() + 2, axis.end(), 1);
|
|
|
|
|
|
axis[1] = dim - 1;
|
|
|
|
|
|
}
|
|
|
|
|
|
return axis;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
|
|
|
|
|
int used_node_index, const std::vector<int64_t> &transpose_perm) {
|
|
|
|
|
|
MS_LOG(ERROR) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
|
|
|
|
|
<< ", index: " << used_node_index;
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
// 1.Create a transpose node or a fake transpose node:reshape.
|
|
|
|
|
|
auto primitive_ptr = prim::kPrimTranspose;
|
|
|
|
|
|
auto transpose_prim = std::make_shared<Primitive>(primitive_ptr->name());
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(transpose_prim);
|
|
|
|
|
|
// 2.Set the input of transpose.
|
|
|
|
|
|
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
|
|
|
|
|
auto transpose_op = graph->NewCNode(transpose_input);
|
|
|
|
|
|
// 3.Set the output info of transpose.
|
|
|
|
|
|
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
|
|
|
|
|
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
|
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
|
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
|
|
|
|
|
// 4. Set the new edge of transpose op.
|
|
|
|
|
|
FuncGraphManagerPtr manager = graph->manager();
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
|
manager->SetEdge(used_node, used_node_index + 1, transpose_op);
|
|
|
|
|
|
return transpose_op;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
|
|
|
|
|
|
const AnfNodePtr &node) {
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
|
auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputInferDataType(node, 0);
|
|
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
|
|
|
|
builder.SetInputsFormat({input_format});
|
|
|
|
|
|
builder.SetInputsDeviceType({input_type});
|
|
|
|
|
|
builder.SetOutputsFormat({output_format});
|
|
|
|
|
|
builder.SetOutputsDeviceType({output_type});
|
|
|
|
|
|
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
|
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
|
|
|
|
|
|
const std::vector<int64_t> &transpose_perm, const std::string &transpose_format) {
|
|
|
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
|
|
|
|
|
|
for (size_t i = 0; i < used_node_list->size(); i++) {
|
|
|
|
|
|
auto used_node = used_node_list->at(i).first;
|
|
|
|
|
|
auto used_node_index = used_node_list->at(i).second - 1;
|
|
|
|
|
|
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item.";
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// node->used_node, if output format of node equals input format of used_node,
|
|
|
|
|
|
// then no need to insert transpose between node and used_node.
|
|
|
|
|
|
auto used_node_in_format =
|
|
|
|
|
|
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
|
|
|
|
|
|
if (transpose_format == used_node_in_format) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
|
|
|
|
|
|
SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op);
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void InsertTransformOpForInput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) {
|
|
|
|
|
|
auto inputs_format = AnfAlgo::GetAllInputFormats(node);
|
|
|
|
|
|
for (size_t i = 0; i < inputs_format.size(); ++i) {
|
|
|
|
|
|
if ((inputs_format[i] == kOpFormat_DEFAULT) || (inputs_format[i] == origin_format)) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(node, i);
|
|
|
|
|
|
if (inputs_format[i] == prev_input_format) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
|
|
|
|
|
|
auto dim = in_shape.size();
|
|
|
|
|
|
if (dim < kMinDimNeedToTransform) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
|
|
auto transpose_perm = TransposeAxis(dim, ChannelFisrt2ChannelLast);
|
|
|
|
|
|
auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
|
|
|
|
|
|
SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Insert output transpose from output_format to origin_format.
|
|
|
|
|
|
void InsertTransformOpForOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &origin_format) {
|
|
|
|
|
|
auto outputs_format = AnfAlgo::GetAllOutputFormats(node);
|
|
|
|
|
|
for (size_t i = 0; i < outputs_format.size(); ++i) {
|
|
|
|
|
|
if ((outputs_format[i] == kOpFormat_DEFAULT) || (outputs_format[i] == origin_format)) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto out_shape = AnfAlgo::GetOutputInferShape(node, i);
|
|
|
|
|
|
auto dim = out_shape.size();
|
|
|
|
|
|
if (dim < kMinDimNeedToTransform) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto transpose_perm = TransposeAxis(dim, ChannelLast2ChannelFirst);
|
|
|
|
|
|
// Find all nodes connected with node output, and change their inputs to transpose.
|
|
|
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
|
|
|
|
|
|
for (size_t j = 0; j < used_node_list->size(); ++j) {
|
|
|
|
|
|
auto used_node = used_node_list->at(j).first;
|
|
|
|
|
|
auto used_node_index = used_node_list->at(j).second - 1;
|
|
|
|
|
|
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
|
MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item.";
|
|
|
|
|
|
// The tuple item need get next used nodes again.
|
|
|
|
|
|
ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]);
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
// node->used_node, if output format of node equals input format of used_node,
|
|
|
|
|
|
// then no need to insert transpose between node and used_node.
|
|
|
|
|
|
auto used_node_in_format =
|
|
|
|
|
|
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
|
|
|
|
|
|
if (outputs_format[i] == used_node_in_format) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
|
|
|
|
|
|
SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op);
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
const std::unordered_set<std::string> kChannelLastKernel = {prim::kPrimBiasAdd->name()};
|
|
|
|
|
|
|
|
|
|
|
|
bool InsertFormatTransformOpCPU::Run(const FuncGraphPtr &graph) {
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
|
|
|
|
|
|
|
|
|
|
|
|
for (auto node : node_list) {
|
|
|
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto iter = kChannelLastKernel.find(AnfAlgo::GetCNodeName(node));
|
|
|
|
|
|
if (iter == kChannelLastKernel.end()) {
|
|
|
|
|
|
continue;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto origin_format = AnfAlgo::GetOriginDataFormat(node);
|
|
|
|
|
|
if (origin_format == kOpFormat_DEFAULT) {
|
|
|
|
|
|
origin_format = kOpFormat_ChannelFirst;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
InsertTransformOpForInput(graph, node, origin_format);
|
|
|
|
|
|
InsertTransformOpForOutput(graph, node, origin_format);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
}
|
|
|
|
|
|
} // namespace opt
|
|
|
|
|
|
} // namespace mindspore
|