Browse Source

!5508 Add AdamApplyOneWithDecayAssign fusion pass

Merge pull request !5508 from YuJianfeng/adam_assign
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
3048240f16
5 changed files with 418 additions and 100 deletions
  1. +147
    -7
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
  2. +48
    -1
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h
  3. +1
    -0
      mindspore/ccsrc/utils/utils.h
  4. +115
    -10
      tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc
  5. +107
    -82
      tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py

+ 147
- 7
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc View File

@@ -24,7 +24,8 @@


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const {
std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv,
const AnfNodePtr &final_node) const {
MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(equiv);
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]); auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]); auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]);
@@ -37,7 +38,12 @@ std::vector<AnfNodePtr> AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ
auto mul3_x = utils::cast<AnfNodePtr>((*equiv)[mul3_x_]); auto mul3_x = utils::cast<AnfNodePtr>((*equiv)[mul3_x_]);
auto mul4_x = utils::cast<AnfNodePtr>((*equiv)[mul4_x_]); auto mul4_x = utils::cast<AnfNodePtr>((*equiv)[mul4_x_]);
auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]); auto add2_y = utils::cast<AnfNodePtr>((*equiv)[add2_y_]);
auto prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName);
PrimitivePtr prim = nullptr;
if (AnfAlgo::CheckPrimitiveType(final_node, prim::kPrimDepend)) {
prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayAssignOpName);
} else {
prim = std::make_shared<Primitive>(kAdamApplyOneWithDecayOpName);
}
return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y};
} }


@@ -141,18 +147,152 @@ const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const {
return sub0; return sub0;
} }


const BaseRef AdamApplyOneWithDecayAssignRuleCond1::DefinePattern() const {
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
VectorRef square0({prim::kPrimSquare, input0_});
VectorRef add0({add0_var_, mul0, mul1});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, input4_, add3});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
return VectorRef({prim::kPrimDepend, depend1, assign2});
}

const BaseRef AdamApplyOneWithDecayAssignRuleCond2::DefinePattern() const {
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul0({prim::kPrimMul, input2_, mul0_x_});
VectorRef mul1({prim::kPrimMul, input0_, mul1_x_});
VectorRef square0({prim::kPrimSquare, input0_});
VectorRef add0({add0_var_, mul0, mul1});
VectorRef mul2({prim::kPrimMul, input1_, mul2_x_});
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, input3_, mul4_x_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
return VectorRef({prim::kPrimDepend, depend1, assign2});
}

const BaseRef AdamApplyOneWithDecayAssignRuleCond3::DefinePattern() const {
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
VectorRef square0({prim::kPrimSquare, input0_});
VectorRef add0({add0_var_, mul0, mul1});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef mul3({prim::kPrimMul, square0, mul3_x_});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
return VectorRef({prim::kPrimDepend, depend1, assign2});
}

const BaseRef AdamApplyOneWithDecayAssignRuleCond4::DefinePattern() const {
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
VectorRef square0({prim::kPrimSquare, input0_});
VectorRef add0({add0_var_, mul0, mul1});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
return VectorRef({prim::kPrimDepend, depend1, assign2});
}

const BaseRef AdamApplyOneWithDecayAssignRuleCond5::DefinePattern() const {
auto sqrt = std::make_shared<Primitive>(kSqrtOpName);
auto real_div = std::make_shared<Primitive>(kRealDivOpName);
VectorRef mul0({prim::kPrimMul, mul0_x_, input2_});
VectorRef mul1({prim::kPrimMul, mul1_x_, input0_});
VectorRef square0({prim::kPrimSquare, input0_});
VectorRef add0({add0_var_, mul0, mul1});
VectorRef mul2({prim::kPrimMul, mul2_x_, input1_});
VectorRef mul3({prim::kPrimMul, mul3_x_, square0});
VectorRef add1({add1_var_, mul2, mul3});
VectorRef sqrt0({sqrt, add1});
VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_});
VectorRef mul4({prim::kPrimMul, mul4_x_, input3_});
VectorRef real_div0({real_div, add0, add2});
VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0});
VectorRef mul5({prim::kPrimMul, add3, input4_});
VectorRef sub0({sub0_var_, input3_, mul5});
VectorRef assign0 = VectorRef({prim::kPrimAssign, input3_, sub0});
VectorRef depend0 = VectorRef({prim::kPrimDepend, sub0, assign0});
VectorRef assign1 = VectorRef({prim::kPrimAssign, input2_, add0});
VectorRef depend1 = VectorRef({prim::kPrimDepend, depend0, assign1});
VectorRef assign2 = VectorRef({prim::kPrimAssign, input1_, add1});
return VectorRef({prim::kPrimDepend, depend1, assign2});
}

