From: @liubuyu Reviewed-by: @kisnwang,@jjfeing Signed-off-by: @jjfeingpull/14583/MERGE
| @@ -144,6 +144,5 @@ MS_REG_GPU_KERNEL_THREE(RandomCategorical, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| RandomCategoricalGpuKernel, double, int64_t, int64_t) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -68,7 +68,7 @@ size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } | |||
| std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||
| if (input_reshape_type_.empty()) { | |||
| return {}; | |||
| return ""; | |||
| } | |||
| if (input_index >= input_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " | |||
| @@ -79,7 +79,7 @@ std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const { | |||
| std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { | |||
| if (output_reshape_type_.empty()) { | |||
| return {}; | |||
| return ""; | |||
| } | |||
| if (output_index >= output_reshape_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " | |||
| @@ -218,6 +218,9 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr<AnfNode> | |||
| if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { | |||
| def_format = kOpFormat_NCDHW; | |||
| } | |||
| if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| @@ -446,6 +449,9 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod | |||
| std::vector<int64_t> ori_shape; | |||
| AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape)); | |||
| // std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); | |||
| if (def_format == kOpFormat_NCDHW && k3DFormatSet.find(format) == k3DFormatSet.end()) { | |||
| format = kOpFormat_NCDHW; | |||
| } | |||
| if (ori_shape.empty()) { | |||
| ori_shape.emplace_back(1); | |||
| } | |||
| @@ -626,7 +632,7 @@ void TbeKernelJsonCreator::ParseAttrDefaultValue(const std::string &type, const | |||
| } else if (type == kVTypeStr) { | |||
| (*attr_obj)[kJValue] = value; | |||
| } else if (type == kVTypeBool) { | |||
| bool attr_value; | |||
| bool attr_value = false; | |||
| std::istringstream(value) >> std::boolalpha >> attr_value; | |||
| (*attr_obj)[kJValue] = attr_value; | |||
| } else if (type == kVTypeFloat) { | |||
| @@ -384,7 +384,7 @@ void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io | |||
| std::vector<std::string> TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { | |||
| const std::map<std::string, std::string> kDynamicFormatMap = { | |||
| {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; | |||
| {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}, {"NCDHW", "DefaultFormat"}}; | |||
| if (op_select_json_item.empty()) { | |||
| MS_LOG(EXCEPTION) << "Op select ret item is null."; | |||
| } | |||
| @@ -67,7 +67,7 @@ | |||
| #include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_trans_op.h" | |||
| #include "backend/optimizer/ascend/format_type/add_reformat_op.h" | |||
| #include "backend/optimizer/ascend/format_type/trans_op_format_refine.h" | |||
| #include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | |||
| @@ -259,6 +259,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<TransOpFormatRefine>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); | |||
| optimizer->AddPassManager(mixed_precision_pm); | |||
| @@ -387,7 +389,6 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| ConfigManager::GetInstance().iter_num() > 1) { | |||
| other2_pm->AddPass(std::make_shared<GetnextMemcpyElimination>()); | |||
| } | |||
| other2_pm->AddPass(std::make_shared<AddReFormatOp>()); | |||
| other2_pm->AddPass(std::make_shared<CheckConsistency>()); | |||
| optimizer2->AddPassManager(other2_pm); | |||
| (void)optimizer2->Optimize(kernel_graph); | |||
| @@ -64,33 +64,6 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { | |||
| } | |||
| } | |||
| std::string InitDefaultFormat(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) { | |||
| auto primitive_ptr = GetCNodePrimitive(node); | |||
| MS_EXCEPTION_IF_NULL(primitive_ptr); | |||
| auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (result && data_format == Format::NCDHW) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| } else if (AnfAlgo::IsRealKernel(node)) { | |||
| auto formats = AnfAlgo::GetAllOutputFormats(node); | |||
| if (std::any_of(formats.begin(), formats.end(), | |||
| [](const std::string &format) { return k3DFormatSet.find(format) != k3DFormatSet.end(); })) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| } else { | |||
| auto format = AnfAlgo::GetOutputFormat(node, 0); | |||
| if (k3DFormatSet.find(format) != k3DFormatSet.end()) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| } | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| void ReFreshInferShape(const AnfNodePtr &trans_node, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(trans_node); | |||
| auto real_input_node = AnfAlgo::VisitKernelWithReturnType(node, 0).first; | |||
| @@ -183,7 +156,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| CNodePtr trans_data = nullptr; | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // Init | |||
| std::string default_format = InitDefaultFormat(node); | |||
| std::string default_format = kOpFormat_DEFAULT; | |||
| AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; | |||
| std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); | |||
| std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; | |||
| @@ -1,136 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/add_reformat_op.h" | |||
| #include <memory> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| #include "base/core_ops.h" | |||
| #include "runtime/device/kernel_info.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| using KernelWithIndex = std::pair<AnfNodePtr, size_t>; | |||
| namespace { | |||
| AnfNodePtr InsertReFormatOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &in_node, | |||
| size_t idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> reformat_inputs; | |||
| auto node_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(node); | |||
| MS_EXCEPTION_IF_NULL(node_kernel_build_info); | |||
| auto reformat_prim = std::make_shared<Primitive>(prim::kPrimReformat->name()); | |||
| reformat_inputs.push_back(NewValueNode(reformat_prim)); | |||
| reformat_inputs.push_back(in_node); | |||
| auto reformat = func_graph->NewCNode(reformat_inputs); | |||
| auto reformat_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| reformat_builder->SetInputsFormat({AnfAlgo::GetPrevNodeOutputFormat(node, idx)}); | |||
| reformat_builder->SetOutputsFormat({AnfAlgo::GetInputFormat(node, idx)}); | |||
| reformat_builder->SetInputsDeviceType({AnfAlgo::GetPrevNodeOutputDeviceDataType(node, idx)}); | |||
| reformat_builder->SetOutputsDeviceType({node_kernel_build_info->GetInputDeviceType(idx)}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(reformat_builder->Build(), reformat.get()); | |||
| reformat->set_abstract(in_node->abstract()); | |||
| AnfAlgo::SetNodeAttr("nop_op", MakeValue(true), reformat); | |||
| return reformat; | |||
| } | |||
| bool NeedInsert(const CNodePtr &cnode, const size_t input_index) { | |||
| KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, input_index); | |||
| auto real_input_node = kernel_with_index.first; | |||
| auto idx = kernel_with_index.second; | |||
| auto input_format = AnfAlgo::GetInputFormat(cnode, input_index); | |||
| auto prev_format = AnfAlgo::GetOutputFormat(real_input_node, idx); | |||
| bool flag_format = (input_format != prev_format); | |||
| if (!flag_format) { | |||
| return false; | |||
| } | |||
| bool flag_shape = true; | |||
| auto input_origin_shape = AnfAlgo::GetOutputInferShape(real_input_node, idx); | |||
| if (prev_format == kOpFormat_DEFAULT || input_format == kOpFormat_DEFAULT) { | |||
| string checking_format = (prev_format == kOpFormat_DEFAULT) ? input_format : prev_format; | |||
| // when input shape size is 1D, default format and NC1HWC0 are compatible | |||
| if (input_origin_shape.size() == 1 && checking_format == kOpFormat_NC1HWC0) { | |||
| flag_shape = false; | |||
| } | |||
| if (kDefaultCompatibleFormat.find(checking_format) != kDefaultCompatibleFormat.end()) { | |||
| flag_shape = false; | |||
| } | |||
| } | |||
| if (input_origin_shape.size() == 0) { | |||
| flag_shape = false; | |||
| } | |||
| return flag_format && flag_shape; | |||
| } | |||
| AnfNodePtr NeedInSertReformatOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto in_nums = AnfAlgo::GetInputTensorNum(cnode); | |||
| bool need_insert = false; | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; | |||
| for (size_t i = 0; i < in_nums; i++) { | |||
| auto input_node = AnfAlgo::GetInputNode(cnode, i); | |||
| if (NeedInsert(cnode, i)) { | |||
| need_insert = true; | |||
| auto re_format = InsertReFormatOp(func_graph, cnode, input_node, i); | |||
| new_inputs.push_back(re_format); | |||
| continue; | |||
| } | |||
| new_inputs.push_back(input_node); | |||
| } | |||
| if (need_insert) { | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_node = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_node = std::make_shared<CNode>(*cnode); | |||
| } else { | |||
| new_node = kernel_graph->NewCNode(cnode); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| new_node->set_inputs(new_inputs); | |||
| AnfAlgo::CopyNodeAttrs(cnode, new_node); | |||
| return new_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace | |||
| bool AddReFormatOp::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| bool changed = false; | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| for (auto &node : node_list) { | |||
| auto new_node = NeedInSertReformatOp(func_graph, node); | |||
| if (new_node != nullptr) { | |||
| manager->Replace(node, new_node); | |||
| changed = true; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/trans_op_format_refine.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef TransOpFormatRefine::DefinePattern() const { | |||
| std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited); | |||
| std::shared_ptr<Var> Vs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Vs}); | |||
| } | |||
| const AnfNodePtr TransOpFormatRefine::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| auto op_name = AnfAlgo::GetCNodeName(node); | |||
| if (op_name == kTransDataOpName) { | |||
| auto in_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto out_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||
| if (in_format == kOpFormat_DEFAULT && k3DFormatSet.find(out_format) != k3DFormatSet.end()) { | |||
| builder->SetInputsFormat({kOpFormat_NCDHW}); | |||
| builder->SetOutputsFormat({out_format}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(kOpFormat_NCDHW), node); | |||
| } | |||
| if (out_format == kOpFormat_DEFAULT && k3DFormatSet.find(in_format) != k3DFormatSet.end()) { | |||
| builder->SetInputsFormat({in_format}); | |||
| builder->SetOutputsFormat({kOpFormat_NCDHW}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(kOpFormat_NCDHW), node); | |||
| } | |||
| } | |||
| return node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -14,27 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||
| #include <vector> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_TRANS_OP_FORMAT_REFINE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_TRANS_OP_FORMAT_REFINE_H_ | |||
| #include <string> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AddReFormatOp : public Pass { | |||
| class TransOpFormatRefine : public PatternProcessPass { | |||
| public: | |||
| explicit AddReFormatOp(size_t groups = 1) : Pass("add_reformat_op"), groups_(groups) {} | |||
| ~AddReFormatOp() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| size_t groups_ = 1; | |||
| explicit TransOpFormatRefine(bool multigraph = true) : PatternProcessPass("trans_op_format_refine", multigraph) {} | |||
| ~TransOpFormatRefine() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_TRANS_OP_FORMAT_REFINE_H_ | |||
| @@ -713,7 +713,7 @@ std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, siz | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| if (build_info->IsInputDefaultPadding()) { | |||
| return {}; | |||
| return ""; | |||
| } | |||
| return build_info->GetInputReshapeType(input_idx); | |||
| } | |||
| @@ -733,7 +733,7 @@ std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, si | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| if (build_info->IsOutputDefaultPadding()) { | |||
| return {}; | |||
| return ""; | |||
| } | |||
| return build_info->GetOutputReshapeType(output_idx); | |||
| } | |||
| @@ -429,31 +429,6 @@ void KernelGraph::CheckLoop() { | |||
| } | |||
| } | |||
| void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (AnfAlgo::OutputAddrExist(node, 0)) { | |||
| return; | |||
| } | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_builder); | |||
| kernel_build_info_builder->SetOutputsFormat({format}); | |||
| kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)}); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); | |||
| } | |||
| void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto in_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), i); | |||
| MS_EXCEPTION_IF_NULL(in_node); | |||
| if ((in_node->isa<Parameter>() || in_node->isa<ValueNode>()) && | |||
| AnfAlgo::GetOutputInferShape(in_node, 0).size() == k5dDims) { | |||
| ReSetParameterValueNodeFormatAndType(in_node, format); | |||
| } | |||
| } | |||
| } | |||
| CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| auto cnode = FuncGraph::NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -463,17 +438,6 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); | |||
| } | |||
| SetKernelInfoForNode(cnode); | |||
| if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) { | |||
| auto primitive_ptr = GetCNodePrimitive(cnode); | |||
| MS_EXCEPTION_IF_NULL(primitive_ptr); | |||
| auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (result && data_format == Format::NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| } | |||
| } | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| return cnode; | |||
| } | |||
| @@ -281,7 +281,6 @@ class KernelGraph : public FuncGraph { | |||
| // remove value node form graph | |||
| bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); | |||
| void SetKernelInfoForNode(const AnfNodePtr &node) const; | |||
| void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; | |||
| AnfNodePtr MakeValueNode(const AnfNodePtr &node); | |||
| void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue, | |||
| std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); | |||
| @@ -57,7 +57,6 @@ void AiCoreDynamicKernel::Execute() { | |||
| args_size, l2ctrl, stream_, kernel_info.c_str())) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunchWithHandle error."; | |||
| } | |||
| } else { | |||
| if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error."; | |||
| @@ -61,7 +61,7 @@ class AiCoreDynamicKernel : public DynamicKernel { | |||
| void ParseCompileJson(); | |||
| private: | |||
| const void *stub_func_; | |||
| const void *stub_func_{nullptr}; | |||
| void *handle_{nullptr}; | |||
| uint32_t block_dim_; | |||
| void *tiling_data_ptr_; // device ptr | |||
| @@ -271,6 +271,7 @@ constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | |||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | |||
| constexpr auto kMatMulOpName = "MatMul"; | |||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | |||
| constexpr auto kBatchMatMulOpName = "BatchMatMul"; | |||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | |||
| constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | |||
| constexpr auto kFusedAddReluGradV2Name = "FusedAddReluGradV2"; | |||
| @@ -486,7 +487,8 @@ const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_N | |||
| kOpFormat_NDC1HWC0, kOpFormat_NCDHW, | |||
| kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, | |||
| kOpFormat_DHWCN}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; | |||
| const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, | |||
| kOpFormat_NCDHW}; | |||
| const std::set<std::string> kOptOperatorSet = {kMomentumOpName, | |||
| kApplyMomentumOpName, | |||
| kApplyAdadeltaOpName, | |||
| @@ -44,7 +44,6 @@ class Squeeze : public PrimitiveC { | |||
| AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimSqueezePtr = std::shared_ptr<Squeeze>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -23,6 +23,7 @@ slice_op_info = TBERegOp("Slice") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("slice_d") \ | |||
| .partial_flag(True) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .attr("begin", "required", "listInt", "all") \ | |||
| .attr("size", "required", "listInt", "all") \ | |||
| .input(0, "x", False, "required", "all") \ | |||