Browse Source

!9916 GPU fix momentum fusion

From: @VectorSL
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c98e6b19d6
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc

+ 6
- 2
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc View File

@@ -39,10 +39,14 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(equiv);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
MS_EXCEPTION_IF_NULL(grad_cast);
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
// momentum only support fp32/fp16 by now, do nothing if not.
if (src != kNumberTypeFloat16 || src != kNumberTypeFloat32) {
return nullptr;
}
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
MS_EXCEPTION_IF_NULL(grad);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));


Loading…
Cancel
Save