|
|
|
@@ -70,6 +70,35 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace { |
|
|
|
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { |
|
|
|
MS_EXCEPTION_IF_NULL(ir_fusion_pm); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto optimizer = std::make_shared<GraphOptimizer>(); |
|
|
|
@@ -164,29 +193,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); |
|
|
|
if (context_ptr->ir_fusion_flag()) { |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); |
|
|
|
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); |
|
|
|
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); |
|
|
|
} |
|
|
|
|
|
|
|
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { |
|
|
|
|