You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_op.h 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DATASET_KERNELS_TENSOR_OP_H_
  17. #define DATASET_KERNELS_TENSOR_OP_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include "dataset/core/tensor.h"
  22. #include "dataset/util/status.h"
  23. #define IO_CHECK(input, output) \
  24. do { \
  25. if (input == nullptr || output == nullptr) { \
  26. RETURN_STATUS_UNEXPECTED("input or output is null."); \
  27. } \
  28. } while (false)
  29. #define IO_CHECK_VECTOR(input, output) \
  30. do { \
  31. if (output == nullptr) { \
  32. RETURN_STATUS_UNEXPECTED("output is null."); \
  33. } \
  34. for (auto &_i : input) { \
  35. if (_i == nullptr) { \
  36. RETURN_STATUS_UNEXPECTED("input is null."); \
  37. } \
  38. } \
  39. } while (false)
  40. namespace mindspore {
  41. namespace dataset {
  42. // A class that does a computation on a Tensor
  43. class TensorOp {
  44. public:
  45. TensorOp() = default;
  46. virtual ~TensorOp() = default;
  47. // A function that prints info about the tensor operation
  48. // @param out
  49. virtual void Print(std::ostream &out) const;
  50. // Provide stream operator for displaying it
  51. // @param output stream
  52. // @param so the TensorOp object to be printed
  53. // @return output stream
  54. friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) {
  55. so.Print(out);
  56. return out;
  57. }
  58. // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp
  59. // @param input shares the ownership of the Tensor (increase the ref count).
  60. // @param output the address to a shared_ptr where the result will be placed.
  61. // @return Status
  62. virtual Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
  63. // Perform an operation on Tensors from multiple columns, and produce multiple Tensors.
  64. // This is for m-to-n column MapOp.
  65. // @param input is a vector of shared_ptr to Tensor (pass by const reference).
  66. // @param output is the address to an empty vector of shared_ptr to Tensor.
  67. // @return Status
  68. virtual Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
  69. std::vector<std::shared_ptr<Tensor>> *output);
  70. // Returns true oif the TensorOp takes one input and returns one output.
  71. // @return true/false
  72. bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; }
  73. // Function to determine the number of inputs the TensorOp can take. 0: means undefined.
  74. // @return uint32_t
  75. virtual uint32_t NumInput() { return 1; }
  76. // Function to determine the number of output the TensorOp generates. 0: means undefined.
  77. // @return uint32_t
  78. virtual uint32_t NumOutput() { return 1; }
  79. // Function to determine the shapes of the output tensor given the input tensors' shapes.
  80. // If a subclass did not override this function, it means that the shape does not change.
  81. // @param inputs in: vector of the shapes of the input tensors.
  82. // @param outputs out: vector of the shapes of the output tensors to be filled.
  83. // @return Status
  84. virtual Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs);
  85. // Function to determine the types of the output tensor given the input tensor's types.
  86. // If a subclass did not override this function, it means that the type does not change.
  87. // @param inputs in: vector of the types of the input tensors.
  88. // @param outputs out: vector of the types of the output tensors to be filled.
  89. // @return Status
  90. virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs);
  91. };
  92. } // namespace dataset
  93. } // namespace mindspore
  94. #endif // DATASET_KERNELS_TENSOR_OP_H_