diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc index 864bb026af..be3146ca00 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc @@ -39,10 +39,14 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c MS_EXCEPTION_IF_NULL(equiv); auto grad_cast = AnfAlgo::GetInputNode(utils::cast(node), 3); - auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); MS_EXCEPTION_IF_NULL(grad_cast); + auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); + // momentum only support fp32/fp16 by now, do nothing if not. + if (src != kNumberTypeFloat16 || src != kNumberTypeFloat32) { + return nullptr; + } + auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); MS_EXCEPTION_IF_NULL(grad); - auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); manager->Replace(utils::cast(grad_cast), utils::cast(grad));