You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_test.cc 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <memory>
  18. #include "common/common_test.h"
  19. #include "common/py_func_graph_fetcher.h"
  20. #include "pipeline/jit/static_analysis/prim.h"
  21. #include "frontend/operator/ops.h"
  22. #include "abstract/utils.h"
  23. namespace mindspore {
  24. namespace abstract {
  25. class TestData : public UT::Common {
  26. public:
  27. void SetUp();
  28. void TearDown();
  29. };
  30. void TestData::SetUp() { UT::InitPythonPath(); }
  31. void TestData::TearDown() {
  32. // destroy resource
  33. }
  34. TEST_F(TestData, test_build_value) {
  35. // assert build_value(S(1)) == 1
  36. AbstractScalar s1 = AbstractScalar(static_cast<int64_t>(1));
  37. ASSERT_EQ(1, s1.BuildValue()->cast<Int64ImmPtr>()->value());
  38. // assert build_value(S(t=ty.Int[64]), default=ANYTHING) is ANYTHING
  39. s1 = AbstractScalar(kAnyValue, kInt64);
  40. ASSERT_TRUE(s1.BuildValue()->isa<AnyValue>());
  41. ASSERT_TRUE(s1.BuildValue()->isa<AnyValue>());
  42. // assert build_value(T([S(1), S(2)])) == (1, 2)
  43. AbstractBasePtr base1 = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
  44. AbstractBasePtr base2 = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
  45. AbstractBasePtrList base_list = {base1, base2};
  46. AbstractTuple t1 = AbstractTuple(base_list);
  47. std::vector<ValuePtr> value_list = {MakeValue(static_cast<int64_t>(1)), MakeValue(static_cast<int64_t>(2))};
  48. auto tup = t1.BuildValue()->cast<ValueTuplePtr>()->value();
  49. ASSERT_TRUE(tup.size() == value_list.size());
  50. for (int i = 0; i < value_list.size(); i++) {
  51. ASSERT_EQ(*tup[i], *value_list[i]);
  52. }
  53. // BuildValue(AbstractFunction) should return kAnyValue.
  54. AbstractBasePtr abs_f1 = FromValue(prim::kPrimReturn, false);
  55. ValuePtr abs_f1_built = abs_f1->BuildValue();
  56. ASSERT_EQ(abs_f1_built, prim::kPrimReturn);
  57. FuncGraphPtr fg1 = std::make_shared<FuncGraph>();
  58. AbstractBasePtr abs_fg1 = FromValue(fg1, false);
  59. ValuePtr abs_fg1_built = abs_fg1->BuildValue();
  60. ASSERT_EQ(abs_fg1_built, kAnyValue);
  61. // BuildValue(Tuple(AbstractFunction)) should return kAnyValue;
  62. AbstractBasePtr abs_f2 = FromValue(prim::kPrimScalarAdd, false);
  63. AbstractBasePtr abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_f1, abs_f2}));
  64. ValuePtr func_tuple_built = abs_func_tuple->BuildValue();
  65. ASSERT_EQ(*func_tuple_built,
  66. ValueTuple(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
  67. // BuildValue(List(AbstractFunction)) should return kAnyValue;
  68. AbstractBasePtr abs_func_list = std::make_shared<AbstractList>(AbstractBasePtrList({abs_f1, abs_f2}));
  69. ValuePtr func_list_built = abs_func_list->BuildValue();
  70. ASSERT_EQ(*func_list_built,
  71. ValueList(std::vector<ValuePtr>{prim::kPrimReturn, prim::kPrimScalarAdd}));
  72. // BuildValue(Tuple(AnyAbstractBase, AbstractFunction)) should return kAnyValue
  73. abs_func_tuple = std::make_shared<AbstractTuple>(AbstractBasePtrList({base1, abs_f2}));
  74. func_tuple_built = abs_func_tuple->BuildValue();
  75. ASSERT_EQ(*func_tuple_built,
  76. ValueTuple(std::vector<ValuePtr>{std::make_shared<Int64Imm>(1), prim::kPrimScalarAdd}));
  77. }
  78. TEST_F(TestData, test_build_type) {
  79. AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false);
  80. AbstractBasePtr s2 = FromValue(static_cast<int64_t>(2), false);
  81. ASSERT_TRUE(Int(64) == *s1->BuildType());
  82. AbstractFunctionPtr f1 = std::make_shared<PrimitiveAbstractClosure>(nullptr, nullptr);
  83. ASSERT_TRUE(Function() == *f1->BuildType());
  84. AbstractList l1 = AbstractList({s1, s2});
  85. ASSERT_TRUE(List({std::make_shared<Int>(64), std::make_shared<Int>(64)}) == *l1.BuildType());
  86. }
  87. TEST_F(TestData, test_build_shape) {
  88. AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false);
  89. AbstractBasePtr s2 = FromValue(static_cast<int64_t>(2), false);
  90. ASSERT_TRUE(NoShape() == *s1->BuildShape());
  91. AbstractFunctionPtr f1 = std::make_shared<PrimitiveAbstractClosure>(nullptr, nullptr);
  92. ASSERT_TRUE(NoShape() == *f1->BuildShape());
  93. AbstractList l1 = AbstractList({s1, s2});
  94. auto lshape = l1.BuildShape();
  95. ASSERT_TRUE(lshape);
  96. std::vector<int64_t> weight1_dims = {2, 20, 5, 5};
  97. std::vector<int64_t> weight2_dims = {2, 2, 5, 5};
  98. tensor::TensorPtr weight1 = std::make_shared<tensor::Tensor>(kNumberTypeInt64, weight1_dims);
  99. tensor::TensorPtr weight2 = std::make_shared<tensor::Tensor>(kNumberTypeInt64, weight2_dims);
  100. AbstractBasePtr abstract_weight1 = FromValue(weight1, true);
  101. AbstractBasePtr abstract_weight2 = FromValue(weight2, true);
  102. ShapePtr shape_weight = dyn_cast<Shape>(abstract_weight1->BuildShape());
  103. ASSERT_TRUE(shape_weight);
  104. ASSERT_EQ(weight1_dims, shape_weight->shape());
  105. std::vector<ValuePtr> vec({weight1, weight2});
  106. AbstractBasePtr abstract_tup = FromValue(vec, true);
  107. std::shared_ptr<TupleShape> shape_tuple = dyn_cast<TupleShape>(abstract_tup->BuildShape());
  108. ASSERT_TRUE(shape_tuple);
  109. const std::vector<BaseShapePtr>& ptr_vec = shape_tuple->shape();
  110. ASSERT_EQ(ptr_vec.size(), 2);
  111. ShapePtr shape1 = dyn_cast<Shape>(ptr_vec[0]);
  112. ASSERT_TRUE(shape1);
  113. ASSERT_EQ(weight1_dims, shape1->shape());
  114. ShapePtr shape2 = dyn_cast<Shape>(ptr_vec[1]);
  115. ASSERT_TRUE(shape2);
  116. ASSERT_EQ(weight2_dims, shape2->shape());
  117. }
  118. TEST_F(TestData, test_clone) {
  119. AbstractBasePtr s1 = FromValue(static_cast<int64_t>(1), false);
  120. AbstractBasePtr s2 = s1->Clone();
  121. ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
  122. ASSERT_TRUE(s1->GetValueTrack() == s2->GetValueTrack());
  123. ASSERT_TRUE(*s1->GetShapeTrack() == *s2->GetShapeTrack());
  124. AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
  125. AnalysisContext::DummyContext());
  126. AbstractBasePtr f2 = f1->Clone();
  127. ASSERT_TRUE(*f2 == *f1);
  128. AbstractList l1 = AbstractList({s1, s2});
  129. AbstractBasePtr l2 = l1.Clone();
  130. AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
  131. ASSERT_TRUE(l2_cast != nullptr);
  132. ASSERT_TRUE(l2_cast->GetValueTrack() == l1.GetValueTrack());
  133. std::vector<AbstractAttribute> attr = {{"x", std::make_shared<AbstractScalar>(kAnyValue, kInt64)},
  134. {"y", std::make_shared<AbstractScalar>(kAnyValue, kInt64)}};
  135. std::unordered_map<std::string, ValuePtr> methods;
  136. AbstractBasePtr c1 = std::make_shared<AbstractClass>(Named("Point"), attr, methods);
  137. AbstractBasePtr c2 = c1->Clone();
  138. ASSERT_EQ(*c1, *c2);
  139. }
  140. TEST_F(TestData, test_join) {
  141. int64_t int1 = 1;
  142. AbstractBasePtr s1 = FromValue(int1, false);
  143. AbstractBasePtr s2 = s1->Broaden();
  144. std::vector<AbstractBasePtr> xx = {s1, s2};
  145. AbstractListPtr l1 = std::make_shared<AbstractList>(xx);
  146. AbstractListPtr l2 = std::make_shared<AbstractList>(xx);
  147. l1->Join(l2);
  148. }
  149. TEST_F(TestData, test_broaden) {
  150. int64_t int1 = 1;
  151. AbstractBasePtr s1 = FromValue(int1, false);
  152. AbstractBasePtr s2 = s1->Broaden();
  153. ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack());
  154. ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1));
  155. ASSERT_TRUE(s2->GetValueTrack()->isa<Int64Imm>());
  156. AbstractFunctionPtr f1 = std::make_shared<FuncGraphAbstractClosure>(std::make_shared<FuncGraph>(),
  157. AnalysisContext::DummyContext());
  158. AbstractBasePtr f2 = f1->Broaden();
  159. ASSERT_TRUE(f2 == f1);
  160. AbstractList l1 = AbstractList({s1, s2});
  161. AbstractBasePtr l2 = l1.Broaden();
  162. AbstractList* l2_cast = dynamic_cast<AbstractList*>(l2.get());
  163. ASSERT_TRUE(l2_cast != nullptr);
  164. AbstractBasePtr csr = AbstractJoin(l2_cast->elements());
  165. ASSERT_TRUE(csr->GetValueTrack()->isa<Int64Imm>());
  166. }
  167. } // namespace abstract
  168. } // namespace mindspore