|
|
|
@@ -26,6 +26,8 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
constexpr size_t kInputIndex = 1; |
|
|
|
|
|
|
|
bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) { |
|
|
|
if (utils::isa<AnfNodePtr>(n)) { |
|
|
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n); |
|
|
|
@@ -48,10 +50,36 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
bool ApplyMomentumWeightDecayScaleFusion::IsCast(const BaseRef &n) { |
|
|
|
if (utils::isa<AnfNodePtr>(n)) { |
|
|
|
AnfNodePtr in = utils::cast<AnfNodePtr>(n); |
|
|
|
MS_EXCEPTION_IF_NULL(in); |
|
|
|
if (IsPrimitiveCNode(in, prim::kPrimCast) || |
|
|
|
(IsPrimitiveCNode(in, prim::kPrimDepend) && |
|
|
|
IsPrimitiveCNode(in->cast<CNodePtr>()->input(kInputIndex), prim::kPrimCast))) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr GetCastInput(const AnfNodePtr &node) { |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimCast)) { |
|
|
|
return node->cast<CNodePtr>()->input(kInputIndex); |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimDepend)) { |
|
|
|
auto cast_node = node->cast<CNodePtr>()->input(kInputIndex); |
|
|
|
if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) { |
|
|
|
return cast_node->cast<CNodePtr>()->input(kInputIndex); |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { |
|
|
|
VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_}); |
|
|
|
VectorRef weight = VectorRef( |
|
|
|
{prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); |
|
|
|
VectorRef weight = |
|
|
|
VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), cast_gradient_}); |
|
|
|
VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_}); |
|
|
|
VectorRef apply_momentum = |
|
|
|
VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_}); |
|
|
|
@@ -68,7 +96,7 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr |
|
|
|
auto variable = utils::cast<AnfNodePtr>((*equiv)[variable_]); |
|
|
|
auto accumulation = utils::cast<AnfNodePtr>((*equiv)[accumulation_]); |
|
|
|
auto learning_rate = utils::cast<AnfNodePtr>((*equiv)[learning_rate_]); |
|
|
|
auto gradient = utils::cast<AnfNodePtr>((*equiv)[gradient_]); |
|
|
|
auto cast_gradient = utils::cast<AnfNodePtr>((*equiv)[cast_gradient_]); |
|
|
|
auto momentum = utils::cast<AnfNodePtr>((*equiv)[momentum_]); |
|
|
|
auto monad_state = utils::cast<AnfNodePtr>((*equiv)[monad_state_]); |
|
|
|
|
|
|
|
@@ -77,12 +105,14 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr |
|
|
|
MS_EXCEPTION_IF_NULL(variable); |
|
|
|
MS_EXCEPTION_IF_NULL(accumulation); |
|
|
|
MS_EXCEPTION_IF_NULL(learning_rate); |
|
|
|
MS_EXCEPTION_IF_NULL(gradient); |
|
|
|
MS_EXCEPTION_IF_NULL(cast_gradient); |
|
|
|
MS_EXCEPTION_IF_NULL(momentum); |
|
|
|
MS_EXCEPTION_IF_NULL(monad_state); |
|
|
|
|
|
|
|
auto prim = std::make_shared<Primitive>(kFusedWeightScaleApplyMomentum); |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
auto gradient = GetCastInput(cast_gradient); |
|
|
|
MS_EXCEPTION_IF_NULL(gradient); |
|
|
|
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation, |
|
|
|
learning_rate, gradient, momentum, monad_state}; |
|
|
|
auto replace_node = graph->NewCNode(inputs); |
|
|
|
|