You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindapi_test.cc 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cmath>
  17. #include <memory>
  18. #include <sstream>
  19. #include <unordered_map>
  20. #include "common/common_test.h"
  21. #include "mindapi/base/logging.h"
  22. #include "mindapi/ir/func_graph.h"
  23. #include "mindapi/ir/tensor.h"
  24. #include "mindapi/ir/utils.h"
  25. namespace mindspore::api {
  26. class TestMindApi : public UT::Common {
  27. public:
  28. TestMindApi() = default;
  29. };
  30. /// Feature: MindAPI
  31. /// Description: test basic 'is()' 'cast()'
  32. /// Expectation: is/cast works correctly.
  33. TEST_F(TestMindApi, test_base_isa_cast) {
  34. auto value_node = MakeShared<ValueNode>(MakeValue(0));
  35. auto base = MakeShared<Base>(value_node->impl());
  36. ASSERT_TRUE(base->isa<Base>());
  37. ASSERT_TRUE(base->isa<AnfNode>());
  38. ASSERT_TRUE(base->isa<ValueNode>());
  39. ASSERT_FALSE(base->isa<AbstractBase>());
  40. auto anf_node = base->cast<AnfNodePtr>();
  41. ASSERT_TRUE(anf_node != nullptr);
  42. ASSERT_TRUE(anf_node->impl() == value_node->impl());
  43. ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
  44. }
  45. /// Feature: MindAPI
  46. /// Description: test graph construction.
  47. /// Expectation: graph is constructed as expected.
  48. TEST_F(TestMindApi, test_graph_construction) {
  49. // fg(x) { return myprim(x, 1); }
  50. auto fg = FuncGraph::Create();
  51. auto x = fg->add_parameter();
  52. x->set_name("x");
  53. auto prim = MakeShared<Primitive>("myprim");
  54. auto prim_node = MakeShared<ValueNode>(prim);
  55. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  56. auto cnode = fg->NewCNode({prim_node, x, value_node});
  57. fg->set_output(cnode);
  58. // Now we check the graph.
  59. ASSERT_EQ(fg->parameters().size(), 1);
  60. ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
  61. ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");
  62. auto ret_node = fg->get_return();
  63. ASSERT_TRUE(ret_node != nullptr);
  64. auto output_node = fg->output();
  65. ASSERT_TRUE(output_node != nullptr);
  66. ASSERT_TRUE(output_node->isa<CNode>());
  67. auto output_cnode = output_node->cast<CNodePtr>();
  68. ASSERT_EQ(output_cnode->inputs().size(), 3);
  69. ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
  70. ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
  71. ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
  72. ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
  73. ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
  74. ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());
  75. ASSERT_EQ(output_cnode->impl(), cnode->impl());
  76. }
  77. /// Feature: MindAPI
  78. /// Description: test value related functions.
  79. /// Expectation: value related functions work as expected.
  80. TEST_F(TestMindApi, test_values) {
  81. int64_t one = 1;
  82. auto s = MakeValue("hello");
  83. auto i = MakeValue(one);
  84. auto i2 = MakeValue(2);
  85. auto b = MakeValue(true);
  86. auto f = MakeValue(3.14f);
  87. auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
  88. auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));
  89. ASSERT_TRUE(s->isa<StringImm>());
  90. ASSERT_TRUE(i->isa<Int64Imm>());
  91. ASSERT_TRUE(i2->isa<Int64Imm>());
  92. ASSERT_TRUE(b->isa<BoolImm>());
  93. ASSERT_TRUE(f->isa<FP32Imm>());
  94. ASSERT_TRUE(seq->isa<ValueSequence>());
  95. ASSERT_TRUE(seq_str->isa<ValueSequence>());
  96. ASSERT_EQ(GetValue<std::string>(s), "hello");
  97. ASSERT_EQ(GetValue<int64_t>(i), one);
  98. ASSERT_EQ(GetValue<int64_t>(i2), 2);
  99. ASSERT_TRUE(GetValue<bool>(b));
  100. ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);
  101. ASSERT_EQ(GetValue<std::string>(i), "");
  102. ASSERT_EQ(GetValue<int64_t>(s), 0);
  103. ASSERT_FALSE(GetValue<bool>(s));
  104. ASSERT_EQ(GetValue<float>(s), 0.0f);
  105. auto seq_ptr = seq->cast<ValueSequencePtr>();
  106. ASSERT_TRUE(seq_ptr != nullptr);
  107. ASSERT_EQ(seq_ptr->size(), 3);
  108. ASSERT_EQ(seq_ptr->value().size(), 3);
  109. ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
  110. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
  111. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
  112. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);
  113. auto seq_values = GetValue<std::vector<int64_t>>(seq);
  114. ASSERT_EQ(seq_values.size(), 3);
  115. ASSERT_EQ(seq_values[0], 3);
  116. ASSERT_EQ(seq_values[1], 4);
  117. ASSERT_EQ(seq_values[2], 5);
  118. auto str_values = GetValue<std::vector<std::string>>(seq_str);
  119. ASSERT_EQ(str_values.size(), 4);
  120. ASSERT_EQ(str_values[0], "this");
  121. ASSERT_EQ(str_values[1], "is");
  122. ASSERT_EQ(str_values[2], "mindspore");
  123. ASSERT_EQ(str_values[3], "api");
  124. auto value_list = GetValue<ValuePtrList>(seq);
  125. ASSERT_EQ(value_list.size(), 3);
  126. ASSERT_EQ(utils::cast<int64_t>(value_list[0]), 3);
  127. ASSERT_EQ(utils::cast<int64_t>(value_list[1]), 4);
  128. ASSERT_EQ(utils::cast<int64_t>(value_list[2]), 5);
  129. std::vector<uint8_t> vec_uint8{5, 6, 7};
  130. auto uint8_seq = MakeValue<std::vector<uint8_t>>(vec_uint8);
  131. ASSERT_TRUE(uint8_seq->isa<ValueSequence>());
  132. auto uint8_values = GetValue<std::vector<uint8_t>>(uint8_seq);
  133. ASSERT_EQ(uint8_values.size(), 3);
  134. ASSERT_EQ(uint8_values[0], 5);
  135. ASSERT_EQ(uint8_values[1], 6);
  136. ASSERT_EQ(uint8_values[2], 7);
  137. }
  138. /// Feature: MindAPI
  139. /// Description: test graph manager functions.
  140. /// Expectation: graph manager functions work as expected.
  141. TEST_F(TestMindApi, test_func_graph_manager) {
  142. // fg(x, y) { return myprim(add(x, y), 1); }
  143. auto fg = FuncGraph::Create();
  144. auto x = fg->add_parameter();
  145. x->set_name("x");
  146. auto y = fg->add_parameter();
  147. y->set_name("y");
  148. auto add = MakeShared<Primitive>("add");
  149. auto add_node = MakeShared<ValueNode>(add);
  150. auto add_cnode = fg->NewCNode({add_node, x, y});
  151. auto prim = MakeShared<Primitive>("myprim");
  152. auto prim_node = MakeShared<ValueNode>(prim);
  153. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  154. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  155. fg->set_output(cnode);
  156. auto mgr = FuncGraphManager::Manage(fg);
  157. ASSERT_TRUE(mgr != nullptr);
  158. ASSERT_TRUE(fg->manager() != nullptr);
  159. ASSERT_EQ(fg->manager()->impl(), mgr->impl());
  160. ASSERT_EQ(fg->manager(), mgr);
  161. ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
  162. mgr->Replace(add_cnode, x);
  163. ASSERT_EQ(cnode->input(1)->impl(), x->impl());
  164. mgr->SetEdge(cnode, 1, y);
  165. ASSERT_EQ(cnode->input(1)->impl(), y->impl());
  166. mgr->AddEdge(cnode, x);
  167. ASSERT_EQ(cnode->size(), 4);
  168. ASSERT_EQ(cnode->input(3)->impl(), x->impl());
  169. auto users = mgr->GetUsers(value_node);
  170. ASSERT_EQ(users.size(), 1);
  171. ASSERT_EQ(users[0].first, cnode);
  172. ASSERT_EQ(users[0].second, 2);
  173. }
  174. /// Feature: MindAPI
  175. /// Description: test value node utils.
  176. /// Expectation: value node utils work as expected.
  177. TEST_F(TestMindApi, test_value_node_utils) {
  178. auto fg = FuncGraph::Create();
  179. auto fg_node = MakeShared<ValueNode>(fg);
  180. auto prim = MakeShared<Primitive>("myprim");
  181. auto prim_node = MakeShared<ValueNode>(prim);
  182. auto one = MakeShared<ValueNode>(MakeValue(1));
  183. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  184. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);
  185. auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
  186. ASSERT_TRUE(fg1 != nullptr);
  187. ASSERT_TRUE(fg1->isa<FuncGraph>());
  188. auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
  189. ASSERT_TRUE(prim1 != nullptr);
  190. ASSERT_TRUE(prim1->isa<Primitive>());
  191. auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
  192. ASSERT_TRUE(imm != nullptr);
  193. ASSERT_TRUE(imm->isa<Int64Imm>());
  194. ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);
  195. auto value = GetValueNode(cnode->input(2));
  196. ASSERT_TRUE(value != nullptr);
  197. ASSERT_EQ(GetValue<int64_t>(value), 1);
  198. ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
  199. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
  200. ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);
  201. // Test NewValueNode.
  202. auto int_node = NewValueNode(1);
  203. auto bool_node = NewValueNode(true);
  204. auto float_node = NewValueNode(1.23f);
  205. auto str_node = NewValueNode("hello");
  206. ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
  207. ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
  208. ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
  209. ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
  210. ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
  211. ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
  212. ASSERT_TRUE(str_node->value()->isa<StringImm>());
  213. ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
  214. }
  215. /// Feature: MindAPI
  216. /// Description: test SharedPtr.
  217. /// Expectation: SharedPtr work as expected.
  218. TEST_F(TestMindApi, test_object_ptr) {
  219. auto fg = FuncGraph::Create();
  220. auto fg_node = MakeShared<ValueNode>(fg);
  221. auto prim = MakeShared<Primitive>("myprim");
  222. auto prim_node = MakeShared<ValueNode>(prim);
  223. auto one = MakeShared<ValueNode>(MakeValue(1));
  224. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  225. ASSERT_TRUE(fg != nullptr);
  226. ASSERT_FALSE(!fg);
  227. ASSERT_TRUE(fg ? true : false);
  228. ASSERT_TRUE((*cnode).input(0) == fg_node);
  229. ASSERT_TRUE(cnode->input(0) == fg_node);
  230. ASSERT_TRUE(cnode.get()->input(0) == fg_node);
  231. ASSERT_EQ(cnode->input(0), fg_node);
  232. ASSERT_EQ(cnode->input(1), prim_node);
  233. ASSERT_EQ(cnode->input(2), one);
  234. ASSERT_TRUE(cnode->input(0) != fg);
  235. AnfNodePtr p = fg_node;
  236. ASSERT_TRUE(p == fg_node);
  237. ASSERT_TRUE(p->isa<ValueNode>());
  238. ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
  239. ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);
  240. p = cnode;
  241. ASSERT_TRUE(p == cnode);
  242. ASSERT_TRUE(p->isa<CNode>());
  243. ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
  244. ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
  245. ASSERT_TRUE(p.get() == cnode.get());
  246. ASSERT_TRUE(p != nullptr);
  247. ASSERT_FALSE(p == nullptr);
  248. ASSERT_TRUE(p > nullptr);
  249. ASSERT_FALSE(p < nullptr);
  250. ASSERT_TRUE(p >= nullptr);
  251. ASSERT_FALSE(p <= nullptr);
  252. ASSERT_TRUE(nullptr != p);
  253. ASSERT_FALSE(nullptr == p);
  254. ASSERT_TRUE(nullptr < p);
  255. ASSERT_FALSE(nullptr > p);
  256. ASSERT_TRUE(nullptr <= p);
  257. ASSERT_FALSE(nullptr >= p);
  258. AnfNodePtr q = fg_node;
  259. ASSERT_TRUE(p != q);
  260. if (p.get()->impl() > q.get()->impl()) {
  261. ASSERT_TRUE(p > q);
  262. ASSERT_TRUE(p >= q);
  263. ASSERT_TRUE(q < p);
  264. ASSERT_TRUE(q <= p);
  265. } else {
  266. ASSERT_TRUE(p < q);
  267. ASSERT_TRUE(p <= q);
  268. ASSERT_TRUE(q > p);
  269. ASSERT_TRUE(q >= p);
  270. }
  271. std::stringstream ss1;
  272. std::stringstream ss2;
  273. ss1 << p;
  274. ss2 << cnode.get()->impl().get();
  275. ASSERT_EQ(ss1.str(), ss2.str());
  276. std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
  277. mymap.emplace(p, q);
  278. mymap.emplace(q, p);
  279. ASSERT_TRUE(mymap.find(p) != mymap.end());
  280. ASSERT_TRUE(mymap.find(q) != mymap.end());
  281. ASSERT_TRUE(mymap[p] == q);
  282. ASSERT_TRUE(mymap[q] == p);
  283. }
  284. /// Feature: MindAPI
  285. /// Description: test Tensor API.
  286. /// Expectation: Tensor API work as expected.
  287. TEST_F(TestMindApi, test_tensor_api) {
  288. ShapeVector shape{1, 2, 3};
  289. auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
  290. ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
  291. ASSERT_EQ(tensor->shape(), shape);
  292. ASSERT_EQ(tensor->DataSize(), 6);
  293. ASSERT_EQ(tensor->Size(), 24);
  294. ShapeVector shape2{2, 3};
  295. tensor->set_data_type(kNumberTypeInt32);
  296. tensor->set_shape(shape2);
  297. ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
  298. ASSERT_EQ(tensor->shape(), shape2);
  299. // TensorType.
  300. TypePtr tensor_type = MakeShared<TensorType>(Type::GetType(TypeId::kNumberTypeFloat32));
  301. ASSERT_TRUE(tensor_type->isa<TensorType>());
  302. ASSERT_EQ(tensor_type->cast<TensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
  303. }
  304. /// Feature: MindAPI
  305. /// Description: test utils API.
  306. /// Expectation: Tensor API work as expected.
  307. TEST_F(TestMindApi, test_api_utils) {
  308. // Test utils::isa, utils::cast.
  309. auto anf_node = NewValueNode("hello");
  310. ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
  311. ASSERT_TRUE(utils::isa<AnfNodePtr>(anf_node));
  312. ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
  313. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
  314. ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);
  315. ASSERT_TRUE(utils::isa<std::string>(anf_node->value()));
  316. ASSERT_EQ(utils::cast<std::string>(anf_node->value()), "hello");
  317. auto int_value = MakeValue(123);
  318. ASSERT_TRUE(utils::isa<int64_t>(int_value));
  319. ASSERT_EQ(utils::cast<int64_t>(int_value), 123);
  320. anf_node = nullptr;
  321. ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
  322. ASSERT_FALSE(utils::isa<AnfNodePtr>(anf_node));
  323. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);
  324. // Test clone graph.
  325. auto fg = FuncGraph::Create();
  326. auto x = fg->add_parameter();
  327. x->set_name("x");
  328. auto y = fg->add_parameter();
  329. y->set_name("y");
  330. auto add = MakeShared<Primitive>("add");
  331. auto add_node = MakeShared<ValueNode>(add);
  332. auto add_cnode = fg->NewCNode({add_node, x, y});
  333. auto prim = MakeShared<Primitive>("myprim");
  334. auto prim_node = MakeShared<ValueNode>(prim);
  335. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  336. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  337. fg->set_output(cnode);
  338. auto cloned_fg = utils::CloneGraph(fg);
  339. ASSERT_TRUE(cloned_fg != nullptr);
  340. ASSERT_EQ(cloned_fg->parameters().size(), 2);
  341. auto new_output = cloned_fg->output();
  342. ASSERT_TRUE(new_output != nullptr);
  343. ASSERT_TRUE(new_output->isa<CNode>());
  344. ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
  345. ASSERT_TRUE(new_output != cnode);
  346. ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);
  347. // Test get pad mode.
  348. auto pm_lower = MakeValue("pad");
  349. auto pm_upper = MakeValue("PAD");
  350. ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
  351. ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
  352. ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
  353. }
  354. /// Feature: MindAPI
  355. /// Description: test logging API.
  356. /// Expectation: logging work as expected.
  357. TEST_F(TestMindApi, test_api_logging) {
  358. MS_LOG(DEBUG) << "hello debug";
  359. MS_LOG(INFO) << "hello info";
  360. MS_LOG(WARNING) << "hello warning";
  361. MS_LOG(ERROR) << "hello error";
  362. try {
  363. MS_LOG(EXCEPTION) << "hello exception";
  364. ASSERT_TRUE(false);
  365. } catch (...) {
  366. }
  367. ASSERT_TRUE(true);
  368. }
  369. } // namespace mindspore::api