You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cc 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common.h"
  17. #include <algorithm>
  18. #include <fstream>
  19. #include <string>
  20. #include <vector>
  21. #include "minddata/dataset/core/client.h"
  22. #include "minddata/dataset/core/config_manager.h"
  23. #include "minddata/dataset/core/pybind_support.h"
  24. #include "minddata/dataset/core/tensor.h"
  25. #include "minddata/dataset/core/tensor_shape.h"
  26. #include "minddata/dataset/engine/datasetops/batch_op.h"
  27. #include "minddata/dataset/engine/datasetops/repeat_op.h"
  28. #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
  29. namespace UT {
  30. #ifdef __cplusplus
  31. #if __cplusplus
  32. extern "C" {
  33. #endif
  34. #endif
  35. void DatasetOpTesting::SetUp() {
  36. std::string install_home = "data/dataset";
  37. datasets_root_path_ = install_home;
  38. mindrecord_root_path_ = "data/mindrecord";
  39. }
  40. std::vector<mindspore::dataset::TensorShape> DatasetOpTesting::ToTensorShapeVec(
  41. const std::vector<std::vector<int64_t>> &v) {
  42. std::vector<mindspore::dataset::TensorShape> ret_v;
  43. std::transform(v.begin(), v.end(), std::back_inserter(ret_v),
  44. [](const auto &s) { return mindspore::dataset::TensorShape(s); });
  45. return ret_v;
  46. }
  47. std::vector<mindspore::dataset::DataType> DatasetOpTesting::ToDETypes(const std::vector<mindspore::DataType> &t) {
  48. std::vector<mindspore::dataset::DataType> ret_t;
  49. std::transform(t.begin(), t.end(), std::back_inserter(ret_t), [](const mindspore::DataType &t) {
  50. return mindspore::dataset::MSTypeToDEType(static_cast<mindspore::TypeId>(t));
  51. });
  52. return ret_t;
  53. }
  54. // Function to read a file into an MSTensor
  55. // Note: This provides the analogous support for DETensor's CreateFromFile.
  56. mindspore::MSTensor DatasetOpTesting::ReadFileToTensor(const std::string &file) {
  57. if (file.empty()) {
  58. MS_LOG(ERROR) << "Pointer file is nullptr; return an empty Tensor.";
  59. return mindspore::MSTensor();
  60. }
  61. std::ifstream ifs(file);
  62. if (!ifs.good()) {
  63. MS_LOG(ERROR) << "File: " << file << " does not exist; return an empty Tensor.";
  64. return mindspore::MSTensor();
  65. }
  66. if (!ifs.is_open()) {
  67. MS_LOG(ERROR) << "File: " << file << " open failed; return an empty Tensor.";
  68. return mindspore::MSTensor();
  69. }
  70. ifs.seekg(0, std::ios::end);
  71. size_t size = ifs.tellg();
  72. mindspore::MSTensor buf("file", mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
  73. ifs.seekg(0, std::ios::beg);
  74. ifs.read(reinterpret_cast<char *>(buf.MutableData()), size);
  75. ifs.close();
  76. return buf;
  77. }
  78. // Helper function to create a batch op
  79. std::shared_ptr<mindspore::dataset::BatchOp> DatasetOpTesting::Batch(int32_t batch_size, bool drop,
  80. mindspore::dataset::PadInfo pad_map) {
  81. /*
  82. std::shared_ptr<mindspore::dataset::ConfigManager> cfg = mindspore::dataset::GlobalContext::config_manager();
  83. int32_t num_workers = cfg->num_parallel_workers();
  84. int32_t op_connector_size = cfg->op_connector_size();
  85. std::vector<std::string> output_columns = {};
  86. std::vector<std::string> input_columns = {};
  87. mindspore::dataset::py::function batch_size_func;
  88. mindspore::dataset::py::function batch_map_func;
  89. bool pad = false;
  90. if (!pad_map.empty()) {
  91. pad = true;
  92. }
  93. std::shared_ptr<mindspore::dataset::BatchOp> op =
  94. std::make_shared<mindspore::dataset::BatchOp>(batch_size, drop, pad, op_connector_size, num_workers, input_columns,
  95. output_columns, batch_size_func, batch_map_func, pad_map); return op;
  96. */
  97. Status rc;
  98. std::shared_ptr<mindspore::dataset::BatchOp> op;
  99. rc = mindspore::dataset::BatchOp::Builder(batch_size).SetDrop(drop).SetPaddingMap(pad_map).Build(&op);
  100. EXPECT_TRUE(rc.IsOk());
  101. return std::move(op);
  102. }
  103. std::shared_ptr<mindspore::dataset::RepeatOp> DatasetOpTesting::Repeat(int repeat_cnt) {
  104. std::shared_ptr<mindspore::dataset::RepeatOp> op = std::make_shared<mindspore::dataset::RepeatOp>(repeat_cnt);
  105. return std::move(op);
  106. }
  107. std::shared_ptr<mindspore::dataset::TFReaderOp> DatasetOpTesting::TFReader(std::string file, int num_works) {
  108. std::shared_ptr<mindspore::dataset::ConfigManager> config_manager =
  109. mindspore::dataset::GlobalContext::config_manager();
  110. auto op_connector_size = config_manager->op_connector_size();
  111. auto worker_connector_size = config_manager->worker_connector_size();
  112. std::vector<std::string> columns_to_load = {};
  113. std::vector<std::string> files = {file};
  114. std::shared_ptr<mindspore::dataset::TFReaderOp> so = std::make_shared<mindspore::dataset::TFReaderOp>(
  115. num_works, worker_connector_size, 0, files, std::make_unique<mindspore::dataset::DataSchema>(), op_connector_size,
  116. columns_to_load, false, 1, 0, false);
  117. (void)so->Init();
  118. return std::move(so);
  119. }
  120. std::shared_ptr<mindspore::dataset::ExecutionTree> DatasetOpTesting::Build(
  121. std::vector<std::shared_ptr<mindspore::dataset::DatasetOp>> ops) {
  122. std::shared_ptr<mindspore::dataset::ExecutionTree> tree = std::make_shared<mindspore::dataset::ExecutionTree>();
  123. for (int i = 0; i < ops.size(); i++) {
  124. tree->AssociateNode(ops[i]);
  125. if (i > 0) {
  126. ops[i]->AddChild(std::move(ops[i - 1]));
  127. }
  128. if (i == ops.size() - 1) {
  129. tree->AssignRoot(ops[i]);
  130. }
  131. }
  132. return std::move(tree);
  133. }
  134. #ifdef __cplusplus
  135. #if __cplusplus
  136. }
  137. #endif
  138. #endif
  139. } // namespace UT