You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph_test.cc 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/common_test.h"
  17. #include "operator/ops.h"
  18. #include "session/kernel_graph.h"
  19. #include "session/anf_runtime_algorithm.h"
  20. #include "mindspore/ccsrc/device/kernel_info.h"
  21. #include "utils/utils.h"
  22. namespace mindspore {
  23. namespace session {
  24. using device::KernelInfo;
  25. using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
  26. class KernelGraphTest : public UT::Common {
  27. public:
  28. KernelGraphTest() = default;
  29. void SetUp() override {}
  30. void TearDown() override {}
  31. };
  32. TEST_F(KernelGraphTest, NewValueNode) {
  33. auto kernel_graph = std::make_shared<KernelGraph>();
  34. auto add_value = NewValueNode(MakeValue(0));
  35. MS_EXCEPTION_IF_NULL(add_value);
  36. std::vector<int> shape = {1};
  37. auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
  38. add_value->set_abstract(x_abstract);
  39. add_value->set_kernel_info(std::make_shared<KernelInfo>());
  40. auto mutable_kernel_info = add_value->kernel_info();
  41. MS_EXCEPTION_IF_NULL(mutable_kernel_info);
  42. std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
  43. builder->SetOutputsFormat({kOpFormat_FRAC_Z});
  44. builder->SetOutputsDeviceType({kFloat32->type_id()});
  45. mutable_kernel_info->set_select_kernel_build_info(builder->Build());
  46. auto new_value = kernel_graph->NewValueNode(add_value);
  47. EXPECT_NE(new_value, nullptr);
  48. EXPECT_EQ(AnfAlgo::GetOutputInferShape(new_value, 0)[0], 1);
  49. EXPECT_EQ(AnfAlgo::GetOutputInferDataType(new_value, 0), kFloat32->type_id());
  50. EXPECT_EQ(AnfAlgo::GetOutputFormat(new_value, 0), kOpFormat_DEFAULT);
  51. EXPECT_EQ(AnfAlgo::GetOutputDeviceDataType(new_value, 0), kTypeUnknown);
  52. }
  53. TEST_F(KernelGraphTest, NewParameter) {
  54. auto anf_graph = std::make_shared<FuncGraph>();
  55. auto kernel_graph = std::make_shared<KernelGraph>();
  56. // test nullptr as input
  57. auto new_paramter = kernel_graph->NewParameter(nullptr);
  58. EXPECT_NE(new_paramter, nullptr);
  59. EXPECT_TRUE(new_paramter->isa<Parameter>());
  60. EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT);
  61. EXPECT_EQ(AnfAlgo::GetOutputDeviceDataType(new_paramter, 0), kMetaTypeNone);
  62. // test non-weight parameter node as input
  63. std::vector<int> shape = {2, 32, 224, 224};
  64. auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
  65. auto non_weight_parameter = anf_graph->add_parameter();
  66. MS_EXCEPTION_IF_NULL(non_weight_parameter);
  67. non_weight_parameter->set_abstract(x_abstract);
  68. auto new_non_weight_parameter = kernel_graph->NewParameter(non_weight_parameter);
  69. EXPECT_NE(new_non_weight_parameter, nullptr);
  70. new_non_weight_parameter->set_name("non_weight_parameter");
  71. EXPECT_EQ(AnfAlgo::GetOutputInferShape(new_non_weight_parameter, 0)[1], 32);
  72. EXPECT_EQ(AnfAlgo::GetOutputInferDataType(new_non_weight_parameter, 0), kFloat32->type_id());
  73. EXPECT_EQ(AnfAlgo::GetOutputFormat(new_non_weight_parameter, 0), kOpFormat_DEFAULT);
  74. EXPECT_EQ(AnfAlgo::GetOutputDeviceDataType(new_non_weight_parameter, 0), kFloat32->type_id());
  75. EXPECT_EQ(new_non_weight_parameter->name(), "non_weight_parameter");
  76. // test weight parameter node as input
  77. auto weight_parameter_node = anf_graph->add_parameter();
  78. MS_EXCEPTION_IF_NULL(weight_parameter_node);
  79. py::object obj;
  80. weight_parameter_node->set_default_param(obj);
  81. weight_parameter_node->set_abstract(x_abstract);
  82. auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node);
  83. EXPECT_NE(new_weight_parameter_node, nullptr);
  84. EXPECT_TRUE(new_weight_parameter_node->has_default());
  85. EXPECT_EQ(AnfAlgo::GetOutputInferShape(new_weight_parameter_node, 0)[2], 224);
  86. EXPECT_EQ(AnfAlgo::GetOutputInferDataType(new_weight_parameter_node, 0), kFloat32->type_id());
  87. EXPECT_EQ(AnfAlgo::GetOutputFormat(new_weight_parameter_node, 0), kOpFormat_DEFAULT);
  88. EXPECT_EQ(AnfAlgo::GetOutputDeviceDataType(new_weight_parameter_node, 0), kTypeUnknown);
  89. }
  90. TEST_F(KernelGraphTest, NewCNode) {
  91. auto kernel_graph = std::make_shared<KernelGraph>();
  92. auto add_value = NewValueNode(prim::kPrimTensorAdd);
  93. std::vector<AnfNodePtr> inputs = {add_value};
  94. auto new_cnode = kernel_graph->NewCNode(inputs);
  95. EXPECT_NE(new_cnode, nullptr);
  96. EXPECT_EQ(AnfAlgo::GetCNodeName(new_cnode), prim::kPrimTensorAdd->name());
  97. EXPECT_TRUE(AnfAlgo::GetOutputInferShape(new_cnode, 0).empty());
  98. EXPECT_EQ(AnfAlgo::GetOutputInferDataType(new_cnode, 0), kMetaTypeNone);
  99. }
  100. TEST_F(KernelGraphTest, MutableInputs) {
  101. auto kernel_graph = std::make_shared<KernelGraph>();
  102. auto x_parameter = kernel_graph->add_parameter();
  103. MS_EXCEPTION_IF_NULL(x_parameter);
  104. x_parameter->set_name("x_parameter");
  105. auto y_parameter = kernel_graph->add_parameter();
  106. MS_EXCEPTION_IF_NULL(y_parameter);
  107. y_parameter->set_name("y_parameter");
  108. std::vector<AnfNodePtr> inputs = {x_parameter, y_parameter};
  109. auto mutable_inputs = kernel_graph->MutableInputs();
  110. MS_EXCEPTION_IF_NULL(mutable_inputs);
  111. *mutable_inputs = inputs;
  112. auto first_input = kernel_graph->inputs()[0];
  113. MS_EXCEPTION_IF_NULL(first_input);
  114. auto first_parameter = first_input->cast<ParameterPtr>();
  115. MS_EXCEPTION_IF_NULL(first_parameter);
  116. EXPECT_EQ(first_parameter->name(), "x_parameter");
  117. auto second_input = kernel_graph->inputs()[1];
  118. MS_EXCEPTION_IF_NULL(second_input);
  119. auto second_parameter = second_input->cast<ParameterPtr>();
  120. MS_EXCEPTION_IF_NULL(second_parameter);
  121. EXPECT_EQ(second_parameter->name(), "y_parameter");
  122. }
  123. TEST_F(KernelGraphTest, SetExecOrderByDefault) {
  124. /*
  125. * define kernel graph:
  126. * x ----- y
  127. * add ----- z
  128. * mul
  129. * return
  130. */
  131. auto kernel_graph = std::make_shared<KernelGraph>();
  132. std::vector<int> shape = {2, 32, 224, 224};
  133. auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
  134. auto x_parameter = kernel_graph->NewParameter();
  135. MS_EXCEPTION_IF_NULL(x_parameter);
  136. x_parameter->set_name("x_parameter");
  137. x_parameter->set_abstract(abstract);
  138. auto y_parameter = kernel_graph->NewParameter();
  139. MS_EXCEPTION_IF_NULL(y_parameter);
  140. y_parameter->set_name("y_parameter");
  141. y_parameter->set_abstract(abstract);
  142. std::vector<AnfNodePtr> add_inputs = {NewValueNode(prim::kPrimTensorAdd), x_parameter, y_parameter};
  143. auto add = kernel_graph->NewCNode(add_inputs);
  144. MS_EXCEPTION_IF_NULL(add);
  145. add->set_abstract(abstract);
  146. auto z_parameter = kernel_graph->NewParameter();
  147. MS_EXCEPTION_IF_NULL(z_parameter);
  148. z_parameter->set_name("z_parameter");
  149. z_parameter->set_abstract(abstract);
  150. std::vector<AnfNodePtr> mul_inputs = {NewValueNode(prim::kPrimMul), add, z_parameter};
  151. auto mul = kernel_graph->NewCNode(mul_inputs);
  152. MS_EXCEPTION_IF_NULL(mul);
  153. mul->set_abstract(abstract);
  154. std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), mul};
  155. auto make_tuple = kernel_graph->NewCNode(make_tuple_inputs);
  156. kernel_graph->set_output(make_tuple);
  157. // test outputs() function
  158. auto outputs = kernel_graph->outputs();
  159. EXPECT_EQ(outputs.size(), 1);
  160. EXPECT_EQ(AnfAlgo::GetCNodeName(outputs[0]), prim::kPrimMul->name());
  161. // test SetExecOrderByDefault() function
  162. kernel_graph->SetExecOrderByDefault();
  163. auto execution_order = kernel_graph->execution_order();
  164. EXPECT_EQ(execution_order.size(), 2);
  165. EXPECT_EQ(AnfAlgo::GetCNodeName(execution_order[0]), prim::kPrimTensorAdd->name());
  166. EXPECT_EQ(AnfAlgo::GetCNodeName(execution_order[1]), prim::kPrimMul->name());
  167. // test set_execution_order() function
  168. kernel_graph->set_execution_order({add});
  169. execution_order = kernel_graph->execution_order();
  170. EXPECT_EQ(execution_order.size(), 1);
  171. EXPECT_EQ(AnfAlgo::GetCNodeName(execution_order[0]), prim::kPrimTensorAdd->name());
  172. }
  173. TEST_F(KernelGraphTest, SetGraphId) {
  174. auto kernel_graph = std::make_shared<KernelGraph>();
  175. kernel_graph->set_graph_id(1);
  176. EXPECT_EQ(kernel_graph->graph_id(), 1);
  177. }
  178. } // namespace session
  179. } // namespace mindspore