You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_utils.cc 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/data/data_utils.h"
  17. #include <algorithm>
  18. #include <limits>
  19. #include <string>
  20. #include <vector>
  21. #include "dataset/core/constants.h"
  22. #include "dataset/core/data_type.h"
  23. #include "dataset/core/pybind_support.h"
  24. #include "dataset/core/tensor.h"
  25. #include "dataset/core/tensor_shape.h"
  26. #include "dataset/kernels/data/type_cast_op.h"
  27. #include "dataset/util/status.h"
  28. namespace mindspore {
  29. namespace dataset {
  30. Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
  31. dsize_t num_classes, int64_t index) {
  32. uint64_t class_idx;
  33. if (input->Rank() == 0) {
  34. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {}));
  35. } else {
  36. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {index}));
  37. }
  38. if (class_idx >= static_cast<uint64_t>(num_classes)) {
  39. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  40. }
  41. if (input->type() == DataType::DE_UINT64) {
  42. RETURN_IF_NOT_OK((*output)->SetItemAt<uint64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  43. } else if (input->type() == DataType::DE_UINT32) {
  44. RETURN_IF_NOT_OK((*output)->SetItemAt<uint32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  45. } else if (input->type() == DataType::DE_UINT16) {
  46. RETURN_IF_NOT_OK((*output)->SetItemAt<uint16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  47. } else if (input->type() == DataType::DE_UINT8) {
  48. RETURN_IF_NOT_OK((*output)->SetItemAt<uint8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  49. } else {
  50. RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input.");
  51. }
  52. return Status::OK();
  53. }
  54. Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
  55. int64_t index) {
  56. int64_t class_idx;
  57. if (input->Rank() == 0) {
  58. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {}));
  59. } else {
  60. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {index}));
  61. }
  62. if (class_idx >= static_cast<int64_t>(num_classes)) {
  63. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  64. }
  65. if (input->type() == DataType::DE_INT64) {
  66. RETURN_IF_NOT_OK((*output)->SetItemAt<int64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  67. } else if (input->type() == DataType::DE_INT32) {
  68. RETURN_IF_NOT_OK((*output)->SetItemAt<int32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  69. } else if (input->type() == DataType::DE_INT16) {
  70. RETURN_IF_NOT_OK((*output)->SetItemAt<int16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  71. } else if (input->type() == DataType::DE_INT8) {
  72. RETURN_IF_NOT_OK((*output)->SetItemAt<int8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  73. } else {
  74. RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input.");
  75. }
  76. return Status::OK();
  77. }
  78. Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
  79. input->Squeeze();
  80. if (input->Rank() > 1) { // We expect the input to be int he first dimension
  81. RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
  82. }
  83. if (!input->type().IsInt()) {
  84. RETURN_STATUS_UNEXPECTED("One hot does not support input of this type.");
  85. }
  86. try {
  87. dsize_t num_elements = 1;
  88. if (input->Rank() == 1) num_elements = input->shape()[0];
  89. TensorShape out_shape({num_elements, num_classes});
  90. std::shared_ptr<Tensor> out;
  91. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type()));
  92. RETURN_IF_NOT_OK(out->Zero());
  93. for (dsize_t i = 0; i < num_elements; ++i) {
  94. if (input->type().IsUnsignedInt()) {
  95. RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i));
  96. } else {
  97. RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i));
  98. }
  99. }
  100. out->Squeeze();
  101. *output = out;
  102. return Status::OK();
  103. } catch (const std::exception &e) {
  104. RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp");
  105. }
  106. }
  107. Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
  108. CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
  109. "Types do not match");
  110. CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
  111. std::shared_ptr<Tensor> out;
  112. const DataType &to = input->type();
  113. std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
  114. std::shared_ptr<Tensor> fill_output;
  115. RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
  116. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
  117. switch (input->type().value()) {
  118. case DataType::DE_BOOL: {
  119. bool value = 0;
  120. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  121. out->Fill<bool>(value);
  122. break;
  123. }
  124. case DataType::DE_INT8: {
  125. int8_t value = 0;
  126. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  127. out->Fill<int8_t>(value);
  128. break;
  129. }
  130. case DataType::DE_UINT8: {
  131. uint8_t value = 0;
  132. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  133. out->Fill<uint8_t>(value);
  134. break;
  135. }
  136. case DataType::DE_UINT16: {
  137. uint16_t value = 0;
  138. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  139. out->Fill<uint16_t>(value);
  140. break;
  141. }
  142. case DataType::DE_INT16: {
  143. int16_t value = 0;
  144. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  145. out->Fill<int16_t>(value);
  146. break;
  147. }
  148. case DataType::DE_UINT32: {
  149. uint32_t value = 0;
  150. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  151. out->Fill<uint32_t>(value);
  152. break;
  153. }
  154. case DataType::DE_INT32: {
  155. int32_t value = 0;
  156. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  157. out->Fill<int32_t>(value);
  158. break;
  159. }
  160. case DataType::DE_UINT64: {
  161. uint64_t value = 0;
  162. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  163. out->Fill<uint64_t>(value);
  164. break;
  165. }
  166. case DataType::DE_INT64: {
  167. int64_t value = 0;
  168. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  169. out->Fill<int64_t>(value);
  170. break;
  171. }
  172. case DataType::DE_FLOAT16: {
  173. int64_t value = 0;
  174. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  175. out->Fill<float>(value);
  176. break;
  177. }
  178. case DataType::DE_FLOAT32: {
  179. float value = 0;
  180. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  181. out->Fill<float>(value);
  182. break;
  183. }
  184. case DataType::DE_FLOAT64: {
  185. double value = 0;
  186. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  187. out->Fill<double>(value);
  188. break;
  189. }
  190. case DataType::DE_STRING: {
  191. std::vector<std::string> strings;
  192. std::string_view fill_string_view;
  193. RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
  194. std::string fill_string = std::string(fill_string_view);
  195. for (int i = 0; i < input->shape().NumOfElements(); i++) {
  196. strings.emplace_back(fill_string);
  197. }
  198. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
  199. break;
  200. }
  201. case DataType::DE_UNKNOWN: {
  202. RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type.");
  203. break;
  204. }
  205. }
  206. *output = out;
  207. return Status::OK();
  208. }
  209. template <typename FROM, typename TO>
  210. void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  211. auto in_itr = input->begin<FROM>();
  212. auto out_itr = (*output)->begin<TO>();
  213. auto out_end = (*output)->end<TO>();
  214. for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
  215. *out_itr = static_cast<TO>(*in_itr);
  216. }
  217. template <typename T>
  218. void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  219. switch ((*output)->type().value()) {
  220. case DataType::DE_BOOL:
  221. Cast<T, bool>(input, output);
  222. break;
  223. case DataType::DE_INT8:
  224. Cast<T, int8_t>(input, output);
  225. break;
  226. case DataType::DE_UINT8:
  227. Cast<T, uint8_t>(input, output);
  228. break;
  229. case DataType::DE_INT16:
  230. Cast<T, int16_t>(input, output);
  231. break;
  232. case DataType::DE_UINT16:
  233. Cast<T, uint16_t>(input, output);
  234. break;
  235. case DataType::DE_INT32:
  236. Cast<T, int32_t>(input, output);
  237. break;
  238. case DataType::DE_UINT32:
  239. Cast<T, uint32_t>(input, output);
  240. break;
  241. case DataType::DE_INT64:
  242. Cast<T, int64_t>(input, output);
  243. break;
  244. case DataType::DE_UINT64:
  245. Cast<T, uint64_t>(input, output);
  246. break;
  247. case DataType::DE_FLOAT16:
  248. Cast<T, float16>(input, output);
  249. break;
  250. case DataType::DE_FLOAT32:
  251. Cast<T, float>(input, output);
  252. break;
  253. case DataType::DE_FLOAT64:
  254. Cast<T, double>(input, output);
  255. break;
  256. case DataType::DE_UNKNOWN:
  257. MS_LOG(ERROR) << "Unknown data type.";
  258. break;
  259. }
  260. }
  261. // Type cast operator
  262. Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
  263. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type));
  264. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  265. switch (input->type().value()) {
  266. case DataType::DE_BOOL:
  267. CastFrom<bool>(input, output);
  268. break;
  269. case DataType::DE_INT8:
  270. CastFrom<int8_t>(input, output);
  271. break;
  272. case DataType::DE_UINT8:
  273. CastFrom<uint8_t>(input, output);
  274. break;
  275. case DataType::DE_INT16:
  276. CastFrom<int16_t>(input, output);
  277. break;
  278. case DataType::DE_UINT16:
  279. CastFrom<uint16_t>(input, output);
  280. break;
  281. case DataType::DE_INT32:
  282. CastFrom<int32_t>(input, output);
  283. break;
  284. case DataType::DE_UINT32:
  285. CastFrom<uint32_t>(input, output);
  286. break;
  287. case DataType::DE_INT64:
  288. CastFrom<int64_t>(input, output);
  289. break;
  290. case DataType::DE_UINT64:
  291. CastFrom<uint64_t>(input, output);
  292. break;
  293. case DataType::DE_FLOAT16:
  294. CastFrom<float16>(input, output);
  295. break;
  296. case DataType::DE_FLOAT32:
  297. CastFrom<float>(input, output);
  298. break;
  299. case DataType::DE_FLOAT64:
  300. CastFrom<double>(input, output);
  301. break;
  302. case DataType::DE_UNKNOWN:
  303. // sanity check, unreachable code.
  304. RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type.");
  305. }
  306. return Status::OK();
  307. }
  308. Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  309. // initiate new tensor for type cast
  310. DataType new_type = DataType("float16");
  311. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type));
  312. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  313. auto in_itr = input->begin<float>();
  314. auto out_itr = (*output)->begin<float16>();
  315. auto out_end = (*output)->end<float16>();
  316. for (; out_itr != out_end; in_itr++, out_itr++) {
  317. float element = *in_itr;
  318. float float16_max = static_cast<float>(std::numeric_limits<Eigen::half>::max());
  319. float float16_min = static_cast<float>(std::numeric_limits<Eigen::half>::lowest());
  320. if (element > float16_max || element < float16_min) {
  321. RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" +
  322. std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
  323. }
  324. *out_itr = Eigen::half(*in_itr);
  325. }
  326. return Status::OK();
  327. }
  328. Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
  329. const std::shared_ptr<Tensor> &pad_val) {
  330. if (pad_val == nullptr) {
  331. if (src->type().IsNumeric()) {
  332. return PadEndNumeric(src, dst, pad_shape, 0);
  333. } else {
  334. return PadEndString(src, dst, pad_shape, "");
  335. }
  336. }
  337. CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
  338. "Source and pad_value tensors are not of the same type.");
  339. if (pad_val->type().IsNumeric()) {
  340. std::shared_ptr<Tensor> float_pad_value;
  341. RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
  342. float val = 0;
  343. RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
  344. return PadEndNumeric(src, dst, pad_shape, val);
  345. }
  346. std::string_view val;
  347. RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
  348. return PadEndString(src, dst, pad_shape, std::string(val));
  349. }
  350. Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  351. const std::vector<dsize_t> &pad_shape, float pad_val) {
  352. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  353. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  354. (*dst) = src; // if no padding, copy the pointer
  355. } else {
  356. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  357. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
  358. auto tensor_type = src->type().value();
  359. if (pad_val == 0) { // if pad with zero, don't care what type it is
  360. RETURN_IF_NOT_OK((*dst)->Zero());
  361. } else if (tensor_type == DataType::DE_INT8) {
  362. RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
  363. } else if (tensor_type == DataType::DE_BOOL) {
  364. RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
  365. } else if (tensor_type == DataType::DE_UINT8) {
  366. RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
  367. } else if (tensor_type == DataType::DE_INT16) {
  368. RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
  369. } else if (tensor_type == DataType::DE_FLOAT16) {
  370. RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
  371. } else if (tensor_type == DataType::DE_UINT16) {
  372. RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
  373. } else if (tensor_type == DataType::DE_INT32) {
  374. RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
  375. } else if (tensor_type == DataType::DE_UINT32) {
  376. RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
  377. } else if (tensor_type == DataType::DE_INT64) {
  378. RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
  379. } else if (tensor_type == DataType::DE_UINT64) {
  380. RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
  381. } else if (tensor_type == DataType::DE_FLOAT32) {
  382. RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
  383. } else if (tensor_type == DataType::DE_FLOAT64) {
  384. RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
  385. } else {
  386. RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
  387. }
  388. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  389. RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
  390. }
  391. return Status::OK();
  392. }
  393. Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
  394. std::vector<dsize_t> cur_ind, size_t cur_dim) {
  395. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  396. dst->CopyLastDimAt(src, cur_ind);
  397. } else { // not the last dimension, keep doing recursion
  398. dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
  399. for (dsize_t i = 0; i < min_ind; i++) {
  400. cur_ind[cur_dim] = i;
  401. RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
  402. }
  403. }
  404. return Status::OK();
  405. }
  406. Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  407. const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
  408. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  409. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  410. (*dst) = src; // if no padding, copy the pointer
  411. } else {
  412. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  413. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  414. std::vector<std::string> strings;
  415. RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
  416. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape)));
  417. }
  418. return Status::OK();
  419. }
  420. Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
  421. const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
  422. const std::string &pad_value) {
  423. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  424. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  425. for (dsize_t i = 0; i < min_ind; i++) {
  426. cur_ind[cur_dim] = i;
  427. std::string_view item;
  428. RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
  429. dst->emplace_back(item);
  430. }
  431. for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
  432. dst->emplace_back(pad_value);
  433. }
  434. } else { // not the last dimension, keep doing recursion
  435. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  436. for (dsize_t i = 0; i < min_ind; i++) {
  437. cur_ind[cur_dim] = i;
  438. RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
  439. }
  440. dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
  441. for (dsize_t i = 0; i < count; i++) {
  442. dst->emplace_back(pad_value);
  443. }
  444. }
  445. return Status::OK();
  446. }
  447. template <typename T>
  448. Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
  449. const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
  450. T value;
  451. RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
  452. auto in_itr = input->begin<T>();
  453. auto out_itr = output->begin<bool>();
  454. for (; in_itr != input->end<T>(); in_itr++, out_itr++) {
  455. switch (op) {
  456. case RelationalOp::kEqual:
  457. *out_itr = (*in_itr == value);
  458. break;
  459. case RelationalOp::kNotEqual:
  460. *out_itr = (*in_itr != value);
  461. break;
  462. case RelationalOp::kGreater:
  463. *out_itr = (*in_itr > value);
  464. break;
  465. case RelationalOp::kGreaterEqual:
  466. *out_itr = (*in_itr >= value);
  467. break;
  468. case RelationalOp::kLess:
  469. *out_itr = (*in_itr < value);
  470. break;
  471. case RelationalOp::kLessEqual:
  472. *out_itr = (*in_itr <= value);
  473. break;
  474. default:
  475. RETURN_STATUS_UNEXPECTED("Unknown relational operator.");
  476. }
  477. }
  478. return Status::OK();
  479. }
  480. Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
  481. RelationalOp op) {
  482. CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
  483. "Cannot convert constant value to the type of the input tensor.");
  484. CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar");
  485. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL)));
  486. std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
  487. std::shared_ptr<Tensor> casted_value;
  488. if (input->type().IsNumeric()) {
  489. RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
  490. } else {
  491. casted_value = value;
  492. }
  493. switch (input->type().value()) {
  494. case DataType::DE_BOOL:
  495. RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
  496. break;
  497. case DataType::DE_INT8:
  498. RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
  499. break;
  500. case DataType::DE_UINT8:
  501. RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
  502. break;
  503. case DataType::DE_UINT16:
  504. RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
  505. break;
  506. case DataType::DE_INT16:
  507. RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
  508. break;
  509. case DataType::DE_UINT32:
  510. RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
  511. break;
  512. case DataType::DE_INT32:
  513. RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
  514. break;
  515. case DataType::DE_UINT64:
  516. RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
  517. break;
  518. case DataType::DE_INT64:
  519. RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
  520. break;
  521. case DataType::DE_FLOAT16:
  522. RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
  523. break;
  524. case DataType::DE_FLOAT32:
  525. RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
  526. break;
  527. case DataType::DE_FLOAT64:
  528. RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
  529. break;
  530. case DataType::DE_STRING:
  531. RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
  532. break;
  533. case DataType::DE_UNKNOWN:
  534. RETURN_STATUS_UNEXPECTED("Unsupported input type.");
  535. break;
  536. }
  537. return Status::OK();
  538. }
  539. Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
  540. std::shared_ptr<Tensor> append) {
  541. CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
  542. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
  543. axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
  544. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
  545. std::shared_ptr<Tensor> out;
  546. if (prepend != nullptr) {
  547. CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
  548. RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0]));
  549. } else {
  550. out = input[0];
  551. }
  552. for (dsize_t i = 1; i < input.size(); i++) {
  553. std::shared_ptr<Tensor> out_t;
  554. CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
  555. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i]));
  556. out = out_t;
  557. }
  558. std::shared_ptr<Tensor> out_t;
  559. if (append != nullptr) {
  560. CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
  561. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append));
  562. } else {
  563. out_t = out;
  564. }
  565. output->push_back(out_t);
  566. return Status::OK();
  567. }
  568. Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
  569. std::shared_ptr<Tensor> append) {
  570. CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match");
  571. TensorShape t({});
  572. for (dsize_t i = 0; i < input->shape().Rank(); i++) {
  573. if (i != axis) {
  574. t = t.AppendDim(input->shape()[i]);
  575. } else {
  576. dsize_t new_shape = input->shape()[i] + append->shape()[i];
  577. t = t.AppendDim(new_shape);
  578. }
  579. }
  580. std::shared_ptr<Tensor> out;
  581. if (input->type().IsNumeric()) {
  582. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type()));
  583. RETURN_IF_NOT_OK(out->Concatenate({0}, input));
  584. RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append));
  585. *output = out;
  586. } else {
  587. std::vector<std::string> strings;
  588. auto itr = input->begin<std::string_view>();
  589. for (; itr != input->end<std::string_view>(); itr++) {
  590. strings.emplace_back(*itr);
  591. }
  592. itr = append->begin<std::string_view>();
  593. for (; itr != append->end<std::string_view>(); itr++) {
  594. strings.emplace_back(*itr);
  595. }
  596. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t));
  597. *output = out;
  598. }
  599. return Status::OK();
  600. }
  601. } // namespace dataset
  602. } // namespace mindspore