You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindapi_test.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cmath>
  17. #include <memory>
  18. #include <sstream>
  19. #include <unordered_map>
  20. #include "common/common_test.h"
  21. #include "mindapi/base/logging.h"
  22. #include "mindapi/ir/func_graph.h"
  23. #include "mindapi/ir/tensor.h"
  24. #include "mindapi/ir/utils.h"
  25. namespace mindspore::api {
  26. class TestMindApi : public UT::Common {
  27. public:
  28. TestMindApi() = default;
  29. };
  30. /// Feature: MindAPI
  31. /// Description: test basic 'is()' 'cast()'
  32. /// Expectation: is/cast works correctly.
  33. TEST_F(TestMindApi, test_base_isa_cast) {
  34. auto value_node = MakeShared<ValueNode>(MakeValue(0));
  35. auto base = MakeShared<Base>(value_node->impl());
  36. ASSERT_TRUE(base->isa<Base>());
  37. ASSERT_TRUE(base->isa<AnfNode>());
  38. ASSERT_TRUE(base->isa<ValueNode>());
  39. ASSERT_FALSE(base->isa<AbstractBase>());
  40. auto anf_node = base->cast<AnfNodePtr>();
  41. ASSERT_TRUE(anf_node != nullptr);
  42. ASSERT_TRUE(anf_node->impl() == value_node->impl());
  43. ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
  44. }
  45. /// Feature: MindAPI
  46. /// Description: test graph construction.
  47. /// Expectation: graph is constructed as expected.
  48. TEST_F(TestMindApi, test_graph_construction) {
  49. // fg(x) { return myprim(x, 1); }
  50. auto fg = FuncGraph::Create();
  51. auto x = fg->add_parameter();
  52. x->set_name("x");
  53. auto prim = MakeShared<Primitive>("myprim");
  54. auto prim_node = MakeShared<ValueNode>(prim);
  55. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  56. auto cnode = fg->NewCNode({prim_node, x, value_node});
  57. fg->set_output(cnode);
  58. // Now we check the graph.
  59. ASSERT_EQ(fg->parameters().size(), 1);
  60. ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
  61. ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");
  62. auto ret_node = fg->get_return();
  63. ASSERT_TRUE(ret_node != nullptr);
  64. auto output_node = fg->output();
  65. ASSERT_TRUE(output_node != nullptr);
  66. ASSERT_TRUE(output_node->isa<CNode>());
  67. auto output_cnode = output_node->cast<CNodePtr>();
  68. ASSERT_EQ(output_cnode->inputs().size(), 3);
  69. ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
  70. ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
  71. ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
  72. ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
  73. ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
  74. ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());
  75. ASSERT_EQ(output_cnode->impl(), cnode->impl());
  76. }
  77. /// Feature: MindAPI
  78. /// Description: test value related functions.
  79. /// Expectation: value related functions work as expected.
  80. TEST_F(TestMindApi, test_values) {
  81. int64_t one = 1;
  82. auto s = MakeValue("hello");
  83. auto i = MakeValue(one);
  84. auto i2 = MakeValue(2);
  85. auto b = MakeValue(true);
  86. auto f = MakeValue(3.14f);
  87. auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
  88. auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));
  89. ASSERT_TRUE(s->isa<StringImm>());
  90. ASSERT_TRUE(i->isa<Int64Imm>());
  91. ASSERT_TRUE(i2->isa<Int64Imm>());
  92. ASSERT_TRUE(b->isa<BoolImm>());
  93. ASSERT_TRUE(f->isa<FP32Imm>());
  94. ASSERT_TRUE(seq->isa<ValueSequence>());
  95. ASSERT_TRUE(seq_str->isa<ValueSequence>());
  96. ASSERT_EQ(GetValue<std::string>(s), "hello");
  97. ASSERT_EQ(GetValue<int64_t>(i), one);
  98. ASSERT_EQ(GetValue<int64_t>(i2), 2);
  99. ASSERT_TRUE(GetValue<bool>(b));
  100. ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);
  101. ASSERT_EQ(GetValue<std::string>(i), "");
  102. ASSERT_EQ(GetValue<int64_t>(s), 0);
  103. ASSERT_FALSE(GetValue<bool>(s));
  104. ASSERT_EQ(GetValue<float>(s), 0.0f);
  105. auto seq_ptr = seq->cast<ValueSequencePtr>();
  106. ASSERT_TRUE(seq_ptr != nullptr);
  107. ASSERT_EQ(seq_ptr->size(), 3);
  108. ASSERT_EQ(seq_ptr->value().size(), 3);
  109. ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
  110. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
  111. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
  112. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);
  113. auto seq_values = GetValue<std::vector<int64_t>>(seq);
  114. ASSERT_EQ(seq_values.size(), 3);
  115. ASSERT_EQ(seq_values[0], 3);
  116. ASSERT_EQ(seq_values[1], 4);
  117. ASSERT_EQ(seq_values[2], 5);
  118. auto str_values = GetValue<std::vector<std::string>>(seq_str);
  119. ASSERT_EQ(str_values.size(), 4);
  120. ASSERT_EQ(str_values[0], "this");
  121. ASSERT_EQ(str_values[1], "is");
  122. ASSERT_EQ(str_values[2], "mindspore");
  123. ASSERT_EQ(str_values[3], "api");
  124. }
  125. /// Feature: MindAPI
  126. /// Description: test graph manager functions.
  127. /// Expectation: graph manager functions work as expected.
  128. TEST_F(TestMindApi, test_func_graph_manager) {
  129. // fg(x, y) { return myprim(add(x, y), 1); }
  130. auto fg = FuncGraph::Create();
  131. auto x = fg->add_parameter();
  132. x->set_name("x");
  133. auto y = fg->add_parameter();
  134. y->set_name("y");
  135. auto add = MakeShared<Primitive>("add");
  136. auto add_node = MakeShared<ValueNode>(add);
  137. auto add_cnode = fg->NewCNode({add_node, x, y});
  138. auto prim = MakeShared<Primitive>("myprim");
  139. auto prim_node = MakeShared<ValueNode>(prim);
  140. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  141. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  142. fg->set_output(cnode);
  143. auto mgr = FuncGraphManager::Manage(fg);
  144. ASSERT_TRUE(mgr != nullptr);
  145. ASSERT_TRUE(fg->manager() != nullptr);
  146. ASSERT_EQ(fg->manager()->impl(), mgr->impl());
  147. ASSERT_EQ(fg->manager(), mgr);
  148. ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
  149. mgr->Replace(add_cnode, x);
  150. ASSERT_EQ(cnode->input(1)->impl(), x->impl());
  151. mgr->SetEdge(cnode, 1, y);
  152. ASSERT_EQ(cnode->input(1)->impl(), y->impl());
  153. mgr->AddEdge(cnode, x);
  154. ASSERT_EQ(cnode->size(), 4);
  155. ASSERT_EQ(cnode->input(3)->impl(), x->impl());
  156. auto users = mgr->GetUsers(value_node);
  157. ASSERT_EQ(users.size(), 1);
  158. ASSERT_EQ(users[0].first, cnode);
  159. ASSERT_EQ(users[0].second, 2);
  160. }
  161. /// Feature: MindAPI
  162. /// Description: test value node utils.
  163. /// Expectation: value node utils work as expected.
  164. TEST_F(TestMindApi, test_value_node_utils) {
  165. auto fg = FuncGraph::Create();
  166. auto fg_node = MakeShared<ValueNode>(fg);
  167. auto prim = MakeShared<Primitive>("myprim");
  168. auto prim_node = MakeShared<ValueNode>(prim);
  169. auto one = MakeShared<ValueNode>(MakeValue(1));
  170. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  171. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);
  172. auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
  173. ASSERT_TRUE(fg1 != nullptr);
  174. ASSERT_TRUE(fg1->isa<FuncGraph>());
  175. auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
  176. ASSERT_TRUE(prim1 != nullptr);
  177. ASSERT_TRUE(prim1->isa<Primitive>());
  178. auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
  179. ASSERT_TRUE(imm != nullptr);
  180. ASSERT_TRUE(imm->isa<Int64Imm>());
  181. ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);
  182. auto value = GetValueNode(cnode->input(2));
  183. ASSERT_TRUE(value != nullptr);
  184. ASSERT_EQ(GetValue<int64_t>(value), 1);
  185. ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
  186. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
  187. ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);
  188. // Test NewValueNode.
  189. auto int_node = NewValueNode(1);
  190. auto bool_node = NewValueNode(true);
  191. auto float_node = NewValueNode(1.23f);
  192. auto str_node = NewValueNode("hello");
  193. ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
  194. ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
  195. ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
  196. ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
  197. ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
  198. ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
  199. ASSERT_TRUE(str_node->value()->isa<StringImm>());
  200. ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
  201. }
  202. /// Feature: MindAPI
  203. /// Description: test SharedPtr.
  204. /// Expectation: SharedPtr work as expected.
  205. TEST_F(TestMindApi, test_object_ptr) {
  206. auto fg = FuncGraph::Create();
  207. auto fg_node = MakeShared<ValueNode>(fg);
  208. auto prim = MakeShared<Primitive>("myprim");
  209. auto prim_node = MakeShared<ValueNode>(prim);
  210. auto one = MakeShared<ValueNode>(MakeValue(1));
  211. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  212. ASSERT_TRUE(fg != nullptr);
  213. ASSERT_FALSE(!fg);
  214. ASSERT_TRUE(fg ? true : false);
  215. ASSERT_TRUE((*cnode).input(0) == fg_node);
  216. ASSERT_TRUE(cnode->input(0) == fg_node);
  217. ASSERT_TRUE(cnode.get()->input(0) == fg_node);
  218. ASSERT_EQ(cnode->input(0), fg_node);
  219. ASSERT_EQ(cnode->input(1), prim_node);
  220. ASSERT_EQ(cnode->input(2), one);
  221. ASSERT_TRUE(cnode->input(0) != fg);
  222. AnfNodePtr p = fg_node;
  223. ASSERT_TRUE(p == fg_node);
  224. ASSERT_TRUE(p->isa<ValueNode>());
  225. ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
  226. ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);
  227. p = cnode;
  228. ASSERT_TRUE(p == cnode);
  229. ASSERT_TRUE(p->isa<CNode>());
  230. ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
  231. ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
  232. ASSERT_TRUE(p.get() == cnode.get());
  233. ASSERT_TRUE(p != nullptr);
  234. ASSERT_FALSE(p == nullptr);
  235. ASSERT_TRUE(p > nullptr);
  236. ASSERT_FALSE(p < nullptr);
  237. ASSERT_TRUE(p >= nullptr);
  238. ASSERT_FALSE(p <= nullptr);
  239. ASSERT_TRUE(nullptr != p);
  240. ASSERT_FALSE(nullptr == p);
  241. ASSERT_TRUE(nullptr < p);
  242. ASSERT_FALSE(nullptr > p);
  243. ASSERT_TRUE(nullptr <= p);
  244. ASSERT_FALSE(nullptr >= p);
  245. AnfNodePtr q = fg_node;
  246. ASSERT_TRUE(p != q);
  247. if (p.get()->impl() > q.get()->impl()) {
  248. ASSERT_TRUE(p > q);
  249. ASSERT_TRUE(p >= q);
  250. ASSERT_TRUE(q < p);
  251. ASSERT_TRUE(q <= p);
  252. } else {
  253. ASSERT_TRUE(p < q);
  254. ASSERT_TRUE(p <= q);
  255. ASSERT_TRUE(q > p);
  256. ASSERT_TRUE(q >= p);
  257. }
  258. std::stringstream ss1;
  259. std::stringstream ss2;
  260. ss1 << p;
  261. ss2 << cnode.get()->impl().get();
  262. ASSERT_EQ(ss1.str(), ss2.str());
  263. std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
  264. mymap.emplace(p, q);
  265. mymap.emplace(q, p);
  266. ASSERT_TRUE(mymap.find(p) != mymap.end());
  267. ASSERT_TRUE(mymap.find(q) != mymap.end());
  268. ASSERT_TRUE(mymap[p] == q);
  269. ASSERT_TRUE(mymap[q] == p);
  270. }
  271. /// Feature: MindAPI
  272. /// Description: test Tensor API.
  273. /// Expectation: Tensor API work as expected.
  274. TEST_F(TestMindApi, test_tensor_api) {
  275. ShapeVector shape{1, 2, 3};
  276. auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
  277. ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
  278. ASSERT_EQ(tensor->shape(), shape);
  279. ASSERT_EQ(tensor->DataSize(), 6);
  280. ASSERT_EQ(tensor->Size(), 24);
  281. ShapeVector shape2{2, 3};
  282. tensor->set_data_type(kNumberTypeInt32);
  283. tensor->set_shape(shape2);
  284. ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
  285. ASSERT_EQ(tensor->shape(), shape2);
  286. }
  287. /// Feature: MindAPI
  288. /// Description: test utils API.
  289. /// Expectation: Tensor API work as expected.
  290. TEST_F(TestMindApi, test_api_utils) {
  291. // Test utils::isa, utils::cast.
  292. auto anf_node = NewValueNode("hello");
  293. ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
  294. ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
  295. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
  296. ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);
  297. anf_node = nullptr;
  298. ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
  299. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);
  300. // Test clone graph.
  301. auto fg = FuncGraph::Create();
  302. auto x = fg->add_parameter();
  303. x->set_name("x");
  304. auto y = fg->add_parameter();
  305. y->set_name("y");
  306. auto add = MakeShared<Primitive>("add");
  307. auto add_node = MakeShared<ValueNode>(add);
  308. auto add_cnode = fg->NewCNode({add_node, x, y});
  309. auto prim = MakeShared<Primitive>("myprim");
  310. auto prim_node = MakeShared<ValueNode>(prim);
  311. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  312. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  313. fg->set_output(cnode);
  314. auto cloned_fg = utils::CloneGraph(fg);
  315. ASSERT_TRUE(cloned_fg != nullptr);
  316. ASSERT_EQ(cloned_fg->parameters().size(), 2);
  317. auto new_output = cloned_fg->output();
  318. ASSERT_TRUE(new_output != nullptr);
  319. ASSERT_TRUE(new_output->isa<CNode>());
  320. ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
  321. ASSERT_TRUE(new_output != cnode);
  322. ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);
  323. // Test get pad mode.
  324. auto pm_lower = MakeValue("pad");
  325. auto pm_upper = MakeValue("PAD");
  326. ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
  327. ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
  328. ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
  329. }
  330. /// Feature: MindAPI
  331. /// Description: test logging API.
  332. /// Expectation: logging work as expected.
  333. TEST_F(TestMindApi, test_api_logging) {
  334. MS_LOG(DEBUG) << "hello debug";
  335. MS_LOG(INFO) << "hello info";
  336. MS_LOG(WARNING) << "hello warning";
  337. MS_LOG(ERROR) << "hello error";
  338. try {
  339. MS_LOG(EXCEPTION) << "hello exception";
  340. ASSERT_TRUE(false);
  341. } catch (...) {
  342. }
  343. ASSERT_TRUE(true);
  344. }
  345. } // namespace mindspore::api