From 601b0b6e4da6bd53bfb603314363b9f7e7c13a72 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Fri, 21 Aug 2020 15:45:36 +0800 Subject: [PATCH] remove convert datatype when updateoutputs && set parameter device dtype using it's infer dtype && set transdata's abstract --- .../backend/optimizer/ascend/ascend_helper.cc | 37 ++++++------------- .../format_type/rectify_do_mask_kernel_info.h | 3 +- .../ascend/ir_fission/transdata_split.cc | 1 + .../ascend/ir_fusion/remove_reshape_pair.cc | 22 ++++++----- .../optimizer/pass/eliminate_redundant_op.cc | 3 +- .../ccsrc/backend/session/session_basic.cc | 11 +++--- .../device/ascend/kernel_select_ascend.cc | 8 +--- 7 files changed, 36 insertions(+), 49 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 5a174e5966..940661b300 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -51,33 +51,19 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = nullptr; CNodePtr trans_data = nullptr; - std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - std::vector padding_axis; MS_EXCEPTION_IF_NULL(node); - // if insert transdata for input we need to change the input - if (is_insert_input) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; - } - auto cnode = node->cast(); - dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); - input_node = AnfAlgo::GetInputNode(cnode, insert_index); - padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); - } else { - input_node = node; - padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); - } + // Init + AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; + std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT; + std::vector padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) + : AnfAlgo::GetOutputReshapeType(node, insert_index); + auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) + : AnfAlgo::GetOutputInferShape(input_node, insert_index); + bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) + : trans::IsNeedPadding(input_format, input_node_out_shape.size()); - auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); - bool need_padding = false; - if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size())); - } else { - need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size())); - } if (!need_padding) { // don't need padding insert transdata only trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); @@ -89,6 +75,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); trans_node = trans_data; + trans_data->set_abstract(input_node->abstract()); } else { // if need padding & is output need insert a transdata // node -> transdata[padding shape] -> reshape[ori_shape] @@ -303,7 +290,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); TypeId origin_type(kTypeUnknown); auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); - auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); + auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0); auto real_input_node = kernel_with_index.first; if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { // weight diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h index e2b55ae75e..cd676f507f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h @@ -28,7 +28,8 @@ namespace opt { class RectifyDoMaskKernelInfo : public PatternProcessPass { public: explicit RectifyDoMaskKernelInfo(bool multigraph = true) - : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared()) {} + : PatternProcessPass("rectify_do_mask_kernel_info", multigraph), + kernel_selecter(std::make_shared()) {} ~RectifyDoMaskKernelInfo() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc index 057cf8deed..a25ebb8cfd 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -87,6 +87,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n new_transdata_node = NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); + new_transdata_node->set_abstract(node->abstract()); new_replace_node = new_transdata_node; } FuncGraphManagerPtr manager = func_graph->manager(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc index e2c6143927..24010e1858 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc @@ -19,6 +19,8 @@ #include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" #include "base/core_ops.h" +#include "frontend/operator/ops.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace opt { @@ -32,21 +34,21 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_1); + auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(out_reshape); // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly - if (IsUsedByOthers(func_graph, reshape_op_1)) { + if (IsUsedByOthers(func_graph, out_reshape)) { return nullptr; } - auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_2); - if (IsUsedByOthers(func_graph, reshape_op_2)) { + auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(in_reshape); + if (IsUsedByOthers(func_graph, in_reshape)) { return nullptr; } - auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); - auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); - if (input_shape == output_shape) { - auto input_node = reshape_op_2->input(1); + auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0); + if (kernel::IsSameShape(input_shape, output_shape)) { + auto input_node = AnfAlgo::GetInputNode(in_reshape, 0); return input_node; } return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc index fe1b0a3023..115f09fb60 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc @@ -71,7 +71,8 @@ bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && - AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); + AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) && + kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0)); } const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index aafbf01365..5e9863516c 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -106,12 +106,12 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; // if node is a value node, no need sync addr from device to host + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); + } if (!AnfAlgo::OutputAddrExist(node, output_index)) { - if (node->isa()) { - auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - return value_node->value(); - } if (node->isa()) { for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { if (input_idx >= input_tensors.size()) { @@ -252,6 +252,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, kernel_build_info_builder->SetOutputsFormat(std::vector{device_address->format()}); kernel_build_info_builder->SetOutputsDeviceType(std::vector{device_address->type_id()}); kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()}); + AnfAlgo::SetOutputAddr(device_address, 0, param.get()); } AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); // construct abstract of parameter diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index ac37bb5f1a..14c065aa5b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -481,13 +481,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co if (op_info != nullptr) { is_ref = op_info->is_ref(); } - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode && - AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { - continue; - } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown && - AnfAlgo::OutputAddrExist(real_input_node, 0)) { + if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { continue; } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {