You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_utils.h 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_
  17. #define DATASET_KERNELS_DATA_DATA_UTILS_H_
  18. #include <memory>
  19. #include <vector>
  20. #include "dataset/core/constants.h"
  21. #include "dataset/core/cv_tensor.h"
  22. #include "dataset/core/data_type.h"
  23. #include "dataset/core/tensor.h"
  24. namespace mindspore {
  25. namespace dataset {
  26. // Returns Onehot encoding of the input tensor.
  27. // Example: if input=2 and numClasses=3, the output is [0 0 1].
  28. // @param input: Tensor has type DE_UINT64, the non-one hot values are stored
  29. // along the first dimensions or rows..
  30. // If the rank of input is not 1 or the type is not DE_UINT64,
  31. // then it will fail.
  32. // @param output: Tensor. The shape of the output tensor is <input_shape, numClasses>
  33. // and the type is same as input.
  34. // @param num_classes: Number of classes to.
  35. Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes);
  36. Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
  37. dsize_t num_classes, int64_t index);
  38. Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
  39. int64_t index);
  40. // Returns a type changed input tensor.
  41. // Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp
  42. // @param input Tensor
  43. // @param output Tensor. The shape of the output tensor is same as input with the type changed.
  44. // @param data_type: type of data to cast data to
  45. // @note: this operation will do a memcpy and if the value is truncated then precision will be lost
  46. template <typename T>
  47. void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
  48. template <typename FROM, typename TO>
  49. void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
  50. Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
  51. Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type);
  52. } // namespace dataset
  53. } // namespace mindspore
  54. #endif // DATASET_KERNELS_DATA_DATA_UTILS_H_