You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

execute.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "minddata/dataset/include/execute.h"
  17. #include "minddata/dataset/core/de_tensor.h"
  18. #include "minddata/dataset/core/tensor_row.h"
  19. #include "minddata/dataset/include/tensor.h"
  20. #include "minddata/dataset/include/type_id.h"
  21. #include "minddata/dataset/kernels/tensor_op.h"
  22. #ifndef ENABLE_ANDROID
  23. #include "utils/log_adapter.h"
  24. #else
  25. #include "mindspore/lite/src/common/log_adapter.h"
  26. #endif
  27. namespace mindspore {
  28. namespace dataset {
  29. Execute::Execute(std::shared_ptr<TensorOperation> op) { ops_.emplace_back(std::move(op)); }
  30. Execute::Execute(std::vector<std::shared_ptr<TensorOperation>> ops) : ops_(std::move(ops)) {}
  31. Status Execute::operator()(const mindspore::MSTensor &input, mindspore::MSTensor *output) {
  32. // Validate input tensor
  33. CHECK_FAIL_RETURN_UNEXPECTED(input.DataSize() > 0, "Input Tensor has no data");
  34. CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
  35. // Validate and build runtime ops
  36. std::vector<std::shared_ptr<TensorOp>> transforms;
  37. for (int32_t i = 0; i < ops_.size(); i++) {
  38. CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
  39. RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
  40. transforms.emplace_back(ops_[i]->Build());
  41. }
  42. // Convert mindspore::Tensor to dataset::Tensor
  43. std::shared_ptr<dataset::Tensor> de_tensor;
  44. Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(input.Shape()),
  45. MSTypeToDEType(static_cast<TypeId>(input.DataType())),
  46. (const uchar *)(input.Data().get()), input.DataSize(), &de_tensor);
  47. RETURN_IF_NOT_OK(rc);
  48. // Apply transforms on tensor
  49. for (auto &t : transforms) {
  50. std::shared_ptr<dataset::Tensor> de_output;
  51. RETURN_IF_NOT_OK(t->Compute(de_tensor, &de_output));
  52. // For next transform
  53. de_tensor = std::move(de_output);
  54. }
  55. // Convert dataset::Tensor to mindspore::Tensor
  56. CHECK_FAIL_RETURN_UNEXPECTED(de_tensor->HasData(), "Apply transform failed, output tensor has no data");
  57. *output = mindspore::MSTensor(std::make_shared<DETensor>(de_tensor));
  58. return Status::OK();
  59. }
  60. Status Execute::operator()(const std::vector<MSTensor> &input_tensor_list, std::vector<MSTensor> *output_tensor_list) {
  61. // Validate input tensor
  62. CHECK_FAIL_RETURN_UNEXPECTED(!input_tensor_list.empty(), "Input Tensor is not valid");
  63. for (auto &tensor : input_tensor_list) {
  64. CHECK_FAIL_RETURN_UNEXPECTED(tensor.DataSize() > 0, "Input Tensor has no data");
  65. }
  66. CHECK_FAIL_RETURN_UNEXPECTED(!ops_.empty(), "Input TensorOperation should be provided");
  67. // Validate and build runtime ops
  68. std::vector<std::shared_ptr<TensorOp>> transforms;
  69. for (int32_t i = 0; i < ops_.size(); i++) {
  70. CHECK_FAIL_RETURN_UNEXPECTED(ops_[i] != nullptr, "Input TensorOperation[" + std::to_string(i) + "] is null");
  71. RETURN_IF_NOT_OK(ops_[i]->ValidateParams());
  72. transforms.emplace_back(ops_[i]->Build());
  73. }
  74. TensorRow de_tensor_list;
  75. for (auto &tensor : input_tensor_list) {
  76. std::shared_ptr<dataset::Tensor> de_tensor;
  77. Status rc = dataset::Tensor::CreateFromMemory(dataset::TensorShape(tensor.Shape()),
  78. MSTypeToDEType(static_cast<TypeId>(tensor.DataType())),
  79. (const uchar *)(tensor.Data().get()), tensor.DataSize(), &de_tensor);
  80. RETURN_IF_NOT_OK(rc);
  81. de_tensor_list.emplace_back(std::move(de_tensor));
  82. }
  83. // Apply transforms on tensor
  84. for (auto &t : transforms) {
  85. TensorRow de_output_list;
  86. RETURN_IF_NOT_OK(t->Compute(de_tensor_list, &de_output_list));
  87. // For next transform
  88. de_tensor_list = std::move(de_output_list);
  89. }
  90. for (auto &tensor : de_tensor_list) {
  91. CHECK_FAIL_RETURN_UNEXPECTED(tensor->HasData(), "Apply transform failed, output tensor has no data");
  92. auto ms_tensor = mindspore::MSTensor(std::make_shared<DETensor>(tensor));
  93. output_tensor_list->emplace_back(ms_tensor);
  94. }
  95. CHECK_FAIL_RETURN_UNEXPECTED(!output_tensor_list->empty(), "Output Tensor is not valid");
  96. return Status::OK();
  97. }
  98. } // namespace dataset
  99. } // namespace mindspore