Browse Source

fix_compile_ut_failed

tags/v1.3.0
lvliang chujinjin 5 years ago
parent
commit
c759effa94
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      tests/ut/cpp/optimizer/ad/kpynative_test.cc

+ 8
- 3
tests/ut/cpp/optimizer/ad/kpynative_test.cc View File

@@ -43,7 +43,7 @@ class TestKPynative : public UT::Common {
return abstract;
}

FuncGraphPtr BuildPrimalFuncGraph(const std::string& testCase) {
FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
auto g = std::make_shared<FuncGraph>();
auto x = g->add_parameter();
auto y = g->add_parameter();
@@ -73,14 +73,19 @@ class TestKPynative : public UT::Common {
b_node->set_abstract(BuildArg());
auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
c_node->set_abstract(BuildArg());
auto d_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
auto d_node =
g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
d_node->set_abstract(BuildArg());
g->set_output(d_node);
return g;
}

FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
auto k_pynative_cell = GradPynativeCellBegin(primal_fg->parameters());
auto input_params = primal_fg->parameters();
std::vector<ValuePtr> input_param_values;
std::for_each(input_params.begin(), input_params.end(),
[&](const AnfNodePtr &param) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
auto node_list = TopoSort(primal_fg->output());
for (auto node : node_list) {
if (node->isa<CNode>()) {


Loading…
Cancel
Save