You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_utils.cc 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/data/data_utils.h"
  17. #include <vector>
  18. #include "dataset/core/constants.h"
  19. #include "dataset/core/tensor.h"
  20. #include "dataset/core/tensor_shape.h"
  21. #include "dataset/core/data_type.h"
  22. #include "dataset/core/pybind_support.h"
  23. namespace mindspore {
  24. namespace dataset {
  25. Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
  26. dsize_t num_classes, int64_t index) {
  27. uint64_t class_idx;
  28. if (input->Rank() == 0) {
  29. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {}));
  30. } else {
  31. RETURN_IF_NOT_OK(input->GetItemAt<uint64_t>(&class_idx, {index}));
  32. }
  33. if (class_idx >= static_cast<uint64_t>(num_classes)) {
  34. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  35. }
  36. if (input->type() == DataType::DE_UINT64) {
  37. RETURN_IF_NOT_OK((*output)->SetItemAt<uint64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  38. } else if (input->type() == DataType::DE_UINT32) {
  39. RETURN_IF_NOT_OK((*output)->SetItemAt<uint32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  40. } else if (input->type() == DataType::DE_UINT16) {
  41. RETURN_IF_NOT_OK((*output)->SetItemAt<uint16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  42. } else if (input->type() == DataType::DE_UINT8) {
  43. RETURN_IF_NOT_OK((*output)->SetItemAt<uint8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  44. } else {
  45. RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input.");
  46. }
  47. return Status::OK();
  48. }
  49. Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
  50. int64_t index) {
  51. int64_t class_idx;
  52. if (input->Rank() == 0) {
  53. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {}));
  54. } else {
  55. RETURN_IF_NOT_OK(input->GetItemAt<int64_t>(&class_idx, {index}));
  56. }
  57. if (class_idx >= static_cast<int64_t>(num_classes)) {
  58. RETURN_STATUS_UNEXPECTED("One_hot index values are not in range");
  59. }
  60. if (input->type() == DataType::DE_INT64) {
  61. RETURN_IF_NOT_OK((*output)->SetItemAt<int64_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  62. } else if (input->type() == DataType::DE_INT32) {
  63. RETURN_IF_NOT_OK((*output)->SetItemAt<int32_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  64. } else if (input->type() == DataType::DE_INT16) {
  65. RETURN_IF_NOT_OK((*output)->SetItemAt<int16_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  66. } else if (input->type() == DataType::DE_INT8) {
  67. RETURN_IF_NOT_OK((*output)->SetItemAt<int8_t>({index, static_cast<dsize_t>(class_idx)}, 1));
  68. } else {
  69. RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input.");
  70. }
  71. return Status::OK();
  72. }
  73. Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
  74. input->Squeeze();
  75. if (input->Rank() > 1) { // We expect the input to be int he first dimension
  76. RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
  77. }
  78. if (!input->type().IsInt()) {
  79. RETURN_STATUS_UNEXPECTED("One hot does not support input of this type.");
  80. }
  81. try {
  82. dsize_t num_elements = 1;
  83. if (input->Rank() == 1) num_elements = input->shape()[0];
  84. TensorShape out_shape({num_elements, num_classes});
  85. std::shared_ptr<Tensor> out;
  86. RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type()));
  87. RETURN_IF_NOT_OK(out->Zero());
  88. for (dsize_t i = 0; i < num_elements; ++i) {
  89. if (input->type().IsUnsignedInt()) {
  90. RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i));
  91. } else {
  92. RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i));
  93. }
  94. }
  95. out->Squeeze();
  96. *output = out;
  97. return Status::OK();
  98. } catch (const std::exception &e) {
  99. RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp");
  100. }
  101. }
  102. template <typename FROM, typename TO>
  103. void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  104. auto in_itr = input->begin<FROM>();
  105. auto out_itr = (*output)->begin<TO>();
  106. auto out_end = (*output)->end<TO>();
  107. for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
  108. *out_itr = static_cast<TO>(*in_itr);
  109. }
  110. template <typename T>
  111. void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  112. switch ((*output)->type().value()) {
  113. case DataType::DE_BOOL:
  114. Cast<T, bool>(input, output);
  115. break;
  116. case DataType::DE_INT8:
  117. Cast<T, int8_t>(input, output);
  118. break;
  119. case DataType::DE_UINT8:
  120. Cast<T, uint8_t>(input, output);
  121. break;
  122. case DataType::DE_INT16:
  123. Cast<T, int16_t>(input, output);
  124. break;
  125. case DataType::DE_UINT16:
  126. Cast<T, uint16_t>(input, output);
  127. break;
  128. case DataType::DE_INT32:
  129. Cast<T, int32_t>(input, output);
  130. break;
  131. case DataType::DE_UINT32:
  132. Cast<T, uint32_t>(input, output);
  133. break;
  134. case DataType::DE_INT64:
  135. Cast<T, int64_t>(input, output);
  136. break;
  137. case DataType::DE_UINT64:
  138. Cast<T, uint64_t>(input, output);
  139. break;
  140. case DataType::DE_FLOAT16:
  141. Cast<T, float16>(input, output);
  142. break;
  143. case DataType::DE_FLOAT32:
  144. Cast<T, float>(input, output);
  145. break;
  146. case DataType::DE_FLOAT64:
  147. Cast<T, double>(input, output);
  148. break;
  149. case DataType::DE_UNKNOWN:
  150. MS_LOG(ERROR) << "Unknown data type.";
  151. break;
  152. }
  153. }
  154. // Type cast operator
  155. Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
  156. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type));
  157. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  158. switch (input->type().value()) {
  159. case DataType::DE_BOOL:
  160. CastFrom<bool>(input, output);
  161. break;
  162. case DataType::DE_INT8:
  163. CastFrom<int8_t>(input, output);
  164. break;
  165. case DataType::DE_UINT8:
  166. CastFrom<uint8_t>(input, output);
  167. break;
  168. case DataType::DE_INT16:
  169. CastFrom<int16_t>(input, output);
  170. break;
  171. case DataType::DE_UINT16:
  172. CastFrom<uint16_t>(input, output);
  173. break;
  174. case DataType::DE_INT32:
  175. CastFrom<int32_t>(input, output);
  176. break;
  177. case DataType::DE_UINT32:
  178. CastFrom<uint32_t>(input, output);
  179. break;
  180. case DataType::DE_INT64:
  181. CastFrom<int64_t>(input, output);
  182. break;
  183. case DataType::DE_UINT64:
  184. CastFrom<uint64_t>(input, output);
  185. break;
  186. case DataType::DE_FLOAT16:
  187. CastFrom<float16>(input, output);
  188. break;
  189. case DataType::DE_FLOAT32:
  190. CastFrom<float>(input, output);
  191. break;
  192. case DataType::DE_FLOAT64:
  193. CastFrom<double>(input, output);
  194. break;
  195. case DataType::DE_UNKNOWN:
  196. // sanity check, unreachable code.
  197. RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type.");
  198. }
  199. return Status::OK();
  200. }
  201. Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  202. // initiate new tensor for type cast
  203. DataType new_type = DataType("float16");
  204. RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type));
  205. RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
  206. auto in_itr = input->begin<float>();
  207. auto out_itr = (*output)->begin<float16>();
  208. auto out_end = (*output)->end<float16>();
  209. for (; out_itr != out_end; in_itr++, out_itr++) *out_itr = Eigen::half(*in_itr);
  210. return Status::OK();
  211. }
  212. } // namespace dataset
  213. } // namespace mindspore