Merge pull request !3775 from lianliguang/test-xiu-bugtags/v1.0.0
| @@ -51,33 +51,19 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i | |||
| AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { | |||
| AnfNodePtr trans_node = nullptr; | |||
| AnfNodePtr input_node = nullptr; | |||
| CNodePtr trans_data = nullptr; | |||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; | |||
| std::vector<Axis> padding_axis; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // if insert transdata for input we need to change the input | |||
| if (is_insert_input) { | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | |||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | |||
| padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | |||
| } else { | |||
| input_node = node; | |||
| padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||
| } | |||
| // Init | |||
| AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | |||
| std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT; | |||
| std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) | |||
| : AnfAlgo::GetOutputReshapeType(node, insert_index); | |||
| auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) | |||
| : AnfAlgo::GetOutputInferShape(input_node, insert_index); | |||
| bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) | |||
| : trans::IsNeedPadding(input_format, input_node_out_shape.size()); | |||
| auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0); | |||
| bool need_padding = false; | |||
| if (is_insert_input) { | |||
| need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size())); | |||
| } else { | |||
| need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size())); | |||
| } | |||
| if (!need_padding) { | |||
| // don't need padding insert transdata only | |||
| trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); | |||
| @@ -89,6 +75,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); | |||
| trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); | |||
| trans_node = trans_data; | |||
| trans_data->set_abstract(input_node->abstract()); | |||
| } else { | |||
| // if need padding & is output need insert a transdata | |||
| // node -> transdata[padding shape] -> reshape[ori_shape] | |||
| @@ -303,7 +290,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod | |||
| const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); | |||
| TypeId origin_type(kTypeUnknown); | |||
| auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); | |||
| auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0); | |||
| auto real_input_node = kernel_with_index.first; | |||
| if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { | |||
| // weight | |||
| @@ -28,7 +28,8 @@ namespace opt { | |||
| class RectifyDoMaskKernelInfo : public PatternProcessPass { | |||
| public: | |||
| explicit RectifyDoMaskKernelInfo(bool multigraph = true) | |||
| : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared<KernelSelect>()) {} | |||
| : PatternProcessPass("rectify_do_mask_kernel_info", multigraph), | |||
| kernel_selecter(std::make_shared<KernelSelect>()) {} | |||
| ~RectifyDoMaskKernelInfo() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| @@ -87,6 +87,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| new_transdata_node = | |||
| NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); | |||
| new_transdata_node->set_abstract(node->abstract()); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| @@ -19,6 +19,8 @@ | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -32,21 +34,21 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons | |||
| const EquivPtr &equiv) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(equiv); | |||
| auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_op_1); | |||
| auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(out_reshape); | |||
| // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly | |||
| if (IsUsedByOthers(func_graph, reshape_op_1)) { | |||
| if (IsUsedByOthers(func_graph, out_reshape)) { | |||
| return nullptr; | |||
| } | |||
| auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_op_2); | |||
| if (IsUsedByOthers(func_graph, reshape_op_2)) { | |||
| auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(in_reshape); | |||
| if (IsUsedByOthers(func_graph, in_reshape)) { | |||
| return nullptr; | |||
| } | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); | |||
| if (input_shape == output_shape) { | |||
| auto input_node = reshape_op_2->input(1); | |||
| auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0); | |||
| auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0); | |||
| if (kernel::IsSameShape(input_shape, output_shape)) { | |||
| auto input_node = AnfAlgo::GetInputNode(in_reshape, 0); | |||
| return input_node; | |||
| } | |||
| return nullptr; | |||
| @@ -71,7 +71,8 @@ bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { | |||
| bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { | |||
| return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && | |||
| AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); | |||
| AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) && | |||
| kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0)); | |||
| } | |||
| const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, | |||
| @@ -106,12 +106,12 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; | |||
| // if node is a value node, no need sync addr from device to host | |||
| if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| return value_node->value(); | |||
| } | |||
| if (!AnfAlgo::OutputAddrExist(node, output_index)) { | |||
| if (node->isa<ValueNode>()) { | |||
| auto value_node = node->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| return value_node->value(); | |||
| } | |||
| if (node->isa<Parameter>()) { | |||
| for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { | |||
| if (input_idx >= input_tensors.size()) { | |||
| @@ -252,6 +252,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()}); | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()}); | |||
| kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()}); | |||
| AnfAlgo::SetOutputAddr(device_address, 0, param.get()); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); | |||
| // construct abstract of parameter | |||
| @@ -481,13 +481,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co | |||
| if (op_info != nullptr) { | |||
| is_ref = op_info->is_ref(); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||
| if (MsContext::GetInstance()->execution_mode() == kPynativeMode && | |||
| AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown && | |||
| AnfAlgo::OutputAddrExist(real_input_node, 0)) { | |||
| if (AnfAlgo::OutputAddrExist(real_input_node, 0)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { | |||