You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kpynative_test.cc 4.6 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <unordered_map>
  18. #include "frontend/optimizer/ad/kpynative.h"
  19. #include "common/common_test.h"
  20. #include "common/py_func_graph_fetcher.h"
  21. #include "ir/manager.h"
  22. #include "ir/value.h"
  23. #include "ir/func_graph_cloner.h"
  24. #include "utils/log_adapter.h"
  25. #include "ir/graph_utils.h"
  26. #include "pipeline/jit/resource.h"
  27. #include "pipeline/jit/parse/parse.h"
  28. #include "debug/anf_ir_utils.h"
  29. #include "frontend/operator/ops.h"
  30. namespace mindspore {
  31. namespace ad {
  32. class TestKPynative : public UT::Common {
  33. public:
  34. pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
  35. protected:
  36. AbstractBasePtr BuildArg() {
  37. std::vector<int64_t> shp = {2, 2};
  38. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp);
  39. auto abstract = tensor->ToAbstract();
  40. return abstract;
  41. }
  42. FuncGraphPtr BuildPrimalFuncGraph(const std::string &testCase) {
  43. auto g = std::make_shared<FuncGraph>();
  44. auto x = g->add_parameter();
  45. auto y = g->add_parameter();
  46. x->set_abstract(BuildArg());
  47. y->set_abstract(BuildArg());
  48. auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
  49. c_node->set_abstract(BuildArg());
  50. g->set_output(c_node);
  51. return g;
  52. }
  53. // a = x * y
  54. // b = stop_gradient(a)
  55. // c = b * y
  56. // return c
  57. FuncGraphPtr BuildStopGradient(const std::string &testCase) {
  58. auto g = std::make_shared<FuncGraph>();
  59. auto x = g->add_parameter();
  60. x->debug_info()->set_name("x");
  61. auto y = g->add_parameter();
  62. y->debug_info()->set_name("y");
  63. x->set_abstract(BuildArg());
  64. y->set_abstract(BuildArg());
  65. auto a_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), x, y});
  66. a_node->set_abstract(BuildArg());
  67. auto b_node = g->NewCNode({NewValueNode(prim::kPrimStopGradient), a_node});
  68. b_node->set_abstract(BuildArg());
  69. auto c_node = g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), b_node, y});
  70. c_node->set_abstract(BuildArg());
  71. auto d_node =
  72. g->NewCNode({NewValueNode(prim::GetPythonOps("tensor_mul", "mindspore.ops.functional")), a_node, c_node});
  73. d_node->set_abstract(BuildArg());
  74. g->set_output(d_node);
  75. return g;
  76. }
  77. FuncGraphPtr BuildBpropFuncGraph(const FuncGraphPtr &primal_fg) {
  78. auto input_params = primal_fg->parameters();
  79. std::vector<ValuePtr> input_param_values;
  80. std::for_each(input_params.begin(), input_params.end(),
  81. [&](const AnfNodePtr &param) { input_param_values.emplace_back(param->abstract()->BuildValue()); });
  82. auto k_pynative_cell = GradPynativeCellBegin(input_params, input_param_values);
  83. auto node_list = TopoSort(primal_fg->output());
  84. for (auto node : node_list) {
  85. if (node->isa<CNode>()) {
  86. auto c_node = node->cast<CNodePtr>();
  87. auto out = c_node->abstract()->GetValueTrack();
  88. ValuePtrList args;
  89. for (size_t i = 1; i < c_node->inputs().size(); ++i) {
  90. args.push_back(c_node->input(i)->abstract()->GetValueTrack());
  91. }
  92. GradPynativeOp(k_pynative_cell, c_node, args, out);
  93. }
  94. }
  95. auto bprop_fg = GradPynativeCellEnd(k_pynative_cell, AnfNodePtrList{}, std::vector<size_t>{0}, true, false, false,
  96. true);
  97. return bprop_fg;
  98. }
  99. };
  100. TEST_F(TestKPynative, test_simple_add) {
  101. auto primal_fg = BuildPrimalFuncGraph("test_simple_add");
  102. resource->manager()->KeepRoots({primal_fg});
  103. auto bprop_fg = BuildBpropFuncGraph(primal_fg);
  104. resource->manager()->KeepRoots({bprop_fg});
  105. }
  106. TEST_F(TestKPynative, test_stop_gradient) {
  107. auto primal_fg = BuildStopGradient("test_stop_gradient");
  108. resource->manager()->KeepRoots({primal_fg});
  109. auto bprop_fg = BuildBpropFuncGraph(primal_fg);
  110. resource->manager()->KeepRoots({bprop_fg});
  111. }
  112. } // namespace ad
  113. } // namespace mindspore