From 54a496edbc8b80e2cef336079ea9b8ab6efccbe4 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Mon, 14 Dec 2020 17:40:17 +0800 Subject: [PATCH] fix momentum-cast fusion --- .../backend/optimizer/gpu/replace_momentum_cast_fusion.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc index 864bb026af..be3146ca00 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/replace_momentum_cast_fusion.cc @@ -39,10 +39,14 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c MS_EXCEPTION_IF_NULL(equiv); auto grad_cast = AnfAlgo::GetInputNode(utils::cast(node), 3); - auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); MS_EXCEPTION_IF_NULL(grad_cast); + auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0); + // momentum only support fp32/fp16 by now, do nothing if not. + if (src != kNumberTypeFloat16 || src != kNumberTypeFloat32) { + return nullptr; + } + auto grad = AnfAlgo::GetInputNode(utils::cast(grad_cast), 0); MS_EXCEPTION_IF_NULL(grad); - auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); manager->Replace(utils::cast(grad_cast), utils::cast(grad));