|
|
|
@@ -43,7 +43,7 @@ class TestKPynative : public UT::Common { |
|
|
|
return abstract; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr BuildPrimalFuncGraph(const std::string& testCase) { |
|
|
|
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) { |
|
|
|
auto g = std::make_shared<FuncGraph>(); |
|
|
|
auto x = g->add_parameter(); |
|
|
|
auto y = g->add_parameter(); |
|
|
|
@@ -73,14 +73,19 @@ class TestKPynative : public UT::Common { |
|
|
|
b_node->set_abstract(BuildArg()); |
|
|
|
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y}); |
|
|
|
c_node->set_abstract(BuildArg()); |
|
|
|
auto d_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node}); |
|
|
|
auto d_node = |
|
|
|
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node}); |
|
|
|
d_node->set_abstract(BuildArg()); |
|
|
|
g->set_output(d_node); |
|
|
|
return g; |
|
|
|
} |
|
|
|
|
|
|
|
FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) { |
|
|
|
auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters()); |
|
|
|
auto input_params = primal_fg->parameters(); |
|
|
|
std::vector<ValuePtr> input_param_values; |
|
|
|
std::for_each(input_params.begin(), input_params.end(), |
|
|
|
[&](const AnfNodePtr ¶m) { input_param_values.emplace_back(param->abstract()->BuildValue()); }); |
|
|
|
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values); |
|
|
|
auto node_list = TopoSort(primal_fg->output()); |
|
|
|
for (auto node : node_list) { |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
|