You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

minddata_eager.cc 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <unistd.h>
  17. #include <unordered_map>
  18. #include "minddata/dataset/include/minddata_eager.h"
  19. #include "minddata/dataset/include/vision.h"
  20. #include "minddata/dataset/core/tensor.h"
  21. #include "minddata/dataset/kernels/tensor_op.h"
  22. #include "minddata/dataset/util/path.h"
  23. namespace mindspore {
  24. namespace api {
  25. MindDataEager::MindDataEager(std::vector<std::shared_ptr<dataset::TensorOperation>> ops) : ops_(ops) {}
  26. // Helper function to convert Type from DE to MS
  27. DataType ToMSType(dataset::DataType type) {
  28. switch (dataset::DataType::Type(type)) {
  29. case dataset::DataType::DE_BOOL:
  30. return DataType::kMsBool;
  31. case dataset::DataType::DE_UINT8:
  32. return DataType::kMsUint8;
  33. case dataset::DataType::DE_INT32:
  34. return DataType::kMsInt32;
  35. case dataset::DataType::DE_INT64:
  36. return DataType::kMsInt64;
  37. case dataset::DataType::DE_FLOAT32:
  38. return DataType::kMsFloat32;
  39. default:
  40. return DataType::kMsUnknown;
  41. }
  42. }
  43. // Helper function to convert Type from MS to DE
  44. dataset::DataType ToDEType(DataType type) {
  45. switch (type) {
  46. case DataType::kMsBool:
  47. return dataset::DataType(dataset::DataType::DE_BOOL);
  48. case DataType::kMsUint8:
  49. return dataset::DataType(dataset::DataType::DE_UINT8);
  50. case DataType::kMsInt32:
  51. return dataset::DataType(dataset::DataType::DE_INT32);
  52. case DataType::kMsInt64:
  53. return dataset::DataType(dataset::DataType::DE_INT64);
  54. case DataType::kMsFloat32:
  55. return dataset::DataType(dataset::DataType::DE_FLOAT32);
  56. default:
  57. return dataset::DataType(dataset::DataType::DE_UNKNOWN);
  58. }
  59. }
  60. Status MindDataEager::LoadImageFromDir(const std::string &image_dir, std::vector<std::shared_ptr<Tensor>> *images) {
  61. // Check target directory
  62. dataset::Path image_dir_(image_dir);
  63. if (!image_dir_.Exists() || !image_dir_.IsDirectory()) {
  64. std::string err_msg = "Target directory: " + image_dir + " does not exist or not a directory.";
  65. MS_LOG(ERROR) << err_msg;
  66. return Status(StatusCode::FAILED, err_msg);
  67. }
  68. if (access(image_dir_.toString().c_str(), R_OK) == -1) {
  69. std::string err_msg = "No access to target directory: " + image_dir;
  70. MS_LOG(ERROR) << err_msg;
  71. return Status(StatusCode::FAILED, err_msg);
  72. }
  73. // Start reading images and constructing tensors
  74. auto path_itr = dataset::Path::DirIterator::OpenDirectory(&image_dir_);
  75. while (path_itr->hasNext()) {
  76. dataset::Path file = path_itr->next();
  77. std::shared_ptr<dataset::Tensor> image;
  78. dataset::Tensor::CreateFromFile(file.toString(), &image);
  79. std::shared_ptr<Tensor> ms_image = std::make_shared<Tensor>("image", DataType(kMsUint8), image->shape().AsVector(),
  80. image->GetBuffer(), image->SizeInBytes());
  81. images->push_back(ms_image);
  82. }
  83. // Check if read images or not
  84. if (images->empty()) {
  85. std::string err_msg = "No images found in target directory: " + image_dir;
  86. MS_LOG(ERROR) << err_msg;
  87. return Status(StatusCode::FAILED, err_msg);
  88. }
  89. return Status(StatusCode::SUCCESS);
  90. }
  91. std::shared_ptr<Tensor> MindDataEager::operator()(std::shared_ptr<Tensor> input) {
  92. // Validate ops
  93. if (ops_.empty()) {
  94. MS_LOG(ERROR) << "Input TensorOperation should be provided";
  95. return nullptr;
  96. }
  97. for (int32_t i = 0; i < ops_.size(); i++) {
  98. if (ops_[i] == nullptr) {
  99. MS_LOG(ERROR) << "Input TensorOperation[" << i << "] is invalid or null";
  100. return nullptr;
  101. }
  102. }
  103. // Validate input tensor
  104. if (input == nullptr) {
  105. MS_LOG(ERROR) << "Input Tensor should not be null";
  106. return nullptr;
  107. }
  108. // Start applying transforms in ops
  109. std::shared_ptr<dataset::Tensor> de_input;
  110. dataset::Tensor::CreateFromMemory(dataset::TensorShape(input->Shape()), ToDEType(input->DataType()),
  111. (const uchar *)(input->Data()), &de_input);
  112. for (int32_t i = 0; i < ops_.size(); i++) {
  113. // Build runtime op and run
  114. std::shared_ptr<dataset::Tensor> de_output;
  115. std::shared_ptr<dataset::TensorOp> transform = ops_[i]->Build();
  116. dataset::Status rc = transform->Compute(de_input, &de_output);
  117. // check execution failed
  118. if (rc.IsError()) {
  119. MS_LOG(ERROR) << "Operation execution failed : " << rc.ToString();
  120. return nullptr;
  121. }
  122. // For next transform
  123. de_input = std::move(de_output);
  124. }
  125. // Convert DETensor to Tensor
  126. if (!de_input->HasData()) {
  127. MS_LOG(ERROR) << "Apply transform failed, output tensor has no data";
  128. return nullptr;
  129. }
  130. std::shared_ptr<Tensor> output =
  131. std::make_shared<Tensor>("transfomed", ToMSType(de_input->type()), de_input->shape().AsVector(),
  132. de_input->GetBuffer(), de_input->SizeInBytes());
  133. return output;
  134. }
  135. } // namespace api
  136. } // namespace mindspore