You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_utils.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/data/data_utils.h"
  17. #include <algorithm>
  18. #include <limits>
  19. #include <string>
  20. #include <vector>
  21. #include "dataset/core/constants.h"
  22. #include "dataset/core/data_type.h"
  23. #include "dataset/core/pybind_support.h"
  24. #include "dataset/core/tensor.h"
  25. #include "dataset/core/tensor_shape.h"
  26. #include "dataset/kernels/data/type_cast_op.h"
  27. #include "dataset/util/status.h"
  28. namespace mindspore {
  29. namespace dataset {
  30. Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
  31. dsize_t num_classes, int64_t index) {
  32. uint64_t class_idx;
  33. if (input->Rank() == 0) {
  34. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {}));
  35. } else {
  36. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {index}));
  37. }
  38. if (class_idx >= static_cast<uint64_t>(num_classes)) {
  39. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  40. }
  41. if (input->type() == DataType::DE_UINT64) {
  42. RETURN_IF_NOT_OK((*output)->SetItemAt<uint64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  43. } else if (input->type() == DataType::DE_UINT32) {
  44. RETURN_IF_NOT_OK((*output)->SetItemAt<uint32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  45. } else if (input->type() == DataType::DE_UINT16) {
  46. RETURN_IF_NOT_OK((*output)->SetItemAt<uint16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  47. } else if (input->type() == DataType::DE_UINT8) {
  48. RETURN_IF_NOT_OK((*output)->SetItemAt<uint8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  49. } else {
  50. RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input.");
  51. }
  52. return Status::OK();
  53. }
  54. Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
  55. int64_t index) {
  56. int64_t class_idx;
  57. if (input->Rank() == 0) {
  58. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {}));
  59. } else {
  60. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {index}));
  61. }
  62. if (class_idx >= static_cast<int64_t>(num_classes)) {
  63. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  64. }
  65. if (input->type() == DataType::DE_INT64) {
  66. RETURN_IF_NOT_OK((*output)->SetItemAt<int64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  67. } else if (input->type() == DataType::DE_INT32) {
  68. RETURN_IF_NOT_OK((*output)->SetItemAt<int32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  69. } else if (input->type() == DataType::DE_INT16) {
  70. RETURN_IF_NOT_OK((*output)->SetItemAt<int16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  71. } else if (input->type() == DataType::DE_INT8) {
  72. RETURN_IF_NOT_OK((*output)->SetItemAt<int8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  73. } else {
  74. RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input.");
  75. }
  76. return Status::OK();
  77. }
  78. Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
  79. input->Squeeze();
  80. if (input->Rank() > 1) { // We expect the input to be int he first dimension
  81. RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
  82. }
  83. if (!input->type().IsInt()) {
  84. RETURN_STATUS_UNEXPECTED("One hot does not support input of this type.");
  85. }
  86. try {
  87. dsize_t num_elements = 1;
  88. if (input->Rank() == 1) num_elements = input->shape()[0];
  89. TensorShape out_shape({num_elements, num_classes});
  90. std::shared_ptr<Tensor> out;
  91. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type()));
  92. RETURN_IF_NOT_OK(out->Zero());
  93. for (dsize_t i = 0; i < num_elements; ++i) {
  94. if (input->type().IsUnsignedInt()) {
  95. RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i));
  96. } else {
  97. RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i));
  98. }
  99. }
  100. out->Squeeze();
  101. *output = out;
  102. return Status::OK();
  103. } catch (const std::exception &e) {
  104. RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp");
  105. }
  106. }
  107. Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
  108. const DataType &fill_type = fill_value->type();
  109. const DataType &input_type = input->type();
  110. const TensorShape &input_shape = input->shape();
  111. CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)),
  112. "Types do not match");
  113. CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
  114. std::shared_ptr<Tensor> out, fill_output;
  115. if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) {
  116. std::unique_ptr<TypeCastOp> op(new TypeCastOp(input_type));
  117. RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output));
  118. } else {
  119. fill_output = fill_value;
  120. }
  121. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type));
  122. switch (input_type.value()) {
  123. case DataType::DE_BOOL: {
  124. bool value = 0;
  125. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  126. out->Fill<bool>(value);
  127. break;
  128. }
  129. case DataType::DE_INT8: {
  130. int8_t value = 0;
  131. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  132. out->Fill<int8_t>(value);
  133. break;
  134. }
  135. case DataType::DE_UINT8: {
  136. uint8_t value = 0;
  137. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  138. out->Fill<uint8_t>(value);
  139. break;
  140. }
  141. case DataType::DE_UINT16: {
  142. uint16_t value = 0;
  143. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  144. out->Fill<uint16_t>(value);
  145. break;
  146. }
  147. case DataType::DE_INT16: {
  148. int16_t value = 0;
  149. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  150. out->Fill<int16_t>(value);
  151. break;
  152. }
  153. case DataType::DE_UINT32: {
  154. uint32_t value = 0;
  155. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  156. out->Fill<uint32_t>(value);
  157. break;
  158. }
  159. case DataType::DE_INT32: {
  160. int32_t value = 0;
  161. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  162. out->Fill<int32_t>(value);
  163. break;
  164. }
  165. case DataType::DE_UINT64: {
  166. uint64_t value = 0;
  167. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  168. out->Fill<uint64_t>(value);
  169. break;
  170. }
  171. case DataType::DE_INT64: {
  172. int64_t value = 0;
  173. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  174. out->Fill<int64_t>(value);
  175. break;
  176. }
  177. case DataType::DE_FLOAT16: {
  178. int64_t value = 0;
  179. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  180. out->Fill<float>(value);
  181. break;
  182. }
  183. case DataType::DE_FLOAT32: {
  184. float value = 0;
  185. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  186. out->Fill<float>(value);
  187. break;
  188. }
  189. case DataType::DE_FLOAT64: {
  190. double value = 0;
  191. RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
  192. out->Fill<double>(value);
  193. break;
  194. }
  195. case DataType::DE_STRING: {
  196. std::vector<std::string> strings;
  197. std::string_view fill_string_view;
  198. RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
  199. std::string fill_string = std::string(fill_string_view);
  200. for (int i = 0; i < input_shape.NumOfElements(); i++) {
  201. strings.emplace_back(fill_string);
  202. }
  203. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape));
  204. break;
  205. }
  206. case DataType::DE_UNKNOWN: {
  207. RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type.");
  208. break;
  209. }
  210. }
  211. *output = out;
  212. return Status::OK();
  213. }
  214. template <typename FROM, typename TO>
  215. void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  216. auto in_itr = input->begin<FROM>();
  217. auto out_itr = (*output)->begin<TO>();
  218. auto out_end = (*output)->end<TO>();
  219. for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
  220. *out_itr = static_cast<TO>(*in_itr);
  221. }
  222. template <typename T>
  223. void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  224. switch ((*output)->type().value()) {
  225. case DataType::DE_BOOL:
  226. Cast<T, bool>(input, output);
  227. break;
  228. case DataType::DE_INT8:
  229. Cast<T, int8_t>(input, output);
  230. break;
  231. case DataType::DE_UINT8:
  232. Cast<T, uint8_t>(input, output);
  233. break;
  234. case DataType::DE_INT16:
  235. Cast<T, int16_t>(input, output);
  236. break;
  237. case DataType::DE_UINT16:
  238. Cast<T, uint16_t>(input, output);
  239. break;
  240. case DataType::DE_INT32:
  241. Cast<T, int32_t>(input, output);
  242. break;
  243. case DataType::DE_UINT32:
  244. Cast<T, uint32_t>(input, output);
  245. break;
  246. case DataType::DE_INT64:
  247. Cast<T, int64_t>(input, output);
  248. break;
  249. case DataType::DE_UINT64:
  250. Cast<T, uint64_t>(input, output);
  251. break;
  252. case DataType::DE_FLOAT16:
  253. Cast<T, float16>(input, output);
  254. break;
  255. case DataType::DE_FLOAT32:
  256. Cast<T, float>(input, output);
  257. break;
  258. case DataType::DE_FLOAT64:
  259. Cast<T, double>(input, output);
  260. break;
  261. case DataType::DE_UNKNOWN:
  262. MS_LOG(ERROR) << "Unknown data type.";
  263. break;
  264. }
  265. }
  266. // Type cast operator
  267. Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
  268. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type));
  269. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  270. switch (input->type().value()) {
  271. case DataType::DE_BOOL:
  272. CastFrom<bool>(input, output);
  273. break;
  274. case DataType::DE_INT8:
  275. CastFrom<int8_t>(input, output);
  276. break;
  277. case DataType::DE_UINT8:
  278. CastFrom<uint8_t>(input, output);
  279. break;
  280. case DataType::DE_INT16:
  281. CastFrom<int16_t>(input, output);
  282. break;
  283. case DataType::DE_UINT16:
  284. CastFrom<uint16_t>(input, output);
  285. break;
  286. case DataType::DE_INT32:
  287. CastFrom<int32_t>(input, output);
  288. break;
  289. case DataType::DE_UINT32:
  290. CastFrom<uint32_t>(input, output);
  291. break;
  292. case DataType::DE_INT64:
  293. CastFrom<int64_t>(input, output);
  294. break;
  295. case DataType::DE_UINT64:
  296. CastFrom<uint64_t>(input, output);
  297. break;
  298. case DataType::DE_FLOAT16:
  299. CastFrom<float16>(input, output);
  300. break;
  301. case DataType::DE_FLOAT32:
  302. CastFrom<float>(input, output);
  303. break;
  304. case DataType::DE_FLOAT64:
  305. CastFrom<double>(input, output);
  306. break;
  307. case DataType::DE_UNKNOWN:
  308. // sanity check, unreachable code.
  309. RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type.");
  310. }
  311. return Status::OK();
  312. }
  313. Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  314. // initiate new tensor for type cast
  315. DataType new_type = DataType("float16");
  316. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type));
  317. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  318. auto in_itr = input->begin<float>();
  319. auto out_itr = (*output)->begin<float16>();
  320. auto out_end = (*output)->end<float16>();
  321. for (; out_itr != out_end; in_itr++, out_itr++) {
  322. float element = *in_itr;
  323. float float16_max = static_cast<float>(std::numeric_limits<Eigen::half>::max());
  324. float float16_min = static_cast<float>(std::numeric_limits<Eigen::half>::lowest());
  325. if (element > float16_max || element < float16_min) {
  326. RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" +
  327. std::to_string(float16_max) + ", " + std::to_string(float16_min) + "].");
  328. }
  329. *out_itr = Eigen::half(*in_itr);
  330. }
  331. return Status::OK();
  332. }
  333. Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
  334. const std::shared_ptr<Tensor> &pad_val) {
  335. if (pad_val == nullptr) {
  336. if (src->type().IsNumeric()) {
  337. return PadEndNumeric(src, dst, pad_shape, 0);
  338. } else {
  339. return PadEndString(src, dst, pad_shape, "");
  340. }
  341. }
  342. CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(),
  343. "Source and pad_value tensors are not of the same type.");
  344. if (pad_val->type().IsNumeric()) {
  345. std::shared_ptr<Tensor> float_pad_value;
  346. RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32)));
  347. float val = 0;
  348. RETURN_IF_NOT_OK(float_pad_value->GetItemAt<float>(&val, {}));
  349. return PadEndNumeric(src, dst, pad_shape, val);
  350. }
  351. std::string_view val;
  352. RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {}));
  353. return PadEndString(src, dst, pad_shape, std::string(val));
  354. }
  355. Status PadEndNumeric(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  356. const std::vector<dsize_t> &pad_shape, float pad_val) {
  357. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  358. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  359. (*dst) = src; // if no padding, copy the pointer
  360. } else {
  361. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  362. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type()));
  363. auto tensor_type = src->type().value();
  364. if (pad_val == 0) { // if pad with zero, don't care what type it is
  365. RETURN_IF_NOT_OK((*dst)->Zero());
  366. } else if (tensor_type == DataType::DE_INT8) {
  367. RETURN_IF_NOT_OK((*dst)->Fill<int8_t>(pad_val));
  368. } else if (tensor_type == DataType::DE_BOOL) {
  369. RETURN_IF_NOT_OK((*dst)->Fill<bool>(pad_val));
  370. } else if (tensor_type == DataType::DE_UINT8) {
  371. RETURN_IF_NOT_OK((*dst)->Fill<uint8_t>(pad_val));
  372. } else if (tensor_type == DataType::DE_INT16) {
  373. RETURN_IF_NOT_OK((*dst)->Fill<int16_t>(pad_val));
  374. } else if (tensor_type == DataType::DE_FLOAT16) {
  375. RETURN_IF_NOT_OK((*dst)->Fill<float16>(static_cast<float16>(pad_val)));
  376. } else if (tensor_type == DataType::DE_UINT16) {
  377. RETURN_IF_NOT_OK((*dst)->Fill<uint16_t>(pad_val));
  378. } else if (tensor_type == DataType::DE_INT32) {
  379. RETURN_IF_NOT_OK((*dst)->Fill<int32_t>(pad_val));
  380. } else if (tensor_type == DataType::DE_UINT32) {
  381. RETURN_IF_NOT_OK((*dst)->Fill<uint32_t>(pad_val));
  382. } else if (tensor_type == DataType::DE_INT64) {
  383. RETURN_IF_NOT_OK((*dst)->Fill<int64_t>(pad_val));
  384. } else if (tensor_type == DataType::DE_UINT64) {
  385. RETURN_IF_NOT_OK((*dst)->Fill<uint64_t>(pad_val));
  386. } else if (tensor_type == DataType::DE_FLOAT32) {
  387. RETURN_IF_NOT_OK((*dst)->Fill<float>(pad_val));
  388. } else if (tensor_type == DataType::DE_FLOAT64) {
  389. RETURN_IF_NOT_OK((*dst)->Fill<double>(pad_val));
  390. } else {
  391. RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type");
  392. }
  393. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  394. RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0));
  395. }
  396. return Status::OK();
  397. }
  398. Status PadEndNumericHelper(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> dst,
  399. std::vector<dsize_t> cur_ind, size_t cur_dim) {
  400. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  401. dst->CopyLastDimAt(src, cur_ind);
  402. } else { // not the last dimension, keep doing recursion
  403. dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]);
  404. for (dsize_t i = 0; i < min_ind; i++) {
  405. cur_ind[cur_dim] = i;
  406. RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1));
  407. }
  408. }
  409. return Status::OK();
  410. }
  411. Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
  412. const std::vector<dsize_t> &pad_shape, const std::string &pad_val) {
  413. CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr");
  414. if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) {
  415. (*dst) = src; // if no padding, copy the pointer
  416. } else {
  417. CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed");
  418. std::vector<dsize_t> cur_ind(src->Rank(), 0);
  419. std::vector<std::string> strings;
  420. RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val));
  421. RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape)));
  422. }
  423. return Status::OK();
  424. }
  425. Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::string> *dst,
  426. const TensorShape &dst_shape, std::vector<dsize_t> cur_ind, size_t cur_dim,
  427. const std::string &pad_value) {
  428. if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data
  429. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  430. for (dsize_t i = 0; i < min_ind; i++) {
  431. cur_ind[cur_dim] = i;
  432. std::string_view item;
  433. RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind));
  434. dst->emplace_back(item);
  435. }
  436. for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) {
  437. dst->emplace_back(pad_value);
  438. }
  439. } else { // not the last dimension, keep doing recursion
  440. dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]);
  441. for (dsize_t i = 0; i < min_ind; i++) {
  442. cur_ind[cur_dim] = i;
  443. RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value));
  444. }
  445. dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim];
  446. for (dsize_t i = 0; i < count; i++) {
  447. dst->emplace_back(pad_value);
  448. }
  449. }
  450. return Status::OK();
  451. }
  452. template <typename T>
  453. Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &output,
  454. const std::shared_ptr<Tensor> &value_tensor, RelationalOp op) {
  455. T value;
  456. RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
  457. auto in_itr = input->begin<T>();
  458. auto out_itr = output->begin<bool>();
  459. for (; in_itr != input->end<T>(); in_itr++, out_itr++) {
  460. switch (op) {
  461. case RelationalOp::kEqual:
  462. *out_itr = (*in_itr == value);
  463. break;
  464. case RelationalOp::kNotEqual:
  465. *out_itr = (*in_itr != value);
  466. break;
  467. case RelationalOp::kGreater:
  468. *out_itr = (*in_itr > value);
  469. break;
  470. case RelationalOp::kGreaterEqual:
  471. *out_itr = (*in_itr >= value);
  472. break;
  473. case RelationalOp::kLess:
  474. *out_itr = (*in_itr < value);
  475. break;
  476. case RelationalOp::kLessEqual:
  477. *out_itr = (*in_itr <= value);
  478. break;
  479. default:
  480. RETURN_STATUS_UNEXPECTED("Unknown relational operator.");
  481. }
  482. }
  483. return Status::OK();
  484. }
  485. Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &value,
  486. RelationalOp op) {
  487. CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(),
  488. "Cannot convert constant value to the type of the input tensor.");
  489. CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar");
  490. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL)));
  491. std::unique_ptr<TypeCastOp> value_cast_op(new TypeCastOp(input->type()));
  492. std::shared_ptr<Tensor> casted_value;
  493. if (input->type().IsNumeric()) {
  494. RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value));
  495. } else {
  496. casted_value = value;
  497. }
  498. switch (input->type().value()) {
  499. case DataType::DE_BOOL:
  500. RETURN_IF_NOT_OK(MaskHelper<bool>(input, *output, casted_value, op));
  501. break;
  502. case DataType::DE_INT8:
  503. RETURN_IF_NOT_OK(MaskHelper<int8_t>(input, *output, casted_value, op));
  504. break;
  505. case DataType::DE_UINT8:
  506. RETURN_IF_NOT_OK(MaskHelper<uint8_t>(input, *output, casted_value, op));
  507. break;
  508. case DataType::DE_UINT16:
  509. RETURN_IF_NOT_OK(MaskHelper<uint16_t>(input, *output, casted_value, op));
  510. break;
  511. case DataType::DE_INT16:
  512. RETURN_IF_NOT_OK(MaskHelper<int16_t>(input, *output, casted_value, op));
  513. break;
  514. case DataType::DE_UINT32:
  515. RETURN_IF_NOT_OK(MaskHelper<uint32_t>(input, *output, casted_value, op));
  516. break;
  517. case DataType::DE_INT32:
  518. RETURN_IF_NOT_OK(MaskHelper<int32_t>(input, *output, casted_value, op));
  519. break;
  520. case DataType::DE_UINT64:
  521. RETURN_IF_NOT_OK(MaskHelper<uint64_t>(input, *output, casted_value, op));
  522. break;
  523. case DataType::DE_INT64:
  524. RETURN_IF_NOT_OK(MaskHelper<int64_t>(input, *output, casted_value, op));
  525. break;
  526. case DataType::DE_FLOAT16:
  527. RETURN_IF_NOT_OK(MaskHelper<float16>(input, *output, casted_value, op));
  528. break;
  529. case DataType::DE_FLOAT32:
  530. RETURN_IF_NOT_OK(MaskHelper<float>(input, *output, casted_value, op));
  531. break;
  532. case DataType::DE_FLOAT64:
  533. RETURN_IF_NOT_OK(MaskHelper<double>(input, *output, casted_value, op));
  534. break;
  535. case DataType::DE_STRING:
  536. RETURN_IF_NOT_OK(MaskHelper<std::string_view>(input, *output, casted_value, op));
  537. break;
  538. case DataType::DE_UNKNOWN:
  539. RETURN_STATUS_UNEXPECTED("Unsupported input type.");
  540. break;
  541. }
  542. return Status::OK();
  543. }
  544. Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
  545. std::shared_ptr<Tensor> append) {
  546. CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
  547. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
  548. axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
  549. CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
  550. std::shared_ptr<Tensor> out;
  551. if (prepend != nullptr) {
  552. CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported");
  553. RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0]));
  554. } else {
  555. out = input[0];
  556. }
  557. for (dsize_t i = 1; i < input.size(); i++) {
  558. std::shared_ptr<Tensor> out_t;
  559. CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported");
  560. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i]));
  561. out = out_t;
  562. }
  563. std::shared_ptr<Tensor> out_t;
  564. if (append != nullptr) {
  565. CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported");
  566. RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append));
  567. } else {
  568. out_t = out;
  569. }
  570. output->push_back(out_t);
  571. return Status::OK();
  572. }
  573. Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
  574. std::shared_ptr<Tensor> append) {
  575. CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match");
  576. TensorShape t({});
  577. for (dsize_t i = 0; i < input->shape().Rank(); i++) {
  578. if (i != axis) {
  579. t = t.AppendDim(input->shape()[i]);
  580. } else {
  581. dsize_t new_shape = input->shape()[i] + append->shape()[i];
  582. t = t.AppendDim(new_shape);
  583. }
  584. }
  585. std::shared_ptr<Tensor> out;
  586. if (input->type().IsNumeric()) {
  587. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type()));
  588. RETURN_IF_NOT_OK(out->Concatenate({0}, input));
  589. RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append));
  590. *output = out;
  591. } else {
  592. std::vector<std::string> strings;
  593. auto itr = input->begin<std::string_view>();
  594. for (; itr != input->end<std::string_view>(); itr++) {
  595. strings.emplace_back(*itr);
  596. }
  597. itr = append->begin<std::string_view>();
  598. for (; itr != append->end<std::string_view>(); itr++) {
  599. strings.emplace_back(*itr);
  600. }
  601. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t));
  602. *output = out;
  603. }
  604. return Status::OK();
  605. }
  606. } // namespace dataset
  607. } // namespace mindspore