const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const { const EquivPtr &equiv) const {
if (graph == nullptr || node == nullptr || equiv == nullptr) { if (graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr; return nullptr;
} }
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
auto sub0 = node;
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
auto iter_sub0 = (*equiv).find(sub0_var_);
if (iter_sub0 == (*equiv).end()) {
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the sub0 var after matched.";
}
sub0 = utils::cast<AnfNodePtr>(iter_sub0->second);
}
MS_EXCEPTION_IF_NULL(sub0);
if (!CheckSupportDataType(sub0, kFloatDataTypeSet)) {
return nullptr; return nullptr;
} }
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv);
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv, node);
auto fusion_node = graph->NewCNode(inputs); auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node); MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(node->scope());
fusion_node->set_scope(sub0->scope());


auto iter_add0 = (*equiv).find(add0_var_); auto iter_add0 = (*equiv).find(add0_var_);
if (iter_add0 == (*equiv).end()) { if (iter_add0 == (*equiv).end()) {
@@ -167,9 +307,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
auto add1 = utils::cast<AnfNodePtr>(iter_add1->second); auto add1 = utils::cast<AnfNodePtr>(iter_add1->second);
MS_EXCEPTION_IF_NULL(add1); MS_EXCEPTION_IF_NULL(add1);
auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0),
AnfAlgo::GetOutputInferDataType(node, 0)};
AnfAlgo::GetOutputInferDataType(sub0, 0)};
auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0),
AnfAlgo::GetOutputInferShape(node, 0)};
AnfAlgo::GetOutputInferShape(sub0, 0)};
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get());


std::vector<AnfNodePtr> fusion_node_outputs; std::vector<AnfNodePtr> fusion_node_outputs;


+ 48
- 1
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h View File

@@ -40,13 +40,14 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
add2_y_ = std::make_shared<Var>(); add2_y_ = std::make_shared<Var>();
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name())); add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name()));
} }
~AdamApplyOneWithDecayRule() override = default; ~AdamApplyOneWithDecayRule() override = default;
const BaseRef DefinePattern() const override = 0; const BaseRef DefinePattern() const override = 0;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;


protected: protected:
std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv) const;
std::vector<AnfNodePtr> GetFusionNodeInputs(const EquivPtr &equiv, const AnfNodePtr &final_node) const;
VarPtr input0_; VarPtr input0_;
VarPtr input1_; VarPtr input1_;
VarPtr input2_; VarPtr input2_;
@@ -60,6 +61,7 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass {
VarPtr add2_y_; VarPtr add2_y_;
VarPtr add0_var_; VarPtr add0_var_;
VarPtr add1_var_; VarPtr add1_var_;
VarPtr sub0_var_;
}; };


class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule {
@@ -106,6 +108,51 @@ class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule {
~AdamApplyOneWithDecayRuleCond5() override = default; ~AdamApplyOneWithDecayRuleCond5() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
}; };

class AdamApplyOneWithDecayAssignRuleCond1 : public AdamApplyOneWithDecayRule {
public:
explicit AdamApplyOneWithDecayAssignRuleCond1(bool multigraph = true)
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond1", multigraph) {}

~AdamApplyOneWithDecayAssignRuleCond1() override = default;
const BaseRef DefinePattern() const override;
};

class AdamApplyOneWithDecayAssignRuleCond2 : public AdamApplyOneWithDecayRule {
public:
explicit AdamApplyOneWithDecayAssignRuleCond2(bool multigraph = true)
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond2", multigraph) {}

~AdamApplyOneWithDecayAssignRuleCond2() override = default;
const BaseRef DefinePattern() const override;
};

class AdamApplyOneWithDecayAssignRuleCond3 : public AdamApplyOneWithDecayRule {
public:
explicit AdamApplyOneWithDecayAssignRuleCond3(bool multigraph = true)
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond3", multigraph) {}

~AdamApplyOneWithDecayAssignRuleCond3() override = default;
const BaseRef DefinePattern() const override;
};

