You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cc 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common.h"
  17. #include <algorithm>
  18. #include <fstream>
  19. #include <string>
  20. #include <vector>
  21. namespace UT {
  22. #ifdef __cplusplus
  23. #if __cplusplus
  24. extern "C" {
  25. #endif
  26. #endif
  27. void DatasetOpTesting::SetUp() {
  28. std::string install_home = "data/dataset";
  29. datasets_root_path_ = install_home;
  30. mindrecord_root_path_ = "data/mindrecord";
  31. }
  32. std::vector<mindspore::dataset::TensorShape> DatasetOpTesting::ToTensorShapeVec(
  33. const std::vector<std::vector<int64_t>> &v) {
  34. std::vector<mindspore::dataset::TensorShape> ret_v;
  35. std::transform(v.begin(), v.end(), std::back_inserter(ret_v),
  36. [](const auto &s) { return mindspore::dataset::TensorShape(s); });
  37. return ret_v;
  38. }
  39. std::vector<mindspore::dataset::DataType> DatasetOpTesting::ToDETypes(const std::vector<mindspore::DataType> &t) {
  40. std::vector<mindspore::dataset::DataType> ret_t;
  41. std::transform(t.begin(), t.end(), std::back_inserter(ret_t), [](const mindspore::DataType &t) {
  42. return mindspore::dataset::MSTypeToDEType(static_cast<mindspore::TypeId>(t));
  43. });
  44. return ret_t;
  45. }
  46. // Function to read a file into an MSTensor
  47. // Note: This provides the analogous support for DETensor's CreateFromFile.
  48. mindspore::MSTensor DatasetOpTesting::ReadFileToTensor(const std::string &file) {
  49. if (file.empty()) {
  50. MS_LOG(ERROR) << "Pointer file is nullptr; return an empty Tensor.";
  51. return mindspore::MSTensor();
  52. }
  53. std::ifstream ifs(file);
  54. if (!ifs.good()) {
  55. MS_LOG(ERROR) << "File: " << file << " does not exist; return an empty Tensor.";
  56. return mindspore::MSTensor();
  57. }
  58. if (!ifs.is_open()) {
  59. MS_LOG(ERROR) << "File: " << file << " open failed; return an empty Tensor.";
  60. return mindspore::MSTensor();
  61. }
  62. ifs.seekg(0, std::ios::end);
  63. size_t size = ifs.tellg();
  64. mindspore::MSTensor buf("file", mindspore::DataType::kNumberTypeUInt8, {static_cast<int64_t>(size)}, nullptr, size);
  65. ifs.seekg(0, std::ios::beg);
  66. ifs.read(reinterpret_cast<char *>(buf.MutableData()), size);
  67. ifs.close();
  68. return buf;
  69. }
  70. #ifdef __cplusplus
  71. #if __cplusplus
  72. }
  73. #endif
  74. #endif
  75. } // namespace UT