You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_test.cc 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include <string>
  18. #include "minddata/dataset/core/client.h"
  19. #include "common/common.h"
  20. #include "gtest/gtest.h"
  21. #include "securec.h"
  22. #include "minddata/dataset/core/tensor.h"
  23. #include "minddata/dataset/core/cv_tensor.h"
  24. #include "minddata/dataset/core/data_type.h"
  25. using namespace mindspore::dataset;
  26. namespace py = pybind11;
  27. class MindDataTestTensorDE : public UT::Common {
  28. public:
  29. MindDataTestTensorDE() {}
  30. void SetUp() { GlobalInit(); }
  31. };
  32. TEST_F(MindDataTestTensorDE, Basics) {
  33. std::shared_ptr<Tensor> t;
  34. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_UINT64), &t);
  35. ASSERT_EQ(t->shape(), TensorShape({2, 3}));
  36. ASSERT_EQ(t->type(), DataType::DE_UINT64);
  37. ASSERT_EQ(t->SizeInBytes(), 2 * 3 * 8);
  38. ASSERT_EQ(t->Rank(), 2);
  39. t->SetItemAt<uint64_t>({0, 0}, 1);
  40. t->SetItemAt<uint64_t>({0, 1}, 2);
  41. t->SetItemAt<uint64_t>({0, 2}, 3);
  42. t->SetItemAt<uint64_t>({1, 0}, 4);
  43. t->SetItemAt<uint64_t>({1, 1}, 5);
  44. t->SetItemAt<uint64_t>({1, 2}, 6);
  45. Status rc = t->SetItemAt<uint64_t>({2, 3}, 7);
  46. ASSERT_TRUE(rc.IsError());
  47. uint64_t o;
  48. t->GetItemAt<uint64_t>(&o, {0, 0});
  49. ASSERT_EQ(o, 1);
  50. t->GetItemAt<uint64_t>(&o, {0, 1});
  51. ASSERT_EQ(o, 2);
  52. t->GetItemAt<uint64_t>(&o, {0, 2});
  53. ASSERT_EQ(o, 3);
  54. t->GetItemAt<uint64_t>(&o, {1, 0});
  55. ASSERT_EQ(o, 4);
  56. t->GetItemAt<uint64_t>(&o, {1, 1});
  57. ASSERT_EQ(o, 5);
  58. t->GetItemAt<uint64_t>(&o, {1, 2});
  59. ASSERT_EQ(o, 6);
  60. rc = t->GetItemAt<uint64_t>(&o, {2, 3});
  61. ASSERT_TRUE(rc.IsError());
  62. ASSERT_EQ(t->ToString(), "Tensor (shape: <2,3>, Type: uint64)\n[[1,2,3],[4,5,6]]");
  63. std::vector<uint64_t> x = {1, 2, 3, 4, 5, 6};
  64. std::shared_ptr<Tensor> t2;
  65. Tensor::CreateFromVector(x, TensorShape({2, 3}), &t2);
  66. ASSERT_EQ(*t == *t2, true);
  67. ASSERT_EQ(*t != *t2, false);
  68. }
  69. TEST_F(MindDataTestTensorDE, Fill) {
  70. std::shared_ptr<Tensor> t;
  71. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t);
  72. t->Fill<float>(2.5);
  73. std::vector<float> x = {2.5, 2.5, 2.5, 2.5};
  74. std::shared_ptr<Tensor> t2;
  75. Tensor::CreateFromVector(x, TensorShape({2, 2}), &t2);
  76. ASSERT_EQ(*t == *t2, true);
  77. }
  78. TEST_F(MindDataTestTensorDE, Reshape) {
  79. std::shared_ptr<Tensor> t;
  80. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t);
  81. t->Fill<uint8_t>(254);
  82. t->Reshape(TensorShape({4}));
  83. std::vector<uint8_t> x = {254, 254, 254, 254};
  84. std::shared_ptr<Tensor> t2;
  85. Tensor::CreateFromVector(x, &t2);
  86. ASSERT_EQ(*t == *t2, true);
  87. Status rc = t->Reshape(TensorShape({5}));
  88. ASSERT_TRUE(rc.IsError());
  89. t2->ExpandDim(0);
  90. ASSERT_EQ(t2->shape(), TensorShape({1, 4}));
  91. t2->ExpandDim(2);
  92. ASSERT_EQ(t2->shape(), TensorShape({1, 4, 1}));
  93. rc = t2->ExpandDim(4);
  94. ASSERT_TRUE(rc.IsError());
  95. }
  96. TEST_F(MindDataTestTensorDE, CopyTensor) {
  97. std::shared_ptr<Tensor> t;
  98. Tensor::CreateEmpty(TensorShape({}), DataType(DataType::DE_INT16), &t);
  99. t->SetItemAt<int16_t>({}, -66);
  100. ASSERT_EQ(t->shape(), TensorShape({}));
  101. ASSERT_EQ(t->type(), DataType::DE_INT16);
  102. int16_t o;
  103. t->GetItemAt<int16_t>(&o, {});
  104. ASSERT_EQ(o, -66);
  105. const unsigned char *addr = t->GetBuffer();
  106. auto t2 = std::make_shared<Tensor>(std::move(*t));
  107. ASSERT_EQ(t2->shape(), TensorShape({}));
  108. ASSERT_EQ(t2->type(), DataType::DE_INT16);
  109. t2->GetItemAt<int16_t>(&o, {});
  110. ASSERT_EQ(o, -66);
  111. const unsigned char *new_addr = t2->GetBuffer();
  112. ASSERT_EQ(addr, new_addr);
  113. ASSERT_EQ(t->shape(), TensorShape::CreateUnknownRankShape());
  114. ASSERT_EQ(t->type(), DataType::DE_UNKNOWN);
  115. ASSERT_EQ(t->GetBuffer(), nullptr);
  116. Status rc = t->GetItemAt<int16_t>(&o, {});
  117. ASSERT_TRUE(rc.IsError());
  118. }
  119. TEST_F(MindDataTestTensorDE, InsertTensor) {
  120. std::shared_ptr<Tensor> t;
  121. Tensor::CreateEmpty(TensorShape({2, 3}), DataType(DataType::DE_FLOAT64), &t);
  122. std::vector<double> x = {1.1, 2.1, 3.1};
  123. std::shared_ptr<Tensor> t2;
  124. Tensor::CreateFromVector(x, &t2);
  125. std::vector<double> y = {1.2, 2.2, 3.2};
  126. std::shared_ptr<Tensor> t3;
  127. Tensor::CreateFromVector(y, &t3);
  128. ASSERT_TRUE(t->InsertTensor({0}, t2).OK());
  129. ASSERT_TRUE(t->InsertTensor({1}, t3).OK());
  130. std::vector<double> z = {1.1, 2.1, 3.1, 1.2, 2.2, 3.2};
  131. std::shared_ptr<Tensor> t4;
  132. Tensor::CreateFromVector(z, TensorShape({2, 3}), &t4);
  133. ASSERT_EQ(*t == *t4, true);
  134. std::shared_ptr<Tensor> t5;
  135. Tensor::CreateScalar<double>(0, &t5);
  136. ASSERT_TRUE(t->InsertTensor({1, 2}, t5).OK());
  137. z[5] = 0;
  138. std::shared_ptr<Tensor> t6;
  139. Tensor::CreateFromVector(z, TensorShape({2, 3}), &t6);
  140. ASSERT_EQ(*t == *t6, true);
  141. ASSERT_EQ(t->InsertTensor({2}, t5).get_code(), StatusCode::kUnexpectedError);
  142. ASSERT_EQ(t->InsertTensor({1}, t5).get_code(), StatusCode::kUnexpectedError);
  143. ASSERT_EQ(t->InsertTensor({1, 2}, t6).get_code(), StatusCode::kUnexpectedError);
  144. t6->Fill<double>(-1);
  145. ASSERT_TRUE(t->InsertTensor({}, t6).OK());
  146. ASSERT_EQ(*t == *t6, true);
  147. }
  148. // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
  149. TEST_F(MindDataTestTensorDE, BoolTensor) {
  150. std::shared_ptr<Tensor> t;
  151. Tensor::CreateEmpty(TensorShape({2}), DataType(DataType::DE_BOOL), &t);
  152. t->SetItemAt<bool>({0}, true);
  153. t->SetItemAt<bool>({1}, true);
  154. std::string out = t->ToString();
  155. ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos);
  156. }
  157. TEST_F(MindDataTestTensorDE, GetItemAt) {
  158. std::shared_ptr<Tensor> t;
  159. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t);
  160. t->Fill<uint8_t>(254);
  161. uint64_t o1;
  162. t->GetItemAt<uint64_t>(&o1, {0, 0});
  163. ASSERT_EQ(o1, 254);
  164. uint32_t o2;
  165. t->GetItemAt<uint32_t>(&o2, {0, 1});
  166. ASSERT_EQ(o2, 254);
  167. uint16_t o3;
  168. t->GetItemAt<uint16_t>(&o3, {1, 0});
  169. ASSERT_EQ(o3, 254);
  170. uint8_t o4;
  171. t->GetItemAt<uint8_t>(&o4, {1, 1});
  172. ASSERT_EQ(o4, 254);
  173. std::shared_ptr<Tensor> t2;
  174. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_INT8), &t2);
  175. t2->Fill<int8_t>(-10);
  176. int64_t o5;
  177. t2->GetItemAt<int64_t>(&o5, {0, 0});
  178. ASSERT_EQ(o5, -10);
  179. int32_t o6;
  180. t2->GetItemAt<int32_t>(&o6, {0, 1});
  181. ASSERT_EQ(o6, -10);
  182. int16_t o7;
  183. t2->GetItemAt<int16_t>(&o7, {1, 0});
  184. ASSERT_EQ(o7, -10);
  185. int8_t o8;
  186. t2->GetItemAt<int8_t>(&o8, {1, 1});
  187. ASSERT_EQ(o8, -10);
  188. std::shared_ptr<Tensor> t3;
  189. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_FLOAT32), &t3);
  190. t3->Fill<float>(1.1);
  191. double o9;
  192. t3->GetItemAt<double>(&o9, {0, 0});
  193. ASSERT_FLOAT_EQ(o9, 1.1);
  194. float o10;
  195. t3->GetItemAt<float>(&o10, {0, 1});
  196. ASSERT_FLOAT_EQ(o10, 1.1);
  197. }
  198. TEST_F(MindDataTestTensorDE, OperatorAssign) {
  199. std::shared_ptr<Tensor> t;
  200. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t);
  201. t->Fill<uint8_t>(1);
  202. std::shared_ptr<Tensor> t2;
  203. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t2);
  204. *t2 = std::move(*t);
  205. uint8_t o;
  206. t2->GetItemAt(&o, {0, 0});
  207. ASSERT_EQ(o, 1);
  208. t2->GetItemAt(&o, {0, 1});
  209. ASSERT_EQ(o, 1);
  210. t2->GetItemAt(&o, {1, 0});
  211. ASSERT_EQ(o, 1);
  212. t2->GetItemAt(&o, {1, 1});
  213. ASSERT_EQ(o, 1);
  214. }
  215. TEST_F(MindDataTestTensorDE, Strides) {
  216. std::shared_ptr<Tensor> t;
  217. Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT8), &t);
  218. std::vector<dsize_t> x1 = t->Strides();
  219. std::vector<dsize_t> x2 = {4, 2, 1};
  220. ASSERT_EQ(x1, x2);
  221. Tensor::CreateEmpty(TensorShape({4, 2, 2}), DataType(DataType::DE_UINT32), &t);
  222. x1 = t->Strides();
  223. x2 = {16, 8, 4};
  224. ASSERT_EQ(x1, x2);
  225. }
  226. void checkCvMat(TensorShape shape, DataType type) {
  227. std::shared_ptr<CVTensor> t;
  228. CVTensor::CreateEmpty(shape, type, &t);
  229. cv::Mat m = t->mat();
  230. ASSERT_EQ(m.data, t->GetBuffer());
  231. ASSERT_EQ(static_cast<uchar>(m.type()) & static_cast<uchar>(CV_MAT_DEPTH_MASK), type.AsCVType());
  232. if (shape.Rank() < 4) {
  233. if (shape.Rank() > 1) {
  234. for (dsize_t i = 0; i < 2; i++) ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]);
  235. } else if (shape.Rank() == 0) {
  236. ASSERT_EQ(m.size[0], 1);
  237. ASSERT_EQ(m.size[1], 1);
  238. } else {
  239. ASSERT_EQ(m.size[0], shape[0]);
  240. }
  241. if (shape.Rank() == 3) {
  242. ASSERT_EQ(m.channels(), shape[2]);
  243. }
  244. ASSERT_EQ(m.dims, 2);
  245. ASSERT_EQ(m.size.dims(), 2);
  246. if (shape.Rank() > 0) {
  247. ASSERT_EQ(m.rows, shape[0]);
  248. }
  249. if (shape.Rank() > 1) {
  250. ASSERT_EQ(m.cols, shape[1]);
  251. }
  252. } else {
  253. for (dsize_t i = 0; i < shape.Rank(); i++) ASSERT_EQ(m.size[static_cast<int>(i)], shape[i]);
  254. ASSERT_EQ(m.dims, shape.Rank());
  255. ASSERT_EQ(m.size.dims(), shape.Rank());
  256. ASSERT_EQ(m.rows, -1);
  257. ASSERT_EQ(m.cols, -1);
  258. }
  259. }
  260. TEST_F(MindDataTestTensorDE, CVTensorBasics) {
  261. checkCvMat(TensorShape({4, 5}), DataType(DataType::DE_UINT8));
  262. checkCvMat(TensorShape({4, 5, 3}), DataType(DataType::DE_UINT8));
  263. checkCvMat(TensorShape({4, 5, 10}), DataType(DataType::DE_UINT8));
  264. checkCvMat(TensorShape({4, 5, 3, 2}), DataType(DataType::DE_UINT8));
  265. checkCvMat(TensorShape({4}), DataType(DataType::DE_UINT8));
  266. checkCvMat(TensorShape({}), DataType(DataType::DE_INT16));
  267. checkCvMat(TensorShape({4, 5}), DataType(DataType::DE_INT16));
  268. checkCvMat(TensorShape({4, 5, 3}), DataType(DataType::DE_INT16));
  269. checkCvMat(TensorShape({4, 5, 10}), DataType(DataType::DE_INT16));
  270. checkCvMat(TensorShape({4, 5, 3, 2}), DataType(DataType::DE_INT16));
  271. checkCvMat(TensorShape({4}), DataType(DataType::DE_INT16));
  272. checkCvMat(TensorShape({}), DataType(DataType::DE_INT16));
  273. }
  274. TEST_F(MindDataTestTensorDE, CVTensorFromMat) {
  275. cv::Mat m(2, 2, CV_8U);
  276. m.at<uint8_t>(0, 0) = 10;
  277. m.at<uint8_t>(0, 1) = 20;
  278. m.at<uint8_t>(1, 0) = 30;
  279. m.at<uint8_t>(1, 1) = 40;
  280. std::shared_ptr<CVTensor> cvt;
  281. CVTensor::CreateFromMat(m, &cvt);
  282. std::shared_ptr<Tensor> t;
  283. Tensor::CreateEmpty(TensorShape({2, 2}), DataType(DataType::DE_UINT8), &t);
  284. t->SetItemAt<uint8_t>({0, 0}, 10);
  285. t->SetItemAt<uint8_t>({0, 1}, 20);
  286. t->SetItemAt<uint8_t>({1, 0}, 30);
  287. t->SetItemAt<uint8_t>({1, 1}, 40);
  288. ASSERT_TRUE(*t == *cvt);
  289. int size[] = {4};
  290. cv::Mat m2(1, size, CV_8U);
  291. m2.at<uint8_t>(0) = 10;
  292. m2.at<uint8_t>(1) = 20;
  293. m2.at<uint8_t>(2) = 30;
  294. m2.at<uint8_t>(3) = 40;
  295. std::shared_ptr<CVTensor> cvt2;
  296. CVTensor::CreateFromMat(m2, &cvt2);
  297. std::shared_ptr<Tensor> t2;
  298. Tensor::CreateEmpty(TensorShape({4}), DataType(DataType::DE_UINT8), &t2);
  299. t2->SetItemAt<uint8_t>({0}, 10);
  300. t2->SetItemAt<uint8_t>({1}, 20);
  301. t2->SetItemAt<uint8_t>({2}, 30);
  302. t2->SetItemAt<uint8_t>({3}, 40);
  303. t2->ExpandDim(1);
  304. ASSERT_TRUE(*t2 == *cvt2);
  305. }
  306. TEST_F(MindDataTestTensorDE, CVTensorAs) {
  307. std::shared_ptr<Tensor> t;
  308. Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t);
  309. t->Fill<double>(2.2);
  310. const unsigned char *addr = t->GetBuffer();
  311. std::shared_ptr<Tensor> t2;
  312. Tensor::CreateEmpty(TensorShape({3, 2}), DataType(DataType::DE_FLOAT64), &t2);
  313. t2->Fill<double>(4.4);
  314. std::shared_ptr<CVTensor> ctv = CVTensor::AsCVTensor(t);
  315. ASSERT_EQ(t->GetBuffer(), nullptr);
  316. ASSERT_EQ(ctv->GetBuffer(), addr);
  317. cv::Mat m = ctv->mat();
  318. m = 2 * m;
  319. ASSERT_EQ(ctv->GetBuffer(), addr);
  320. ASSERT_TRUE(*t2 == *ctv);
  321. MS_LOG(DEBUG) << *t2 << std::endl << *ctv;
  322. cv::Mat m2 = ctv->matCopy();
  323. m2 = 2 * m2;
  324. ASSERT_EQ(ctv->GetBuffer(), addr);
  325. ASSERT_TRUE(*t2 == *ctv);
  326. }
  327. TEST_F(MindDataTestTensorDE, CVTensorMatSlice) {
  328. cv::Mat m(2, 3, CV_32S);
  329. m.at<int32_t>(0, 0) = 10;
  330. m.at<int32_t>(0, 1) = 20;
  331. m.at<int32_t>(0, 2) = 30;
  332. m.at<int32_t>(1, 0) = 40;
  333. m.at<int32_t>(1, 1) = 50;
  334. m.at<int32_t>(1, 2) = 60;
  335. std::shared_ptr<CVTensor> cvt;
  336. CVTensor::CreateFromMat(m, &cvt);
  337. cv::Mat mat;
  338. cvt->MatAtIndex({1}, &mat);
  339. cv::Mat m2(3, 1, CV_32S);
  340. m2.at<int32_t>(0) = 40;
  341. m2.at<int32_t>(1) = 50;
  342. m2.at<int32_t>(2) = 60;
  343. std::shared_ptr<CVTensor> cvt2;
  344. CVTensor::CreateFromMat(mat, &cvt2);
  345. std::shared_ptr<CVTensor> cvt3;
  346. CVTensor::CreateFromMat(m2, &cvt3);
  347. ASSERT_TRUE(*cvt2 == *cvt3);
  348. cvt->MatAtIndex({0}, &mat);
  349. m2.at<int32_t>(0) = 10;
  350. m2.at<int32_t>(1) = 20;
  351. m2.at<int32_t>(2) = 30;
  352. CVTensor::CreateFromMat(mat, &cvt2);
  353. CVTensor::CreateFromMat(m2, &cvt3);
  354. ASSERT_TRUE(*cvt2 == *cvt3);
  355. }
  356. TEST_F(MindDataTestTensorDE, TensorIterator) {
  357. std::vector<uint32_t> values = {1, 2, 3, 4, 5, 6};
  358. std::vector<uint32_t> values2 = {2, 3, 4, 5, 6, 7};
  359. std::shared_ptr<Tensor> t;
  360. Tensor::CreateFromVector(values, &t);
  361. auto i = t->begin<uint32_t>();
  362. auto j = values.begin();
  363. uint32_t ctr = 0;
  364. for (; i != t->end<uint32_t>(); i++, j++) {
  365. ASSERT_TRUE(*i == *j);
  366. ctr++;
  367. }
  368. ASSERT_TRUE(ctr == 6);
  369. t->Reshape(TensorShape{2, 3});
  370. i = t->begin<uint32_t>();
  371. j = values.begin();
  372. ctr = 0;
  373. for (; i != t->end<uint32_t>(); i++, j++) {
  374. ASSERT_TRUE(*i == *j);
  375. ctr++;
  376. }
  377. ASSERT_TRUE(ctr == 6);
  378. for (auto it = t->begin<uint32_t>(); it != t->end<uint32_t>(); it++) {
  379. *it = *it + 1;
  380. }
  381. i = t->begin<uint32_t>();
  382. j = values2.begin();
  383. ctr = 0;
  384. for (; i != t->end<uint32_t>(); i++, j++) {
  385. ASSERT_TRUE(*i == *j);
  386. ctr++;
  387. }
  388. ASSERT_TRUE(ctr == 6);
  389. }
  390. TEST_F(MindDataTestTensorDE, TensorSlice) {
  391. std::shared_ptr<Tensor> t;
  392. Tensor::CreateFromVector(std::vector<dsize_t>{0, 1, 2, 3, 4}, &t);
  393. std::shared_ptr<Tensor> t2;
  394. auto x = std::vector<dsize_t>{0, 3, 4};
  395. std::shared_ptr<Tensor> expected;
  396. Tensor::CreateFromVector(x, &expected);
  397. t->Slice(&t2, x);
  398. ASSERT_EQ(*t2, *expected);
  399. t->Slice(&t2, std::vector<dsize_t>{0, 1, 2, 3, 4});
  400. ASSERT_EQ(*t2, *t);
  401. }
  402. TEST_F(MindDataTestTensorDE, TensorPartialInsert) {
  403. std::vector<uint32_t> values1 = {1, 2, 3, 0, 0, 0};
  404. std::vector<uint32_t> values2 = {4, 5, 6};
  405. std::vector<uint32_t> expected = {1, 2, 3, 4, 5, 6};
  406. std::shared_ptr<Tensor> t1;
  407. Tensor::CreateFromVector(values1, &t1);
  408. std::shared_ptr<Tensor> t2;
  409. Tensor::CreateFromVector(values2, &t2);
  410. std::shared_ptr<Tensor> out;
  411. Tensor::CreateFromVector(expected, &out);
  412. Status s = t1->InsertTensor({3}, t2, true);
  413. EXPECT_TRUE(s.IsOk());
  414. auto i = out->begin<uint32_t>();
  415. auto j = t1->begin<uint32_t>();
  416. for (; i != out->end<uint32_t>(); i++, j++) {
  417. ASSERT_TRUE(*i == *j);
  418. }
  419. // should fail if the concatenated vector is too large
  420. s = t1->InsertTensor({5}, t2, true);
  421. EXPECT_FALSE(s.IsOk());
  422. }
  423. TEST_F(MindDataTestTensorDE, TensorEmpty) {
  424. TensorPtr t;
  425. Status rc = Tensor::CreateEmpty(TensorShape({0}), DataType(DataType::DE_UINT64), &t);
  426. ASSERT_TRUE(rc.IsOk());
  427. ASSERT_EQ(t->shape(), TensorShape({0}));
  428. ASSERT_EQ(t->type(), DataType::DE_UINT64);
  429. ASSERT_EQ(t->SizeInBytes(), 0);
  430. ASSERT_EQ(t->GetBuffer(), nullptr);
  431. ASSERT_TRUE(!t->HasData());
  432. rc = t->SetItemAt<uint64_t>({0}, 7);
  433. ASSERT_TRUE(rc.IsError());
  434. rc = Tensor::CreateEmpty(TensorShape({1, 0}), DataType(DataType::DE_STRING), &t);
  435. ASSERT_TRUE(rc.IsOk());
  436. ASSERT_EQ(t->shape(), TensorShape({1, 0}));
  437. ASSERT_EQ(t->type(), DataType::DE_STRING);
  438. ASSERT_EQ(t->SizeInBytes(), 0);
  439. ASSERT_EQ(t->GetBuffer(), nullptr);
  440. ASSERT_TRUE(!t->HasData());
  441. std::vector<uint16_t> data;
  442. rc = Tensor::CreateFromVector(data, &t);
  443. ASSERT_TRUE(rc.IsOk());
  444. ASSERT_EQ(t->shape(), TensorShape({0}));
  445. ASSERT_EQ(t->type(), DataType::DE_UINT16);
  446. ASSERT_EQ(t->SizeInBytes(), 0);
  447. ASSERT_EQ(t->GetBuffer(), nullptr);
  448. ASSERT_TRUE(!t->HasData());
  449. std::vector<std::string> data2;
  450. rc = Tensor::CreateFromVector(data2, &t);
  451. ASSERT_TRUE(rc.IsOk());
  452. ASSERT_EQ(t->shape(), TensorShape({0}));
  453. ASSERT_EQ(t->type(), DataType::DE_STRING);
  454. ASSERT_EQ(t->SizeInBytes(), 0);
  455. ASSERT_EQ(t->GetBuffer(), nullptr);
  456. ASSERT_TRUE(!t->HasData());
  457. rc = Tensor::CreateFromVector(data, TensorShape({0, 2}), &t);
  458. ASSERT_TRUE(rc.IsOk());
  459. ASSERT_EQ(t->shape(), TensorShape({0, 2}));
  460. ASSERT_EQ(t->type(), DataType::DE_UINT16);
  461. ASSERT_EQ(t->SizeInBytes(), 0);
  462. ASSERT_EQ(t->GetBuffer(), nullptr);
  463. ASSERT_TRUE(!t->HasData());
  464. rc = Tensor::CreateFromVector(data2, TensorShape({0, 0, 6}), &t);
  465. ASSERT_TRUE(rc.IsOk());
  466. ASSERT_EQ(t->shape(), TensorShape({0, 0, 6}));
  467. ASSERT_EQ(t->type(), DataType::DE_STRING);
  468. ASSERT_EQ(t->SizeInBytes(), 0);
  469. ASSERT_EQ(t->GetBuffer(), nullptr);
  470. ASSERT_TRUE(!t->HasData());
  471. rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_INT8), nullptr, &t);
  472. ASSERT_TRUE(rc.IsOk());
  473. ASSERT_EQ(t->shape(), TensorShape({0}));
  474. ASSERT_EQ(t->type(), DataType::DE_INT8);
  475. ASSERT_EQ(t->SizeInBytes(), 0);
  476. ASSERT_EQ(t->GetBuffer(), nullptr);
  477. ASSERT_TRUE(!t->HasData());
  478. rc = Tensor::CreateFromMemory(TensorShape({0}), DataType(DataType::DE_STRING), nullptr, &t);
  479. ASSERT_TRUE(rc.IsOk());
  480. ASSERT_EQ(t->shape(), TensorShape({0}));
  481. ASSERT_EQ(t->type(), DataType::DE_STRING);
  482. ASSERT_EQ(t->SizeInBytes(), 0);
  483. ASSERT_EQ(t->GetBuffer(), nullptr);
  484. std::vector<uint32_t> values = {1, 2, 3, 0, 0, 0};
  485. std::shared_ptr<Tensor> t2;
  486. Tensor::CreateFromVector(values, &t2);
  487. ASSERT_TRUE(t2->HasData());
  488. t2->Invalidate();
  489. ASSERT_TRUE(!t2->HasData());
  490. }