class AdamApplyOneWithDecayAssignRuleCond4 : public AdamApplyOneWithDecayRule {
public:
explicit AdamApplyOneWithDecayAssignRuleCond4(bool multigraph = true)
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond4", multigraph) {}

~AdamApplyOneWithDecayAssignRuleCond4() override = default;
const BaseRef DefinePattern() const override;
};

class AdamApplyOneWithDecayAssignRuleCond5 : public AdamApplyOneWithDecayRule {
public:
explicit AdamApplyOneWithDecayAssignRuleCond5(bool multigraph = true)
: AdamApplyOneWithDecayRule("adam_apply_one_with_decay_assign_rule_cond5", multigraph) {}

~AdamApplyOneWithDecayAssignRuleCond5() override = default;
const BaseRef DefinePattern() const override;
};
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_

+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -122,6 +122,7 @@ constexpr auto kLayerNormBetaGammaBackpropOpName = "LayerNormBetaGammaBackprop";
constexpr auto kLambNextMVOpName = "LambNextMV"; constexpr auto kLambNextMVOpName = "LambNextMV";
constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD"; constexpr auto kConfusionTransposeDOpName = "ConfusionTransposeD";
constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay"; constexpr auto kAdamApplyOneWithDecayOpName = "AdamApplyOneWithDecay";
constexpr auto kAdamApplyOneWithDecayAssignOpName = "AdamApplyOneWithDecayAssign";
constexpr auto kBatchNormGradOpName = "BatchNormGrad"; constexpr auto kBatchNormGradOpName = "BatchNormGrad";
constexpr auto kBNInferOpName = "BNInfer"; constexpr auto kBNInferOpName = "BNInfer";
constexpr auto kAdamApplyOneOpName = "AdamApplyOne"; constexpr auto kAdamApplyOneOpName = "AdamApplyOne";


+ 115
- 10
tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc View File

@@ -31,7 +31,7 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon {
}; };


TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond1");


std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@@ -47,12 +47,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg); FuncGraphPtr new_graph = optimizer->Optimize(fg);


FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }


TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond2");


std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@@ -68,12 +68,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg); FuncGraphPtr new_graph = optimizer->Optimize(fg);


FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }


TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond3");


std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@@ -89,12 +89,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg); FuncGraphPtr new_graph = optimizer->Optimize(fg);


FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }


TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond4");


std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@@ -110,12 +110,12 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg); FuncGraphPtr new_graph = optimizer->Optimize(fg);


FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }


TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before_cond5");


std::vector<int> shp{2, 32, 224, 224}; std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
@@ -131,7 +131,112 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r
optimizer->AddPassManager(pm); optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg); FuncGraphPtr new_graph = optimizer->Optimize(fg);


FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond1");

std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 11; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond1>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond2");

std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 11; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond3");

std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 11; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond4) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond4");

std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 11; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_assign_rule_cond5) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "before_cond5");

std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 11; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamApplyOneWithDecayAssignRuleCond5>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_assign_rule", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }
} // namespace opt } // namespace opt


+ 107
- 82
tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py View File

@@ -15,6 +15,7 @@


from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F


mul = P.Mul() mul = P.Mul()
add = P.TensorAdd() add = P.TensorAdd()
@@ -22,9 +23,11 @@ square = P.Square()
sqrt = P.Sqrt() sqrt = P.Sqrt()
real_div = P.RealDiv() real_div = P.RealDiv()
sub = P.Sub() sub = P.Sub()
Assign = P.Assign()
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay') adam_apply_one_with_decay = Primitive('AdamApplyOneWithDecay')
adam_apply_one_with_decay_assign = Primitive('AdamApplyOneWithDecayAssign')




class FnDict: class FnDict:
@@ -39,11 +42,10 @@ class FnDict:




