| @@ -90,6 +90,7 @@ void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel | |||||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | ||||
| @@ -126,6 +127,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "kernel/oplib/oplib.h" | #include "kernel/oplib/oplib.h" | ||||
| #include "session/anf_runtime_algorithm.h" | #include "session/anf_runtime_algorithm.h" | ||||
| #include "session/kernel_graph.h" | #include "session/kernel_graph.h" | ||||
| #include "pre_activate/common/helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -168,11 +169,18 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| const BaseRef DealRefTransAndCast::DefinePattern() const { | |||||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | |||||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||||
| return VectorRef({V, Xs}); | |||||
| } | |||||
| const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| if (node == nullptr || !node->isa<CNode>()) { | if (node == nullptr || !node->isa<CNode>()) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (!AnfAlgo::IsRealCNodeKernel(cnode)) { | if (!AnfAlgo::IsRealCNodeKernel(cnode)) { | ||||
| @@ -28,6 +28,7 @@ class DealRefTransAndCast : public PatternProcessPass { | |||||
| public: | public: | ||||
| explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} | explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} | ||||
| ~DealRefTransAndCast() override = default; | ~DealRefTransAndCast() override = default; | ||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -45,6 +45,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { | |||||
| bool change = (new_node != nullptr); | bool change = (new_node != nullptr); | ||||
| if (new_node != nullptr && new_node != node) { | if (new_node != nullptr && new_node != node) { | ||||
| (void)manager->Replace(node, new_node); | (void)manager->Replace(node, new_node); | ||||
| (void)seen_node.erase(node); | |||||
| } else if (new_node == nullptr) { | } else if (new_node == nullptr) { | ||||
| new_node = node; | new_node = node; | ||||
| } | } | ||||
| @@ -46,11 +46,13 @@ from mindspore.ops.op_info_register import op_info_register | |||||
| "dtype": [ | "dtype": [ | ||||
| "bool", | "bool", | ||||
| "float","float","float","float","float","float","float","float","float","float", | "float","float","float","float","float","float","float","float","float","float", | ||||
| "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" | |||||
| "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", | |||||
| "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" | |||||
| ], | ], | ||||
| "format": [ | "format": [ | ||||
| "DefaultFormat", | "DefaultFormat", | ||||
| "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", | "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", | ||||
| "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ", | |||||
| "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ" | "DefaultFormat","DefaultFormat","DefaultFormat","FracZ","FRACTAL_NZ","NC1HWC0","HWCN","HWCN","C1HWNCoC0","FracZ" | ||||
| ], | ], | ||||
| "name": "src", | "name": "src", | ||||
| @@ -65,11 +67,13 @@ from mindspore.ops.op_info_register import op_info_register | |||||
| "dtype": [ | "dtype": [ | ||||
| "bool", | "bool", | ||||
| "float","float","float","float","float","float","float","float","float","float", | "float","float","float","float","float","float","float","float","float","float", | ||||
| "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16" | |||||
| "float16","float16","float16","float16","float16","float16","float16","float16","float16","float16", | |||||
| "uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16","uint16" | |||||
| ], | ], | ||||
| "format": [ | "format": [ | ||||
| "NC1HWC0", | "NC1HWC0", | ||||
| "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", | "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", | ||||
| "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN", | |||||
| "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN" | "NC1HWC0","FRACTAL_NZ","FracZ","DefaultFormat","DefaultFormat","DefaultFormat","FracZ","C1HWNCoC0","HWCN","HWCN" | ||||
| ], | ], | ||||
| "name": "dst", | "name": "dst", | ||||