You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_op.cc 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "dataset/kernels/tensor_op.h"
  17. #include <iostream>
  18. #include <memory>
  19. #include <mutex>
  20. #include <vector>
  21. namespace mindspore {
  22. namespace dataset {
  23. // Name: Compute()
  24. // Description: This Compute() take 1 Tensor and produce 1 Tensor.
  25. // The derived class should override this function otherwise error.
  26. Status TensorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
  27. IO_CHECK(input, output);
  28. if (!OneToOne()) {
  29. return Status(StatusCode::kUnexpectedError, "Wrong Compute() function is called. This is not 1-1 TensorOp.");
  30. } else {
  31. return Status(StatusCode::kUnexpectedError,
  32. "Is this TensorOp 1-1? If yes, please implement this Compute() in the derived class.");
  33. }
  34. }
  35. // Name: Compute()
  36. // Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too.
  37. // The derived class should override this function otherwise error.
  38. Status TensorOp::Compute(const TensorRow &input, TensorRow *output) {
  39. IO_CHECK_VECTOR(input, output);
  40. if (OneToOne()) {
  41. output->resize(1);
  42. return Compute(input[0], &(*output)[0]);
  43. }
  44. return Status(StatusCode::kUnexpectedError,
  45. "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class.");
  46. }
  47. void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; }
  48. Status TensorOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
  49. if (inputs.size() != NumInput())
  50. return Status(StatusCode::kUnexpectedError,
  51. "The size of the input argument vector does not match the number of inputs");
  52. outputs = inputs;
  53. return Status::OK();
  54. }
  55. Status TensorOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
  56. if (inputs.size() != NumInput())
  57. return Status(StatusCode::kUnexpectedError,
  58. "The size of the input argument vector does not match the number of inputs");
  59. outputs = inputs;
  60. return Status::OK();
  61. }
  62. } // namespace dataset
  63. } // namespace mindspore