def test_adam_apply_one_with_decay_rule(tag): def test_adam_apply_one_with_decay_rule(tag):
""" test_adam_apply_one_with_decay_rule """
fns = FnDict() fns = FnDict()


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
@@ -52,50 +54,70 @@ def test_adam_apply_one_with_decay_rule(tag):
mul3 = mul(mul3_x, square0) mul3 = mul(mul3_x, square0)
add1 = add(mul2, mul3) add1 = add(mul2, mul3)
sqrt0 = sqrt(add1) sqrt0 = sqrt(add1)
add2 = add(sqrt0, add2_y)
add2 = add(add2_y, sqrt0)
mul4 = mul(mul4_x, input3) mul4 = mul(mul4_x, input3)
real_div0 = real_div(add0, add2) real_div0 = real_div(add0, add2)
add3 = add(real_div0, mul4)
add3 = add(mul4, real_div0)
mul5 = mul(input4, add3) mul5 = mul(input4, add3)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0) return make_tuple(add1, add0, sub0)


@fns @fns
def no_match(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(input2, mul0_x)
mul1 = mul(input0, mul1_x)
square0 = square(input0)
add0 = add(mul0, mul1)
mul2 = mul(input1, mul2_x)
mul3 = mul(mul3_x, square0)
add1 = add(mul2, mul3)
sqrt0 = sqrt(add1)
add2 = add(sqrt0, add2_y)
mul4 = mul(input3, mul4_x)
real_div0 = real_div(add0, add2)
add3 = add(mul4, real_div0)
mul5 = mul(add3, input4)
sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)

@fns
def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
# diff mul from original add
add0 = mul(mul0, mul1)
add0 = add(mul0, mul1)
mul2 = mul(mul2_x, input1) mul2 = mul(mul2_x, input1)
mul3 = mul(mul3_x, square0)
mul3 = mul(square0, mul3_x)
add1 = add(mul2, mul3) add1 = add(mul2, mul3)
sqrt0 = sqrt(add1) sqrt0 = sqrt(add1)
add2 = add(sqrt0, add2_y) add2 = add(sqrt0, add2_y)
mul4 = mul(mul4_x, input3) mul4 = mul(mul4_x, input3)
real_div0 = real_div(add0, add2) real_div0 = real_div(add0, add2)
add3 = add(real_div0, mul4)
mul5 = mul(input4, add3)
add3 = add(mul4, real_div0)
mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0) return make_tuple(add1, add0, sub0)


@fns @fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
add2_y)
item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1)
item2 = tuple_getitem(res, 2)
return make_tuple(make_tuple(item0, item1, item2))

return fns[tag]


def test_adam_apply_one_with_decay_rule_cond1(tag):
fns = FnDict()
def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0)
square0 = square(input0)
add0 = add(mul0, mul1)
mul2 = mul(mul2_x, input1)
mul3 = mul(mul3_x, square0)
add1 = add(mul2, mul3)
sqrt0 = sqrt(add1)
add2 = add(add2_y, sqrt0)
mul4 = mul(mul4_x, input3)
real_div0 = real_div(add0, add2)
add3 = add(mul4, real_div0)
mul5 = mul(add3, input4)
sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
@@ -104,11 +126,11 @@ def test_adam_apply_one_with_decay_rule_cond1(tag):
mul3 = mul(mul3_x, square0) mul3 = mul(mul3_x, square0)
add1 = add(mul2, mul3) add1 = add(mul2, mul3)
sqrt0 = sqrt(add1) sqrt0 = sqrt(add1)
add2 = add(add2_y, sqrt0)
add2 = add(sqrt0, add2_y)
mul4 = mul(mul4_x, input3) mul4 = mul(mul4_x, input3)
real_div0 = real_div(add0, add2) real_div0 = real_div(add0, add2)
add3 = add(mul4, real_div0) add3 = add(mul4, real_div0)
mul5 = mul(input4, add3)
mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0) return make_tuple(add1, add0, sub0)


@@ -124,11 +146,35 @@ def test_adam_apply_one_with_decay_rule_cond1(tag):
return fns[tag] return fns[tag]




def test_adam_apply_one_with_decay_rule_cond2(tag):
def test_adam_apply_one_with_decay_assign_rule(tag):
fns = FnDict() fns = FnDict()


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0)
square0 = square(input0)
add0 = add(mul0, mul1)
mul2 = mul(mul2_x, input1)
mul3 = mul(mul3_x, square0)
add1 = add(mul2, mul3)
sqrt0 = sqrt(add1)
add2 = add(add2_y, sqrt0)
mul4 = mul(mul4_x, input3)
real_div0 = real_div(add0, add2)
add3 = add(mul4, real_div0)
mul5 = mul(input4, add3)
sub0 = sub(input3, mul5)
assign0 = Assign(input3, sub0)
depend0 = F.depend(sub0, assign0)
assign1 = Assign(input2, add0)
depend1 = F.depend(depend0, assign1)
assign2 = Assign(input1, add1)
depend2 = F.depend(depend1, assign2)
return make_tuple(add1, add0, depend2)

