|
|
|
@@ -98,7 +98,7 @@ class TestKPynative : public UT::Common { |
|
|
|
GradPynativeOp(k_pynative_cell, c_node, args, out); |
|
|
|
} |
|
|
|
} |
|
|
|
auto bprop_fg = GradPynativeCellBuildFormalBProp(k_pynative_cell, AnfNodePtrList{}, true, false); |
|
|
|
auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, true, false, false, true); |
|
|
|
return bprop_fg; |
|
|
|
} |
|
|
|
}; |
|
|
|
@@ -106,23 +106,23 @@ class TestKPynative : public UT::Common { |
|
|
|
TEST_F(TestKPynative, test_simple_add) { |
|
|
|
auto primal_fg = BuildPrimalFuncGraph("test_simple_add"); |
|
|
|
resource->manager()->KeepRoots({primal_fg}); |
|
|
|
ExportIR(primal_fg->ToString() + ".dat", "", primal_fg); |
|
|
|
ExportIR(primal_fg->ToString() + ".dat", primal_fg); |
|
|
|
|
|
|
|
auto bprop_fg = BuildBpropFuncGraph(primal_fg); |
|
|
|
resource->manager()->KeepRoots({bprop_fg}); |
|
|
|
|
|
|
|
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg); |
|
|
|
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestKPynative, test_stop_gradient) { |
|
|
|
auto primal_fg = BuildStopGradient("test_stop_gradient"); |
|
|
|
resource->manager()->KeepRoots({primal_fg}); |
|
|
|
ExportIR(primal_fg->ToString() + ".dat", "", primal_fg); |
|
|
|
ExportIR(primal_fg->ToString() + ".dat", primal_fg); |
|
|
|
|
|
|
|
auto bprop_fg = BuildBpropFuncGraph(primal_fg); |
|
|
|
resource->manager()->KeepRoots({bprop_fg}); |
|
|
|
|
|
|
|
ExportIR(bprop_fg->ToString() + ".dat", "", bprop_fg); |
|
|
|
ExportIR(bprop_fg->ToString() + ".dat", bprop_fg); |
|
|
|
} |
|
|
|
} // namespace ad |
|
|
|
} // namespace mindspore |