Browse Source

fix momentum-cast fusion

tags/v1.1.0
VectorSL 5 years ago
parent
commit
54a496edbc
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc

+ 6
- 2
mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc View File

@@ -39,10 +39,14 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3); auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
MS_EXCEPTION_IF_NULL(grad_cast); MS_EXCEPTION_IF_NULL(grad_cast);
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
// momentum only support fp32/fp16 by now, do nothing if not.
if (src != kNumberTypeFloat16 || src != kNumberTypeFloat32) {
return nullptr;
}
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
MS_EXCEPTION_IF_NULL(grad); MS_EXCEPTION_IF_NULL(grad);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad)); manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));


Loading…
Cancel
Save