@fns
def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(input2, mul0_x) mul0 = mul(input2, mul0_x)
mul1 = mul(input0, mul1_x) mul1 = mul(input0, mul1_x)
square0 = square(input0) square0 = square(input0)
@@ -143,25 +189,16 @@ def test_adam_apply_one_with_decay_rule_cond2(tag):
add3 = add(mul4, real_div0) add3 = add(mul4, real_div0)
mul5 = mul(add3, input4) mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)

@fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
add2_y)
item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1)
item2 = tuple_getitem(res, 2)
return make_tuple(make_tuple(item0, item1, item2))

return fns[tag]


def test_adam_apply_one_with_decay_rule_cond3(tag):
fns = FnDict()
assign0 = Assign(input3, sub0)
depend0 = F.depend(sub0, assign0)
assign1 = Assign(input2, add0)
depend1 = F.depend(depend0, assign1)
assign2 = Assign(input1, add1)
depend2 = F.depend(depend1, assign2)
return make_tuple(add1, add0, depend2)


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
@@ -176,25 +213,16 @@ def test_adam_apply_one_with_decay_rule_cond3(tag):
add3 = add(mul4, real_div0) add3 = add(mul4, real_div0)
mul5 = mul(add3, input4) mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)

@fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
add2_y)
item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1)
item2 = tuple_getitem(res, 2)
return make_tuple(make_tuple(item0, item1, item2))

return fns[tag]


def test_adam_apply_one_with_decay_rule_cond4(tag):
fns = FnDict()
assign0 = Assign(input3, sub0)
depend0 = F.depend(sub0, assign0)
assign1 = Assign(input2, add0)
depend1 = F.depend(depend0, assign1)
assign2 = Assign(input1, add1)
depend2 = F.depend(depend1, assign2)
return make_tuple(add1, add0, depend2)


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
@@ -209,25 +237,16 @@ def test_adam_apply_one_with_decay_rule_cond4(tag):
add3 = add(mul4, real_div0) add3 = add(mul4, real_div0)
mul5 = mul(add3, input4) mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)

@fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
add2_y)
item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1)
item2 = tuple_getitem(res, 2)
return make_tuple(make_tuple(item0, item1, item2))

return fns[tag]


def test_adam_apply_one_with_decay_rule_cond5(tag):
fns = FnDict()
assign0 = Assign(input3, sub0)
depend0 = F.depend(sub0, assign0)
assign1 = Assign(input2, add0)
depend1 = F.depend(depend0, assign1)
assign2 = Assign(input1, add1)
depend2 = F.depend(depend1, assign2)
return make_tuple(add1, add0, depend2)


@fns @fns
def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
def before_cond5(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
mul0 = mul(mul0_x, input2) mul0 = mul(mul0_x, input2)
mul1 = mul(mul1_x, input0) mul1 = mul(mul1_x, input0)
square0 = square(input0) square0 = square(input0)
@@ -242,12 +261,18 @@ def test_adam_apply_one_with_decay_rule_cond5(tag):
add3 = add(mul4, real_div0) add3 = add(mul4, real_div0)
mul5 = mul(add3, input4) mul5 = mul(add3, input4)
sub0 = sub(input3, mul5) sub0 = sub(input3, mul5)
return make_tuple(add1, add0, sub0)
assign0 = Assign(input3, sub0)
depend0 = F.depend(sub0, assign0)
assign1 = Assign(input2, add0)
depend1 = F.depend(depend0, assign1)
assign2 = Assign(input1, add1)
depend2 = F.depend(depend1, assign2)
return make_tuple(add1, add0, depend2)


@fns @fns
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y):
res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x,
add2_y)
res = adam_apply_one_with_decay_assign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x,
mul4_x, add2_y)
item0 = tuple_getitem(res, 0) item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1) item1 = tuple_getitem(res, 1)
item2 = tuple_getitem(res, 2) item2 = tuple_getitem(res, 2)


Loading…
Cancel
Save