|
|
@@ -39,10 +39,14 @@ const AnfNodePtr ReplaceMomentumCastFusion::Process(const FuncGraphPtr &graph, c |
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
MS_EXCEPTION_IF_NULL(equiv);
|
|
|
|
|
|
|
|
|
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
|
|
|
auto grad_cast = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 3);
|
|
|
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(grad_cast);
|
|
|
MS_EXCEPTION_IF_NULL(grad_cast);
|
|
|
|
|
|
auto src = AnfAlgo::GetPrevNodeOutputInferDataType(grad_cast, 0);
|
|
|
|
|
|
// momentum only support fp32/fp16 by now, do nothing if not.
|
|
|
|
|
|
if (src != kNumberTypeFloat16 || src != kNumberTypeFloat32) {
|
|
|
|
|
|
return nullptr;
|
|
|
|
|
|
}
|
|
|
|
|
|
auto grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(grad_cast), 0);
|
|
|
MS_EXCEPTION_IF_NULL(grad);
|
|
|
MS_EXCEPTION_IF_NULL(grad);
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
auto manager = graph->manager();
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));
|
|
|
manager->Replace(utils::cast<CNodePtr>(grad_cast), utils::cast<CNodePtr>(grad));
|
|
|
|