|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "pre_activate/ascend/ascend_helper.h"
- #include <set>
- #include "common/trans.h"
- #include "common/utils.h"
- #include "pre_activate/common/helper.h"
- #include "utils/utils.h"
- #include "device/kernel_info.h"
- #include "kernel/oplib/oplib.h"
- #include "kernel/common_utils.h"
- #include "operator/ops.h"
- #include "session/anf_runtime_algorithm.h"
- #include "session/kernel_graph.h"
- #include "utils/context/ms_context.h"
-
- namespace mindspore {
- namespace opt {
- using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
- namespace {
- const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
- AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
- const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
- std::vector<AnfNodePtr> trans_inputs;
- auto prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
- trans_inputs.emplace_back(NewValueNode(prim));
- trans_inputs.emplace_back(input_node);
- auto reshape = func_graph->NewCNode(trans_inputs);
- AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get());
- AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape);
- AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape);
- reshape->set_scope(input_node->scope());
- kernel_select->SelectKernel(reshape);
- return reshape;
- }
-
- AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
- const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
- AnfNodePtr trans_node = nullptr;
- AnfNodePtr input_node = node;
- CNodePtr trans_data = nullptr;
- std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
- std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
- std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
- MS_EXCEPTION_IF_NULL(node);
- // if insert transdata for input we need to change the input
- if (is_insert_input) {
- if (!node->isa<CNode>()) {
- MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
- }
- auto cnode = node->cast<CNodePtr>();
- dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
- input_node = AnfAlgo::GetInputNode(cnode, insert_index);
- padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
- }
- bool need_padding = false;
- if (is_insert_input) {
- need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
- } else {
- need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
- }
- if (!need_padding) {
- // don't need padding insert transdata only
- trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
- trans_node = trans_data;
- } else if (is_insert_input) {
- // if need padding & is input need insert a transdata
- // reshape[padding shape] -> transdata[padding shape] -> node
- auto padding_shape =
- trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
- auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
- trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
- trans_node = trans_data;
- } else {
- // if need padding & is output need insert a transdata
- // node -> transdata[padding shape] -> reshape[ori_shape]
- trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
- auto reshape_node =
- CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
- trans_node = reshape_node;
- }
- // refresh the transdata's format to ori format & dst format
- RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis);
- return trans_node;
- }
-
- AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
- const KernelSelectPtr &kernel_select) {
- MS_EXCEPTION_IF_NULL(node);
- auto input_node = AnfAlgo::GetInputNode(node, index);
- auto node_with_index = AnfAlgo::VisitKernel(input_node, 0);
- MS_EXCEPTION_IF_NULL(node_with_index.first);
- auto real_input = node_with_index.first;
- if (real_input->isa<ValueNode>() || real_input->isa<Parameter>()) {
- input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select);
- MS_EXCEPTION_IF_NULL(input_node);
- AnfAlgo::SetNodeInput(node, input_node, index);
- }
- std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
- std::string dest_format = AnfAlgo::GetInputFormat(node, index);
- if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
- MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
- << " To DefaultFormat , index: " << index;
- return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
- }
- return input_node;
- }
-
- AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
- const KernelSelectPtr &kernel_select) {
- MS_EXCEPTION_IF_NULL(node);
- std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
- std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
- if (output_format == kOpFormat_NC1KHKWHWC0) {
- MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
- << node->DebugString();
- }
- if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
- MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
- return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
- }
- return node;
- }
-
- AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
- const KernelSelectPtr &kernel_select) {
- MS_EXCEPTION_IF_NULL(func_graph);
- MS_EXCEPTION_IF_NULL(node);
- std::vector<AnfNodePtr> make_tuple_inputs;
- make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
- for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
- std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
- if (output_format == kOpFormat_NC1KHKWHWC0) {
- MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node "
- << node->DebugString();
- }
- auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
- std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
- if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
- make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
- } else {
- // No need insert trans op.
- make_tuple_inputs.push_back(tuple_getitem);
- }
- }
- AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs);
- return make_tuple;
- }
- } // namespace
- void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
- const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
- MS_EXCEPTION_IF_NULL(trans_data);
- auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
- MS_EXCEPTION_IF_NULL(ori_build_info);
- auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
- builder->SetInputsFormat({input_format});
- builder->SetInputReshapeType({reshape_type});
- builder->SetOutputReshapeType({reshape_type});
- builder->SetOutputsFormat({output_format});
- AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
- }
-
- CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
- const bool need_padding, const std::string &op_name) {
- MS_EXCEPTION_IF_NULL(func_graph);
- MS_EXCEPTION_IF_NULL(input);
- std::vector<AnfNodePtr> trans_inputs;
- auto prim = std::make_shared<Primitive>(op_name);
- trans_inputs.push_back(NewValueNode(prim));
- trans_inputs.push_back(input);
- CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
- MS_EXCEPTION_IF_NULL(trans_node);
- auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
- if (need_padding) {
- // if need padding we should set the transdata node's shape to the padding shape
- AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
- {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
- trans_node.get());
- } else {
- AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
- {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
- }
- // special handle for ut
- if (trans_node->kernel_info() == nullptr) {
- auto kernel_info = std::make_shared<device::KernelInfo>();
- trans_node->set_kernel_info(kernel_info);
- }
- MS_EXCEPTION_IF_NULL(kernel_select);
- kernel_select->SelectKernel(trans_node);
- AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
- MS_EXCEPTION_IF_NULL(trans_node);
- trans_node->set_scope(input->scope());
- return trans_node;
- }
-
- AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
- const TypeId &input_type, const TypeId &output_type,
- const std::vector<size_t> &origin_shape, const TypeId &origin_type) {
- MS_EXCEPTION_IF_NULL(func_graph);
- std::string input_format = format;
- std::string output_format = format;
- std::vector<AnfNodePtr> new_cast_inputs;
- auto prim = std::make_shared<Primitive>(prim::kPrimCast->name());
- new_cast_inputs.push_back(NewValueNode(prim));
- new_cast_inputs.push_back(input);
- CNodePtr cast = func_graph->NewCNode(new_cast_inputs);
- MS_EXCEPTION_IF_NULL(cast);
- // set kernel build info
- kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
- builder.SetInputsFormat({input_format});
- builder.SetOutputsFormat({output_format});
- builder.SetInputsDeviceType({input_type});
- builder.SetOutputsDeviceType({output_type});
- builder.SetFusionType(kernel::FusionType::OPAQUE);
- builder.SetProcessor(kernel::Processor::AICORE);
- if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) {
- builder.SetKernelType(KernelType::TBE_KERNEL);
- } else {
- builder.SetKernelType(KernelType::AKG_KERNEL);
- }
- // if kernel info is null , it remarks this function is running ut
- if (cast->kernel_info() == nullptr) {
- auto kernel_info = std::make_shared<device::KernelInfo>();
- cast->set_kernel_info(kernel_info);
- }
- AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
- AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
- AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
- return cast;
- }
-
- AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
- const KernelSelectPtr &kernel_select) {
- size_t outputs_num = AnfAlgo::GetOutputTensorNum(node);
- if (outputs_num == 0) {
- return node;
- }
- // Single output
- if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
- return InsertTransOpForSingleOutput(func_graph, node, kernel_select);
- }
- // Multiple output
- return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);
- }
-
- AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
- const KernelSelectPtr &kernel_select) {
- MS_EXCEPTION_IF_NULL(node);
- auto cnode = node->cast<CNodePtr>();
- MS_EXCEPTION_IF_NULL(cnode);
- std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
- for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
- AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select);
- MS_EXCEPTION_IF_NULL(input_node);
- new_inputs.push_back(input_node);
- }
- CNodePtr new_cnode = nullptr;
- // cnode changed so make a new cnode to differ from original one.
- auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
- if (kernel_graph == nullptr) {
- new_cnode = std::make_shared<CNode>(*cnode);
- } else {
- new_cnode = kernel_graph->NewCNode(cnode);
- }
- MS_EXCEPTION_IF_NULL(new_cnode);
- new_cnode->set_inputs(new_inputs);
- return new_cnode;
- }
-
- CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
- MS_EXCEPTION_IF_NULL(cnode);
- std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
- for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
- const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
- TypeId origin_type(kTypeUnknown);
- auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
- auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
- auto real_input_node = kernel_with_index.first;
- if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
- // weight
- origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index);
- if (origin_type == kTypeUnknown) {
- origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index);
- }
- } else {
- // feature map
- origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
- }
- const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
- const std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index);
- const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index);
- // In graph kernel, we check parameter,
- // the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
- if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
- new_inputs.push_back(cur_input);
- } else if (origin_type != device_type) {
- auto cast =
- AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type);
- MS_EXCEPTION_IF_NULL(cast);
- cast->set_scope(cnode->scope());
- AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast);
- new_inputs.push_back(cast);
- } else {
- new_inputs.push_back(cur_input);
- }
- }
- auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
- CNodePtr new_node = nullptr;
- if (kernel_graph == nullptr) {
- new_node = std::make_shared<CNode>(*cnode);
- } else {
- new_node = kernel_graph->NewCNode(cnode);
- }
- MS_EXCEPTION_IF_NULL(new_node);
- new_node->set_inputs(new_inputs);
- return new_node;
- }
-
- AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) {
- MS_EXCEPTION_IF_NULL(graph);
- MS_EXCEPTION_IF_NULL(node);
- auto prim = std::make_shared<Primitive>(kMemCpyAsyncOpName);
- std::vector<AnfNodePtr> new_node_inputs = {NewValueNode(prim), node};
- auto new_node = graph->NewCNode(new_node_inputs);
- MS_EXCEPTION_IF_NULL(new_node);
- new_node->set_abstract(node->abstract());
- new_node->set_scope(node->scope());
- return new_node;
- }
- } // namespace opt
- } // namespace mindspore
|