You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_utils.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/data/data_utils.h"
  17. #include <algorithm>
  18. #include <limits>
  19. #include <string>
  20. #include <vector>
  21. #include "dataset/core/constants.h"
  22. #include "dataset/core/data_type.h"
  23. #ifdef ENABLE_PYTHON
  24. #include "dataset/core/pybind_support.h"
  25. #endif
  26. #include "dataset/core/tensor.h"
  27. #include "dataset/core/tensor_shape.h"
  28. #include "dataset/kernels/data/type_cast_op.h"
  29. #include "dataset/util/status.h"
  30. namespace mindspore {
  31. namespace dataset {
  32. Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
  33. dsize_t num_classes, int64_t index) {
  34. uint64_t class_idx;
  35. if (input->Rank() == 0) {
  36. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {}));
  37. } else {
  38. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {index}));
  39. }
  40. if (class_idx >= static_cast<uint64_t>(num_classes)) {
  41. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  42. }
  43. if (input->type() == DataType::DE_UINT64) {
  44. RETURN_IF_NOT_OK((*output)->SetItemAt<uint64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  45. } else if (input->type() == DataType::DE_UINT32) {
  46. RETURN_IF_NOT_OK((*output)->SetItemAt<uint32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  47. } else if (input->type() == DataType::DE_UINT16) {
  48. RETURN_IF_NOT_OK((*output)->SetItemAt<uint16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  49. } else if (input->type() == DataType::DE_UINT8) {
  50. RETURN_IF_NOT_OK((*output)->SetItemAt<uint8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  51. } else {
  52. RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input.");
  53. }
  54. return Status::OK();
  55. }
  56. Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
  57. int64_t index) {
  58. int64_t class_idx;
  59. if (input->Rank() == 0) {
  60. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {}));
  61. } else {
  62. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {index}));
  63. }
  64. if (class_idx >= static_cast<int64_t>(num_classes)) {
  65. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  66. }
  67. if (input->type() == DataType::DE_INT64) {
  68. RETURN_IF_NOT_OK((*output)->SetItemAt<int64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  69. } else if (input->type() == DataType::DE_INT32) {
  70. RETURN_IF_NOT_OK((*output)->SetItemAt<int32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  71. } else if (input->type() == DataType::DE_INT16) {
  72. RETURN_IF_NOT_OK((*output)->SetItemAt<int16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  73. } else if (input->type() == DataType::DE_INT8) {
  74. RETURN_IF_NOT_OK((*output)->SetItemAt<int8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  75. } else {
  76. RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input.");
  77. }
  78. return Status::OK();
  79. }
  80. Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
  81. input->Squeeze();
  82. if (input->Rank() > 1) { // We expect the input to be int he first dimension
  83. RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
  84. }
  85. if (!input->type().IsInt()) {
  86. RETURN_STATUS_UNEXPECTED("One hot does not support input of this type.");
  87. }
  88. try {
  89. dsize_t num_elements = 1;
  90. if (input->Rank() == 1) num_elements = input->shape()[0];
  91. TensorShape out_shape({num_elements, num_classes});
  92. std::shared_ptr<Tensor> out;
  93. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type()));
  94. RETURN_IF_NOT_OK(out->Zero());
  95. for (dsize_t i = 0; i < num_elements; ++i) {
  96. if (input->type().IsUnsignedInt()) {
  97. RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i));
  98. } else {
  99. RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i));
  100. }
  101. }
  102. out->Squeeze();
  103. *output = out;
  104. return Status::OK();
  105. } catch (const std::exception &e) {
  106. RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp");
  107. }
  108. }
  109. Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
  110. const DataType &fill_type = fill_value->type();
  111. const DataType &input_type = input->type();
  112. const TensorShape &input_shape = input->shape();
  113. CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
  114. "Types do not match");
  115. CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
  116. std::shared_ptr<Tensor> out, fill_output;
  117. if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
  118. std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
  119. RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
  120. } else {
  121. fill_output = fill_value;
  122. }
  123. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type));
  124. switch (input_type.value()) {
  125. case DataType::DE_BOOL: {
  126. bool value = 0;
  127. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  128. out->Fill<bool>(value);
  129. break;
  130. }
  131. case DataType::DE_INT8: {
  132. int8_t value = 0;
  133. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  134. out->Fill<int8_t>(value);
  135. break;
  136. }
  137. case DataType::DE_UINT8: {
  138. uint8_t value = 0;
  139. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  140. out->Fill<uint8_t>(value);
  141. break;
  142. }
  143. case DataType::DE_UINT16: {
  144. uint16_t value = 0;
  145. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  146. out->Fill<uint16_t>(value);
  147. break;
  148. }
  149. case DataType::DE_INT16: {
  150. int16_t value = 0;
  151. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  152. out->Fill<int16_t>(value);
  153. break;
  154. }
  155. case DataType::DE_UINT32: {
  156. uint32_t value = 0;
  157. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  158. out->Fill<uint32_t>(value);
  159. break;
  160. }
  161. case DataType::DE_INT32: {
  162. int32_t value = 0;
  163. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  164. out->Fill<int32_t>(value);
  165. break;
  166. }
  167. case DataType::DE_UINT64: {
  168. uint64_t value = 0;
  169. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  170. out->Fill<uint64_t>(value);
  171. break;
  172. }
  173. case DataType::DE_INT64: {
  174. int64_t value = 0;
  175. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  176. out->Fill<int64_t>(value);
  177. break;
  178. }
  179. case DataType::DE_FLOAT16: {
  180. int64_t value = 0;
  181. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  182. out->Fill<float>(value);
  183. break;
  184. }
  185. case DataType::DE_FLOAT32: {
  186. float value = 0;
  187. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  188. out->Fill<float>(value);
  189. break;
  190. }
  191. case DataType::DE_FLOAT64: {
  192. double value = 0;
  193. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  194. out->Fill<double>(value);
  195. break;
  196. }
  197. case DataType::DE_STRING: {
  198. std::vector<std::string> strings;
  199. std::string_view fill_string_view;
  200. RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
  201. std::string fill_string = std::string(fill_string_view);
  202. for (int i = 0; i < input_shape.NumOfElements(); i++) {
  203. strings.emplace_back(fill_string);
  204. }
  205. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape));
  206. break;
  207. }
  208. case DataType::DE_UNKNOWN: {
  209. RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type.");
  210. break;
  211. }
  212. }
  213. *output = out;
  214. return Status::OK();
  215. }
  216. template <typename FROM, typename TO>
  217. void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  218. auto in_itr = input->begin<FROM>();
  219. auto out_itr = (*output)->begin<TO>();
  220. auto out_end = (*output)->end<TO>();
  221. for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
  222. *out_itr = static_cast<TO>(*in_itr);
  223. }
  224. template <typename T>
  225. void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  226. switch ((*output)->type().value()) {
  227. case DataType::DE_BOOL:
  228. Cast<T, bool>(input, output);
  229. break;
  230. case DataType::DE_INT8:
  231. Cast<T, int8_t>(input, output);
  232. break;
  233. case DataType::DE_UINT8:
  234. Cast<T, uint8_t>(input, output);
  235. break;
  236. case DataType::DE_INT16:
  237. Cast<T, int16_t>(input, output);
  238. break;
  239. case DataType::DE_UINT16:
  240. Cast<T, uint16_t>(input, output);
  241. break;
  242. case DataType::DE_INT32:
  243. Cast<T, int32_t>(input, output);
  244. break;
  245. case DataType::DE_UINT32:
  246. Cast<T, uint32_t>(input, output);
  247. break;
  248. case DataType::DE_INT64:
  249. Cast<T, int64_t>(input, output);
  250. break;
  251. case DataType::DE_UINT64:
  252. Cast<T, uint64_t>(input, output);
  253. break;
  254. case DataType::DE_FLOAT16:
  255. Cast<T, float16>(input, output);
  256. break;
  257. case DataType::DE_FLOAT32:
  258. Cast<T, float>(input, output);
  259. break;
  260. case DataType::DE_FLOAT64:
  261. Cast<T, double>(input, output);
  262. break;
  263. case DataType::DE_UNKNOWN:
  264. MS_LOG(ERROR) << "Unknown data type.";
  265. break;
  266. }
  267. }
  268. // Type cast operator
  269. Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
  270. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type));
  271. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  272. switch (input->type().value()) {
  273. case DataType::DE_BOOL:
  274. CastFrom<bool>(input, output);
  275. break;
  276. case DataType::DE_INT8:
  277. CastFrom<int8_t>(input, output);
  278. break;
  279. case DataType::DE_UINT8:
  280. CastFrom<uint8_t>(input, output);
  281. break;
  282. case DataType::DE_INT16:
  283. CastFrom<int16_t>(input, output);
  284. break;
  285. case DataType::DE_UINT16:
  286. CastFrom<uint16_t>(input, output);
  287. break;
  288. case DataType::DE_INT32:
  289. CastFrom<int32_t>(input, output);
  290. break;
  291. case DataType::DE_UINT32:
  292. CastFrom<uint32_t>(input, output);
  293. break;
  294. case DataType::DE_INT64:
  295. CastFrom<int64_t>(input, output);
  296. break;
  297. case DataType::DE_UINT64:
  298. CastFrom<uint64_t>(input, output);
  299. break;
  300. case DataType::DE_FLOAT16:
  301. CastFrom<float16>(input, output);
  302. break;
  303. case DataType::DE_FLOAT32:
  304. CastFrom<float>(input, output);
  305. break;
  306. case DataType::DE_FLOAT64:
  307. CastFrom<double>(input, output);
  308. break;
  309. case DataType::DE_UNKNOWN:
  310. // sanity check, unreachable code.
  311. RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type.");
  312. }
  313. return Status::OK();
  314. }
  315. Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  316. // initiate new tensor for type cast
  317. DataType new_type = DataType("float16");
  318. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type));
  319. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  320. auto in_itr = input->begin<float>();
  321. auto out_itr = (*output)->begin<float16>();
  322. auto out_end = (*output)->end<float16>();
  323. for (; out_itr != out_end; in_itr++, out_itr++) {
  324. float element = *in_itr;
  325. float float16_max = static_cast<float>(std::numeric_limits<Eigen::half>::max());
  326. float float16_min = static_cast<float>(std::numeric_limits<Eigen::half>::lowest());
  327. if (element > float16_max || element < float16_min) {
  328. RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" +
  329. std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
  330. }
  331. *out_itr = Eigen::half(*in_itr);
  332. }
  333. return Status::OK();
  334. }
  335. Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
  336. const std::shared_ptr<Tensor> &pad_val) {
  337. if (pad_val == nullptr) {
  338. if (src->type().IsNumeric()) {
  339. return PadEndNumeric(src, dst, pad_shape, 0);
  340. } else {
  341. return PadEndString(src, dst, pad_shape, "");
  342. }
  343. }
  344. CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
  345. "Source and pad_value tensors are not of the same type.");
  346. if (pad_val->type().IsNumeric()) {
  347. std::shared_ptr<Tensor> float_pad_value;
  348. RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
  349. float val = 0;
  350. RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
  351. return PadEndNumeric(src, dst, pad_shape, val);
  352. }
  353. std::string_view val;
  354. RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
  355. return PadEndString(src, dst, pad_shape, std::string(val));
  356. }
  357. Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  358. const std::vector<dsize_t> &pad_shape, float pad_val) {
  359. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  360. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  361. (*dst) = src; // if no padding, copy the pointer
  362. } else {
  363. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  364. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
  365. auto tensor_type = src->type().value();
  366. if (pad_val == 0) { // if pad with zero, don't care what type it is
  367. RETURN_IF_NOT_OK((*dst)->Zero());
  368. } else if (tensor_type == DataType::DE_INT8) {
  369. RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
  370. } else if (tensor_type == DataType::DE_BOOL) {
  371. RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
  372. } else if (tensor_type == DataType::DE_UINT8) {
  373. RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
  374. } else if (tensor_type == DataType::DE_INT16) {
  375. RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
  376. } else if (tensor_type == DataType::DE_FLOAT16) {
  377. RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
  378. } else if (tensor_type == DataType::DE_UINT16) {
  379. RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
  380. } else if (tensor_type == DataType::DE_INT32) {
  381. RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
  382. } else if (tensor_type == DataType::DE_UINT32) {
  383. RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
  384. } else if (tensor_type == DataType::DE_INT64) {
  385. RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
  386. } else if (tensor_type == DataType::DE_UINT64) {
  387. RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
  388. } else if (tensor_type == DataType::DE_FLOAT32) {
  389. RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
  390. } else if (tensor_type == DataType::DE_FLOAT64) {
  391. RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
  392. } else {
  393. RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
  394. }
  395. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  396. RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
  397. }
  398. return Status::OK();
  399. }
  400. Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
  401. std::vector<dsize_t> cur_ind, size_t cur_dim) {
  402. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  403. dst->CopyLastDimAt(src, cur_ind);
  404. } else { // not the last dimension, keep doing recursion
  405. dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
  406. for (dsize_t i = 0; i < min_ind; i++) {
  407. cur_ind[cur_dim] = i;
  408. RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
  409. }
  410. }
  411. return Status::OK();
  412. }
  413. Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  414. const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
  415. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  416. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  417. (*dst) = src; // if no padding, copy the pointer
  418. } else {
  419. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  420. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  421. std::vector<std::string> strings;
  422. RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
  423. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape)));
  424. }
  425. return Status::OK();
  426. }
  427. Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
  428. const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
  429. const std::string &pad_value) {
  430. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  431. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  432. for (dsize_t i = 0; i < min_ind; i++) {
  433. cur_ind[cur_dim] = i;
  434. std::string_view item;
  435. RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
  436. dst->emplace_back(item);
  437. }
  438. for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
  439. dst->emplace_back(pad_value);
  440. }
  441. } else { // not the last dimension, keep doing recursion
  442. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  443. for (dsize_t i = 0; i < min_ind; i++) {
  444. cur_ind[cur_dim] = i;
  445. RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
  446. }
  447. dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
  448. for (dsize_t i = 0; i < count; i++) {
  449. dst->emplace_back(pad_value);
  450. }
  451. }
  452. return Status::OK();
  453. }
  454. template <typename T>
  455. Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
  456. const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
  457. T value;
  458. RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
  459. auto in_itr = input->begin<T>();
  460. auto out_itr = output->begin<bool>();
  461. for (; in_itr != input->end<T>(); in_itr++, out_itr++) {
  462. switch (op) {
  463. case RelationalOp::kEqual:
  464. *out_itr = (*in_itr == value);
  465. break;
  466. case RelationalOp::kNotEqual:
  467. *out_itr = (*in_itr != value);
  468. break;
  469. case RelationalOp::kGreater:
  470. *out_itr = (*in_itr > value);
  471. break;
  472. case RelationalOp::kGreaterEqual:
  473. *out_itr = (*in_itr >= value);
  474. break;
  475. case RelationalOp::kLess:
  476. *out_itr = (*in_itr < value);
  477. break;
  478. case RelationalOp::kLessEqual:
  479. *out_itr = (*in_itr <= value);
  480. break;
  481. default:
  482. RETURN_STATUS_UNEXPECTED("Unknown relational operator.");
  483. }
  484. }
  485. return Status::OK();
  486. }
  487. Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
  488. RelationalOp op) {
  489. CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
  490. "Cannot convert constant value to the type of the input tensor.");
  491. CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar");
  492. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL)));
  493. std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
  494. std::shared_ptr<Tensor> casted_value;
  495. if (input->type().IsNumeric()) {
  496. RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
  497. } else {
  498. casted_value = value;
  499. }
  500. switch (input->type().value()) {
  501. case DataType::DE_BOOL:
  502. RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
  503. break;
  504. case DataType::DE_INT8:
  505. RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
  506. break;
  507. case DataType::DE_UINT8:
  508. RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
  509. break;
  510. case DataType::DE_UINT16:
  511. RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
  512. break;
  513. case DataType::DE_INT16:
  514. RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
  515. break;
  516. case DataType::DE_UINT32:
  517. RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
  518. break;
  519. case DataType::DE_INT32:
  520. RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
  521. break;
  522. case DataType::DE_UINT64:
  523. RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
  524. break;
  525. case DataType::DE_INT64:
  526. RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
  527. break;
  528. case DataType::DE_FLOAT16:
  529. RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
  530. break;
  531. case DataType::DE_FLOAT32:
  532. RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
  533. break;
  534. case DataType::DE_FLOAT64:
  535. RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
  536. break;
  537. case DataType::DE_STRING:
  538. RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
  539. break;
  540. case DataType::DE_UNKNOWN:
  541. RETURN_STATUS_UNEXPECTED("Unsupported input type.");
  542. break;
  543. }
  544. return Status::OK();
  545. }
  546. Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
  547. std::shared_ptr<Tensor> append) {
  548. CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
  549. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
  550. axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
  551. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
  552. std::shared_ptr<Tensor> out;
  553. if (prepend != nullptr) {
  554. CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
  555. RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0]));
  556. } else {
  557. out = input[0];
  558. }
  559. for (dsize_t i = 1; i < input.size(); i++) {
  560. std::shared_ptr<Tensor> out_t;
  561. CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
  562. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i]));
  563. out = out_t;
  564. }
  565. std::shared_ptr<Tensor> out_t;
  566. if (append != nullptr) {
  567. CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
  568. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append));
  569. } else {
  570. out_t = out;
  571. }
  572. output->push_back(out_t);
  573. return Status::OK();
  574. }
  575. Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
  576. std::shared_ptr<Tensor> append) {
  577. CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match");
  578. TensorShape t({});
  579. for (dsize_t i = 0; i < input->shape().Rank(); i++) {
  580. if (i != axis) {
  581. t = t.AppendDim(input->shape()[i]);
  582. } else {
  583. dsize_t new_shape = input->shape()[i] + append->shape()[i];
  584. t = t.AppendDim(new_shape);
  585. }
  586. }
  587. std::shared_ptr<Tensor> out;
  588. if (input->type().IsNumeric()) {
  589. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type()));
  590. RETURN_IF_NOT_OK(out->Concatenate({0}, input));
  591. RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append));
  592. *output = out;
  593. } else {
  594. std::vector<std::string> strings;
  595. auto itr = input->begin<std::string_view>();
  596. for (; itr != input->end<std::string_view>(); itr++) {
  597. strings.emplace_back(*itr);
  598. }
  599. itr = append->begin<std::string_view>();
  600. for (; itr != append->end<std::string_view>(); itr++) {
  601. strings.emplace_back(*itr);
  602. }
  603. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t));
  604. *output = out;
  605. }
  606. return Status::OK();
  607. }
  608. } // namespace dataset
  609. } // namespace mindspore