You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindapi_test.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cmath>
  17. #include <memory>
  18. #include <sstream>
  19. #include <unordered_map>
  20. #include "common/common_test.h"
  21. #include "mindapi/base/logging.h"
  22. #include "mindapi/ir/func_graph.h"
  23. #include "mindapi/ir/tensor.h"
  24. #include "mindapi/ir/utils.h"
  25. namespace mindspore::api {
  26. class TestMindApi : public UT::Common {
  27. public:
  28. TestMindApi() = default;
  29. };
  30. /// Feature: MindAPI
  31. /// Description: test basic 'is()' 'cast()'
  32. /// Expectation: is/cast works correctly.
  33. TEST_F(TestMindApi, test_base_isa_cast) {
  34. auto value_node = MakeShared<ValueNode>(MakeValue(0));
  35. auto base = MakeShared<Base>(value_node->impl());
  36. ASSERT_TRUE(base->isa<Base>());
  37. ASSERT_TRUE(base->isa<AnfNode>());
  38. ASSERT_TRUE(base->isa<ValueNode>());
  39. ASSERT_FALSE(base->isa<AbstractBase>());
  40. auto anf_node = base->cast<AnfNodePtr>();
  41. ASSERT_TRUE(anf_node != nullptr);
  42. ASSERT_TRUE(anf_node->impl() == value_node->impl());
  43. ASSERT_TRUE(base->cast<AbstractBasePtr>() == nullptr);
  44. }
  45. /// Feature: MindAPI
  46. /// Description: test graph construction.
  47. /// Expectation: graph is constructed as expected.
  48. TEST_F(TestMindApi, test_graph_construction) {
  49. // fg(x) { return myprim(x, 1); }
  50. auto fg = FuncGraph::Create();
  51. auto x = fg->add_parameter();
  52. x->set_name("x");
  53. auto prim = MakeShared<Primitive>("myprim");
  54. auto prim_node = MakeShared<ValueNode>(prim);
  55. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  56. auto cnode = fg->NewCNode({prim_node, x, value_node});
  57. fg->set_output(cnode);
  58. // Now we check the graph.
  59. ASSERT_EQ(fg->parameters().size(), 1);
  60. ASSERT_TRUE(fg->parameters()[0]->isa<Parameter>());
  61. ASSERT_EQ(fg->parameters()[0]->cast<ParameterPtr>()->name(), "x");
  62. auto ret_node = fg->get_return();
  63. ASSERT_TRUE(ret_node != nullptr);
  64. auto output_node = fg->output();
  65. ASSERT_TRUE(output_node != nullptr);
  66. ASSERT_TRUE(output_node->isa<CNode>());
  67. auto output_cnode = output_node->cast<CNodePtr>();
  68. ASSERT_EQ(output_cnode->inputs().size(), 3);
  69. ASSERT_TRUE(output_cnode->input(0)->isa<ValueNode>());
  70. ASSERT_TRUE(output_cnode->input(0)->cast<ValueNodePtr>()->value()->isa<Primitive>());
  71. ASSERT_EQ(output_cnode->input(0)->cast<ValueNodePtr>()->value()->cast<PrimitivePtr>()->name(), "myprim");
  72. ASSERT_TRUE(output_cnode->input(1)->isa<Parameter>());
  73. ASSERT_EQ(output_cnode->input(1)->cast<ParameterPtr>()->name(), "x");
  74. ASSERT_TRUE(output_cnode->input(2)->isa<ValueNode>());
  75. ASSERT_EQ(output_cnode->impl(), cnode->impl());
  76. }
  77. /// Feature: MindAPI
  78. /// Description: test value related functions.
  79. /// Expectation: value related functions work as expected.
  80. TEST_F(TestMindApi, test_values) {
  81. int64_t one = 1;
  82. auto s = MakeValue("hello");
  83. auto i = MakeValue(one);
  84. auto i2 = MakeValue(2);
  85. auto b = MakeValue(true);
  86. auto f = MakeValue(3.14f);
  87. auto seq = MakeValue(std::vector<int64_t>{3, 4, 5});
  88. auto seq_str = MakeValue(std::vector<std::string>({"this", "is", "mindspore", "api"}));
  89. ASSERT_TRUE(s->isa<StringImm>());
  90. ASSERT_TRUE(i->isa<Int64Imm>());
  91. ASSERT_TRUE(i2->isa<Int64Imm>());
  92. ASSERT_TRUE(b->isa<BoolImm>());
  93. ASSERT_TRUE(f->isa<FP32Imm>());
  94. ASSERT_TRUE(seq->isa<ValueSequence>());
  95. ASSERT_TRUE(seq_str->isa<ValueSequence>());
  96. ASSERT_EQ(GetValue<std::string>(s), "hello");
  97. ASSERT_EQ(GetValue<int64_t>(i), one);
  98. ASSERT_EQ(GetValue<int64_t>(i2), 2);
  99. ASSERT_TRUE(GetValue<bool>(b));
  100. ASSERT_TRUE(std::abs(GetValue<float>(f) - 3.14f) < 0.00001f);
  101. ASSERT_EQ(GetValue<std::string>(i), "");
  102. ASSERT_EQ(GetValue<int64_t>(s), 0);
  103. ASSERT_FALSE(GetValue<bool>(s));
  104. ASSERT_EQ(GetValue<float>(s), 0.0f);
  105. auto seq_ptr = seq->cast<ValueSequencePtr>();
  106. ASSERT_TRUE(seq_ptr != nullptr);
  107. ASSERT_EQ(seq_ptr->size(), 3);
  108. ASSERT_EQ(seq_ptr->value().size(), 3);
  109. ASSERT_TRUE(seq_ptr->value()[0]->isa<Int64Imm>());
  110. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[0]), 3);
  111. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[1]), 4);
  112. ASSERT_EQ(GetValue<int64_t>(seq_ptr->value()[2]), 5);
  113. auto seq_values = GetValue<std::vector<int64_t>>(seq);
  114. ASSERT_EQ(seq_values.size(), 3);
  115. ASSERT_EQ(seq_values[0], 3);
  116. ASSERT_EQ(seq_values[1], 4);
  117. ASSERT_EQ(seq_values[2], 5);
  118. auto str_values = GetValue<std::vector<std::string>>(seq_str);
  119. ASSERT_EQ(str_values.size(), 4);
  120. ASSERT_EQ(str_values[0], "this");
  121. ASSERT_EQ(str_values[1], "is");
  122. ASSERT_EQ(str_values[2], "mindspore");
  123. ASSERT_EQ(str_values[3], "api");
  124. auto value_list = GetValue<ValuePtrList>(seq);
  125. ASSERT_EQ(value_list.size(), 3);
  126. ASSERT_EQ(utils::cast<int64_t>(value_list[0]), 3);
  127. ASSERT_EQ(utils::cast<int64_t>(value_list[1]), 4);
  128. ASSERT_EQ(utils::cast<int64_t>(value_list[2]), 5);
  129. }
  130. /// Feature: MindAPI
  131. /// Description: test graph manager functions.
  132. /// Expectation: graph manager functions work as expected.
  133. TEST_F(TestMindApi, test_func_graph_manager) {
  134. // fg(x, y) { return myprim(add(x, y), 1); }
  135. auto fg = FuncGraph::Create();
  136. auto x = fg->add_parameter();
  137. x->set_name("x");
  138. auto y = fg->add_parameter();
  139. y->set_name("y");
  140. auto add = MakeShared<Primitive>("add");
  141. auto add_node = MakeShared<ValueNode>(add);
  142. auto add_cnode = fg->NewCNode({add_node, x, y});
  143. auto prim = MakeShared<Primitive>("myprim");
  144. auto prim_node = MakeShared<ValueNode>(prim);
  145. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  146. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  147. fg->set_output(cnode);
  148. auto mgr = FuncGraphManager::Manage(fg);
  149. ASSERT_TRUE(mgr != nullptr);
  150. ASSERT_TRUE(fg->manager() != nullptr);
  151. ASSERT_EQ(fg->manager()->impl(), mgr->impl());
  152. ASSERT_EQ(fg->manager(), mgr);
  153. ASSERT_EQ(cnode->input(1)->impl(), add_cnode->impl());
  154. mgr->Replace(add_cnode, x);
  155. ASSERT_EQ(cnode->input(1)->impl(), x->impl());
  156. mgr->SetEdge(cnode, 1, y);
  157. ASSERT_EQ(cnode->input(1)->impl(), y->impl());
  158. mgr->AddEdge(cnode, x);
  159. ASSERT_EQ(cnode->size(), 4);
  160. ASSERT_EQ(cnode->input(3)->impl(), x->impl());
  161. auto users = mgr->GetUsers(value_node);
  162. ASSERT_EQ(users.size(), 1);
  163. ASSERT_EQ(users[0].first, cnode);
  164. ASSERT_EQ(users[0].second, 2);
  165. }
  166. /// Feature: MindAPI
  167. /// Description: test value node utils.
  168. /// Expectation: value node utils work as expected.
  169. TEST_F(TestMindApi, test_value_node_utils) {
  170. auto fg = FuncGraph::Create();
  171. auto fg_node = MakeShared<ValueNode>(fg);
  172. auto prim = MakeShared<Primitive>("myprim");
  173. auto prim_node = MakeShared<ValueNode>(prim);
  174. auto one = MakeShared<ValueNode>(MakeValue(1));
  175. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  176. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode) == nullptr);
  177. auto fg1 = GetValueNode<FuncGraphPtr>(cnode->input(0));
  178. ASSERT_TRUE(fg1 != nullptr);
  179. ASSERT_TRUE(fg1->isa<FuncGraph>());
  180. auto prim1 = GetValueNode<PrimitivePtr>(cnode->input(1));
  181. ASSERT_TRUE(prim1 != nullptr);
  182. ASSERT_TRUE(prim1->isa<Primitive>());
  183. auto imm = GetValueNode<Int64ImmPtr>(cnode->input(2));
  184. ASSERT_TRUE(imm != nullptr);
  185. ASSERT_TRUE(imm->isa<Int64Imm>());
  186. ASSERT_EQ(imm->cast<Int64ImmPtr>()->value(), 1);
  187. auto value = GetValueNode(cnode->input(2));
  188. ASSERT_TRUE(value != nullptr);
  189. ASSERT_EQ(GetValue<int64_t>(value), 1);
  190. ASSERT_TRUE(GetValueNode<PrimitivePtr>(cnode->input(0)) == nullptr);
  191. ASSERT_TRUE(GetValueNode<FuncGraphPtr>(cnode->input(1)) == nullptr);
  192. ASSERT_TRUE(GetValueNode<StringImmPtr>(cnode->input(2)) == nullptr);
  193. // Test NewValueNode.
  194. auto int_node = NewValueNode(1);
  195. auto bool_node = NewValueNode(true);
  196. auto float_node = NewValueNode(1.23f);
  197. auto str_node = NewValueNode("hello");
  198. ASSERT_TRUE(int_node->value()->isa<Int64Imm>());
  199. ASSERT_EQ(int_node->value()->cast<Int64ImmPtr>()->value(), 1);
  200. ASSERT_TRUE(bool_node->value()->isa<BoolImm>());
  201. ASSERT_TRUE(bool_node->value()->cast<BoolImmPtr>()->value());
  202. ASSERT_TRUE(float_node->value()->isa<FP32Imm>());
  203. ASSERT_TRUE(std::abs(float_node->value()->cast<FP32ImmPtr>()->value() - 1.23f) < 0.0000001f);
  204. ASSERT_TRUE(str_node->value()->isa<StringImm>());
  205. ASSERT_EQ(str_node->value()->cast<StringImmPtr>()->value(), "hello");
  206. }
  207. /// Feature: MindAPI
  208. /// Description: test SharedPtr.
  209. /// Expectation: SharedPtr work as expected.
  210. TEST_F(TestMindApi, test_object_ptr) {
  211. auto fg = FuncGraph::Create();
  212. auto fg_node = MakeShared<ValueNode>(fg);
  213. auto prim = MakeShared<Primitive>("myprim");
  214. auto prim_node = MakeShared<ValueNode>(prim);
  215. auto one = MakeShared<ValueNode>(MakeValue(1));
  216. auto cnode = fg->NewCNode({fg_node, prim_node, one});
  217. ASSERT_TRUE(fg != nullptr);
  218. ASSERT_FALSE(!fg);
  219. ASSERT_TRUE(fg ? true : false);
  220. ASSERT_TRUE((*cnode).input(0) == fg_node);
  221. ASSERT_TRUE(cnode->input(0) == fg_node);
  222. ASSERT_TRUE(cnode.get()->input(0) == fg_node);
  223. ASSERT_EQ(cnode->input(0), fg_node);
  224. ASSERT_EQ(cnode->input(1), prim_node);
  225. ASSERT_EQ(cnode->input(2), one);
  226. ASSERT_TRUE(cnode->input(0) != fg);
  227. AnfNodePtr p = fg_node;
  228. ASSERT_TRUE(p == fg_node);
  229. ASSERT_TRUE(p->isa<ValueNode>());
  230. ASSERT_TRUE(p->cast<ValueNodePtr>() != nullptr);
  231. ASSERT_TRUE(p->cast<ValueNodePtr>() == fg_node);
  232. p = cnode;
  233. ASSERT_TRUE(p == cnode);
  234. ASSERT_TRUE(p->isa<CNode>());
  235. ASSERT_TRUE(p->cast<CNodePtr>() != nullptr);
  236. ASSERT_TRUE(p->cast<CNodePtr>() == cnode);
  237. ASSERT_TRUE(p.get() == cnode.get());
  238. ASSERT_TRUE(p != nullptr);
  239. ASSERT_FALSE(p == nullptr);
  240. ASSERT_TRUE(p > nullptr);
  241. ASSERT_FALSE(p < nullptr);
  242. ASSERT_TRUE(p >= nullptr);
  243. ASSERT_FALSE(p <= nullptr);
  244. ASSERT_TRUE(nullptr != p);
  245. ASSERT_FALSE(nullptr == p);
  246. ASSERT_TRUE(nullptr < p);
  247. ASSERT_FALSE(nullptr > p);
  248. ASSERT_TRUE(nullptr <= p);
  249. ASSERT_FALSE(nullptr >= p);
  250. AnfNodePtr q = fg_node;
  251. ASSERT_TRUE(p != q);
  252. if (p.get()->impl() > q.get()->impl()) {
  253. ASSERT_TRUE(p > q);
  254. ASSERT_TRUE(p >= q);
  255. ASSERT_TRUE(q < p);
  256. ASSERT_TRUE(q <= p);
  257. } else {
  258. ASSERT_TRUE(p < q);
  259. ASSERT_TRUE(p <= q);
  260. ASSERT_TRUE(q > p);
  261. ASSERT_TRUE(q >= p);
  262. }
  263. std::stringstream ss1;
  264. std::stringstream ss2;
  265. ss1 << p;
  266. ss2 << cnode.get()->impl().get();
  267. ASSERT_EQ(ss1.str(), ss2.str());
  268. std::unordered_map<AnfNodePtr, AnfNodePtr> mymap;
  269. mymap.emplace(p, q);
  270. mymap.emplace(q, p);
  271. ASSERT_TRUE(mymap.find(p) != mymap.end());
  272. ASSERT_TRUE(mymap.find(q) != mymap.end());
  273. ASSERT_TRUE(mymap[p] == q);
  274. ASSERT_TRUE(mymap[q] == p);
  275. }
  276. /// Feature: MindAPI
  277. /// Description: test Tensor API.
  278. /// Expectation: Tensor API work as expected.
  279. TEST_F(TestMindApi, test_tensor_api) {
  280. ShapeVector shape{1, 2, 3};
  281. auto tensor = MakeShared<Tensor>(kNumberTypeFloat32, shape);
  282. ASSERT_EQ(tensor->data_type(), kNumberTypeFloat32);
  283. ASSERT_EQ(tensor->shape(), shape);
  284. ASSERT_EQ(tensor->DataSize(), 6);
  285. ASSERT_EQ(tensor->Size(), 24);
  286. ShapeVector shape2{2, 3};
  287. tensor->set_data_type(kNumberTypeInt32);
  288. tensor->set_shape(shape2);
  289. ASSERT_EQ(tensor->data_type(), kNumberTypeInt32);
  290. ASSERT_EQ(tensor->shape(), shape2);
  291. // TensorType.
  292. TypePtr tensor_type = MakeShared<TensorType>(Type::GetType(TypeId::kNumberTypeFloat32));
  293. ASSERT_TRUE(tensor_type->isa<TensorType>());
  294. ASSERT_EQ(tensor_type->cast<TensorTypePtr>()->element()->type_id(), kNumberTypeFloat32);
  295. }
  296. /// Feature: MindAPI
  297. /// Description: test utils API.
  298. /// Expectation: Tensor API work as expected.
  299. TEST_F(TestMindApi, test_api_utils) {
  300. // Test utils::isa, utils::cast.
  301. auto anf_node = NewValueNode("hello");
  302. ASSERT_TRUE(utils::isa<AnfNode>(anf_node));
  303. ASSERT_TRUE(utils::isa<AnfNodePtr>(anf_node));
  304. ASSERT_FALSE(utils::isa<AbstractBase>(anf_node));
  305. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) != nullptr);
  306. ASSERT_TRUE(utils::cast<AbstractBasePtr>(anf_node) == nullptr);
  307. ASSERT_TRUE(utils::isa<std::string>(anf_node->value()));
  308. ASSERT_EQ(utils::cast<std::string>(anf_node->value()), "hello");
  309. auto int_value = MakeValue(123);
  310. ASSERT_TRUE(utils::isa<int64_t>(int_value));
  311. ASSERT_EQ(utils::cast<int64_t>(int_value), 123);
  312. anf_node = nullptr;
  313. ASSERT_FALSE(utils::isa<AnfNode>(anf_node));
  314. ASSERT_FALSE(utils::isa<AnfNodePtr>(anf_node));
  315. ASSERT_TRUE(utils::cast<AnfNodePtr>(anf_node) == nullptr);
  316. // Test clone graph.
  317. auto fg = FuncGraph::Create();
  318. auto x = fg->add_parameter();
  319. x->set_name("x");
  320. auto y = fg->add_parameter();
  321. y->set_name("y");
  322. auto add = MakeShared<Primitive>("add");
  323. auto add_node = MakeShared<ValueNode>(add);
  324. auto add_cnode = fg->NewCNode({add_node, x, y});
  325. auto prim = MakeShared<Primitive>("myprim");
  326. auto prim_node = MakeShared<ValueNode>(prim);
  327. auto value_node = MakeShared<ValueNode>(MakeValue(1));
  328. auto cnode = fg->NewCNode({prim_node, add_cnode, value_node});
  329. fg->set_output(cnode);
  330. auto cloned_fg = utils::CloneGraph(fg);
  331. ASSERT_TRUE(cloned_fg != nullptr);
  332. ASSERT_EQ(cloned_fg->parameters().size(), 2);
  333. auto new_output = cloned_fg->output();
  334. ASSERT_TRUE(new_output != nullptr);
  335. ASSERT_TRUE(new_output->isa<CNode>());
  336. ASSERT_EQ(new_output->cast<CNodePtr>()->size(), cnode->size());
  337. ASSERT_TRUE(new_output != cnode);
  338. ASSERT_TRUE(new_output->cast<CNodePtr>() != cnode);
  339. // Test get pad mode.
  340. auto pm_lower = MakeValue("pad");
  341. auto pm_upper = MakeValue("PAD");
  342. ASSERT_EQ(utils::GetPadMode(pm_lower), 0);
  343. ASSERT_EQ(utils::GetPadMode(pm_lower, false), 0);
  344. ASSERT_EQ(utils::GetPadMode(pm_upper, true), 0);
  345. }
  346. /// Feature: MindAPI
  347. /// Description: test logging API.
  348. /// Expectation: logging work as expected.
  349. TEST_F(TestMindApi, test_api_logging) {
  350. MS_LOG(DEBUG) << "hello debug";
  351. MS_LOG(INFO) << "hello info";
  352. MS_LOG(WARNING) << "hello warning";
  353. MS_LOG(ERROR) << "hello error";
  354. try {
  355. MS_LOG(EXCEPTION) << "hello exception";
  356. ASSERT_TRUE(false);
  357. } catch (...) {
  358. }
  359. ASSERT_TRUE(true);
  360. }
  361. } // namespace mindspore::api