| @@ -146,9 +146,9 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, | |||||
| /// (Default = 0 means all samples.) | /// (Default = 0 means all samples.) | ||||
| /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal) | /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal) | ||||
| /// Can be any of: | /// Can be any of: | ||||
| /// ShuffleMode.kFalse - No shuffling is performed. | |||||
| /// ShuffleMode.kFiles - Shuffle files only. | |||||
| /// ShuffleMode.kGlobal - Shuffle both the files and samples. | |||||
| /// ShuffleMode::kFalse - No shuffling is performed. | |||||
| /// ShuffleMode::kFiles - Shuffle files only. | |||||
| /// ShuffleMode::kGlobal - Shuffle both the files and samples. | |||||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | ||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | /// \param[in] shard_id The shard ID within num_shards. This argument should be | ||||
| /// specified only when num_shards is also specified. (Default = 0) | /// specified only when num_shards is also specified. (Default = 0) | ||||
| @@ -5,6 +5,7 @@ add_library(utils OBJECT | |||||
| buddy.cc | buddy.cc | ||||
| cache_pool.cc | cache_pool.cc | ||||
| circular_pool.cc | circular_pool.cc | ||||
| data_helper.cc | |||||
| memory_pool.cc | memory_pool.cc | ||||
| cond_var.cc | cond_var.cc | ||||
| intrp_service.cc | intrp_service.cc | ||||
| @@ -0,0 +1,142 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/util/data_helper.h" | |||||
| #include <algorithm> | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <sstream> | |||||
| #include <nlohmann/json.hpp> | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| // Create a numbered json file from image folder | |||||
| Status DataHelper::CreateAlbum(const std::string &in_dir, const std::string &out_dir) { | |||||
| // in check | |||||
| Path base_dir = Path(in_dir); | |||||
| if (!base_dir.IsDirectory() || !base_dir.Exists()) { | |||||
| RETURN_STATUS_UNEXPECTED("Input dir is not a directory or doesn't exist"); | |||||
| } | |||||
| // check if output_dir exists and create it if it does not exist | |||||
| Path target_dir = Path(out_dir); | |||||
| RETURN_IF_NOT_OK(target_dir.CreateDirectory()); | |||||
| // iterate over in dir and create json for all images | |||||
| uint64_t index = 0; | |||||
| auto dir_it = Path::DirIterator::OpenDirectory(&base_dir); | |||||
| while (dir_it->hasNext()) { | |||||
| Path v = dir_it->next(); | |||||
| // check if found file fits image extension | |||||
| // create json file in output dir with the path | |||||
| std::string out_file = out_dir + "/" + std::to_string(index) + ".json"; | |||||
| UpdateValue(out_file, "image", v.toString(), out_file); | |||||
| index++; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // A print method typically used for debugging | |||||
| void DataHelper::Print(std::ostream &out) const { | |||||
| out << " Data Helper" | |||||
| << "\n"; | |||||
| } | |||||
| Status DataHelper::UpdateArray(const std::string &in_file, const std::string &key, | |||||
| const std::vector<std::string> &value, const std::string &out_file) { | |||||
| try { | |||||
| Path in = Path(in_file); | |||||
| nlohmann::json js; | |||||
| if (in.Exists()) { | |||||
| std::ifstream in_stream(in_file); | |||||
| MS_LOG(INFO) << "Filename: " << in_file << "."; | |||||
| in_stream >> js; | |||||
| in_stream.close(); | |||||
| } | |||||
| js[key] = value; | |||||
| MS_LOG(INFO) << "Write outfile is: " << js << "."; | |||||
| if (out_file == "") { | |||||
| std::ofstream o(in_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } else { | |||||
| std::ofstream o(out_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Update json failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status DataHelper::RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file) { | |||||
| try { | |||||
| Path in = Path(in_file); | |||||
| nlohmann::json js; | |||||
| if (in.Exists()) { | |||||
| std::ifstream in_stream(in_file); | |||||
| MS_LOG(INFO) << "Filename: " << in_file << "."; | |||||
| in_stream >> js; | |||||
| in_stream.close(); | |||||
| } | |||||
| js.erase(key); | |||||
| MS_LOG(INFO) << "Write outfile is: " << js << "."; | |||||
| if (out_file == "") { | |||||
| std::ofstream o(in_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } else { | |||||
| std::ofstream o(out_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Update json failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| size_t DataHelper::DumpTensor(const std::shared_ptr<Tensor> &input, void *addr, const size_t &buffer_size) { | |||||
| // get tensor size | |||||
| size_t tensor_size = input->SizeInBytes(); | |||||
| // iterate over entire tensor | |||||
| const unsigned char *tensor_addr = input->GetBuffer(); | |||||
| // tensor iterator print | |||||
| // write to address, input order is: destination, source | |||||
| errno_t ret = memcpy_s(addr, buffer_size, tensor_addr, tensor_size); | |||||
| if (ret != 0) { | |||||
| // memcpy failed | |||||
| MS_LOG(ERROR) << "memcpy tensor memory failed" | |||||
| << "."; | |||||
| return 0; // amount of data copied is 0, error | |||||
| } | |||||
| return tensor_size; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,214 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_ | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include <nlohmann/json.hpp> | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// \brief Simple class to do data manipulation, contains helper function to update json files in dataset | |||||
| class DataHelper { | |||||
| public: | |||||
| /// \brief constructor | |||||
| DataHelper() {} | |||||
| /// \brief Destructor | |||||
| ~DataHelper() = default; | |||||
| /// \brief Create an Album dataset while taking in a path to a image folder | |||||
| /// Creates the output directory if doesn't exist | |||||
| /// \param[in] in_dir Image folder directory that takes in images | |||||
| /// \param[in] out_dir Directory containing output json files | |||||
| Status CreateAlbum(const std::string &in_dir, const std::string &out_dir); | |||||
| /// \brief Update a json file field with a vector of integers | |||||
| /// \param in_file The input file name to read in | |||||
| /// \param key Key of field to write to | |||||
| /// \param value Value array to write to file | |||||
| /// \param out_file Optional input for output file path, will write to input file if not specified | |||||
| /// \return Status The error code return | |||||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value, | |||||
| const std::string &out_file = ""); | |||||
| /// \brief Update a json file field with a vector of type T values | |||||
| /// \param in_file The input file name to read in | |||||
| /// \param key Key of field to write to | |||||
| /// \param value Value array to write to file | |||||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | |||||
| /// \return Status The error code return | |||||
| template <typename T> | |||||
| Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<T> &value, | |||||
| const std::string &out_file = "") { | |||||
| try { | |||||
| Path in = Path(in_file); | |||||
| nlohmann::json js; | |||||
| if (in.Exists()) { | |||||
| std::ifstream in(in_file); | |||||
| MS_LOG(INFO) << "Filename: " << in_file << "."; | |||||
| in >> js; | |||||
| in.close(); | |||||
| } | |||||
| js[key] = value; | |||||
| MS_LOG(INFO) << "Write outfile is: " << js << "."; | |||||
| if (out_file == "") { | |||||
| std::ofstream o(in_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } else { | |||||
| std::ofstream o(out_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Update json failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief Update a json file field with a single value of of type T | |||||
| /// \param in_file The input file name to read in | |||||
| /// \param key Key of field to write to | |||||
| /// \param value Value to write to file | |||||
| /// \param out_file Optional parameter for output file path, will write to input file if not specified | |||||
| /// \return Status The error code return | |||||
| template <typename T> | |||||
| Status UpdateValue(const std::string &in_file, const std::string &key, const T &value, | |||||
| const std::string &out_file = "") { | |||||
| try { | |||||
| Path in = Path(in_file); | |||||
| nlohmann::json js; | |||||
| if (in.Exists()) { | |||||
| std::ifstream in(in_file); | |||||
| MS_LOG(INFO) << "Filename: " << in_file << "."; | |||||
| in >> js; | |||||
| in.close(); | |||||
| } | |||||
| js[key] = value; | |||||
| MS_LOG(INFO) << "Write outfile is: " << js << "."; | |||||
| if (out_file == "") { | |||||
| std::ofstream o(in_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } else { | |||||
| std::ofstream o(out_file, std::ofstream::trunc); | |||||
| o << js; | |||||
| o.close(); | |||||
| } | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Update json failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief Template function to write tensor to file | |||||
| /// \param[in] in_file File to write to | |||||
| /// \param[in] data Array of type T values | |||||
| /// \return Status The error code return | |||||
| template <typename T> | |||||
| Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) { | |||||
| try { | |||||
| std::ofstream o(in_file, std::ios::binary | std::ios::out); | |||||
| if (!o.is_open()) { | |||||
| RETURN_STATUS_UNEXPECTED("Error opening Bin file to write"); | |||||
| } | |||||
| size_t length = data.size(); | |||||
| o.write(reinterpret_cast<const char *>(&data[0]), std::streamsize(length * sizeof(T))); | |||||
| o.close(); | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Write bin file failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief Write pointer to bin, use pointer to avoid memcpy | |||||
| /// \param[in] in_file File name to write to | |||||
| /// \param[in] data Pointer to data | |||||
| /// \param[in] length Length of values to write from pointer | |||||
| /// \return Status The error code return | |||||
| template <typename T> | |||||
| Status WriteBinFile(const std::string &in_file, T *data, size_t length) { | |||||
| try { | |||||
| std::ofstream o(in_file, std::ios::binary | std::ios::out); | |||||
| if (!o.is_open()) { | |||||
| RETURN_STATUS_UNEXPECTED("Error opening Bin file to write"); | |||||
| } | |||||
| o.write(reinterpret_cast<const char *>(data), std::streamsize(length * sizeof(T))); | |||||
| o.close(); | |||||
| } | |||||
| // Catch any exception and convert to Status return code | |||||
| catch (const std::exception &err) { | |||||
| RETURN_STATUS_UNEXPECTED("Write bin file failed "); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| /// \brief Helper function to copy content of a tensor to buffer | |||||
| /// \note This function iterates over the tensor in bytes, since | |||||
| /// \param[in] input The tensor to copy value from | |||||
| /// \param[out] addr The address to copy tensor data to | |||||
| /// \param[in] buffer_size The buffer size of addr | |||||
| /// \return The size of the tensor (bytes copied | |||||
| size_t DumpTensor(const std::shared_ptr<Tensor> &input, void *addr, const size_t &buffer_size); | |||||
| /// \brief Helper function to delete key in json file | |||||
| /// note This function will return okay even if key not found | |||||
| /// \param[in] in_file Json file to remove key from | |||||
| /// \param[in] key The key to remove | |||||
| /// \return Status The error code return | |||||
| Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = ""); | |||||
| /// \brief A print method typically used for debugging | |||||
| /// \param out - The output stream to write output to | |||||
| void Print(std::ostream &out) const; | |||||
| /// \brief << Stream output operator overload | |||||
| /// \notes This allows you to write the debug print info using stream operators | |||||
| /// \param out Reference to the output stream being overloaded | |||||
| /// \param ds Reference to the DataSchema to display | |||||
| /// \return The output stream must be returned | |||||
| friend std::ostream &operator<<(std::ostream &out, const DataHelper &dh) { | |||||
| dh.Print(out); | |||||
| return out; | |||||
| } | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_DATA_HELPER_H_ | |||||
| @@ -121,6 +121,7 @@ SET(DE_UT_SRCS | |||||
| solarize_op_test.cc | solarize_op_test.cc | ||||
| swap_red_blue_test.cc | swap_red_blue_test.cc | ||||
| distributed_sampler_test.cc | distributed_sampler_test.cc | ||||
| data_helper_test.cc | |||||
| ) | ) | ||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| @@ -0,0 +1,195 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/common.h" | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||||
| #include "minddata/dataset/util/data_helper.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "securec.h" | |||||
| using namespace mindspore::dataset; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::LogStream; | |||||
| class MindDataTestDataHelper : public UT::DatasetOpTesting { | |||||
| protected: | |||||
| }; | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestHelper) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/images/1.json"; | |||||
| DataHelper dh; | |||||
| std::vector<std::string> new_label = {"3", "4"}; | |||||
| Status rc = dh.UpdateArray(file_path, "label", new_label); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during label update: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestAlbumGen) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/original"; | |||||
| std::string out_path = datasets_root_path_ + "/testAlbum/testout"; | |||||
| DataHelper dh; | |||||
| Status rc = dh.CreateAlbum(file_path, out_path); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during album generation: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTemplateUpdateArrayInt) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/testout/2.json"; | |||||
| DataHelper dh; | |||||
| std::vector<int> new_label = {3, 4}; | |||||
| Status rc = dh.UpdateArray(file_path, "label", new_label); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during json int array update: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTemplateUpdateArrayString) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/testout/3.json"; | |||||
| DataHelper dh; | |||||
| std::vector<std::string> new_label = {"3", "4"}; | |||||
| Status rc = dh.UpdateArray(file_path, "label", new_label); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during json string array update: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTemplateUpdateValueInt) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/testout/4.json"; | |||||
| DataHelper dh; | |||||
| int new_label = 3; | |||||
| Status rc = dh.UpdateValue(file_path, "label", new_label); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during json int update: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTemplateUpdateString) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/testout/5.json"; | |||||
| DataHelper dh; | |||||
| std::string new_label = "new label"; | |||||
| Status rc = dh.UpdateValue(file_path, "label", new_label); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during json string update: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestDeleteKey) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/testout/5.json"; | |||||
| DataHelper dh; | |||||
| Status rc = dh.RemoveKey(file_path, "label"); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during json key remove: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestBinWrite) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/1.bin"; | |||||
| DataHelper dh; | |||||
| std::vector<float> bin_content = {3, 4}; | |||||
| Status rc = dh.WriteBinFile(file_path, bin_content); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during bin file write: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestBinWritePointer) { | |||||
| std::string file_path = datasets_root_path_ + "/testAlbum/2.bin"; | |||||
| DataHelper dh; | |||||
| std::vector<float> bin_content = {3, 4}; | |||||
| Status rc = dh.WriteBinFile(file_path, &bin_content[0], bin_content.size()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "Return code error detected during binfile write: " << "."; | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTensorWriteFloat) { | |||||
| // create tensor | |||||
| std::vector<float> y = {2.5, 3.0, 3.5, 4.0}; | |||||
| std::shared_ptr<Tensor> t; | |||||
| Tensor::CreateFromVector(y, &t); | |||||
| // create buffer using system mempool | |||||
| DataHelper dh; | |||||
| void *data = malloc(t->SizeInBytes()); | |||||
| auto bytes_copied = dh.DumpTensor(std::move(t), data, t->SizeInBytes()); | |||||
| if (bytes_copied != t->SizeInBytes()) { | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| float *array = static_cast<float *>(data); | |||||
| if (array[0] != 2.5) { EXPECT_TRUE(false); } | |||||
| if (array[1] != 3.0) { EXPECT_TRUE(false); } | |||||
| if (array[2] != 3.5) { EXPECT_TRUE(false); } | |||||
| if (array[3] != 4.0) { EXPECT_TRUE(false); } | |||||
| std::free(data); | |||||
| } | |||||
| TEST_F(MindDataTestDataHelper, MindDataTestTensorWriteUInt) { | |||||
| // create tensor | |||||
| std::vector<uint8_t> y = {1, 2, 3, 4}; | |||||
| std::shared_ptr<Tensor> t; | |||||
| Tensor::CreateFromVector(y, &t); | |||||
| uint8_t o; | |||||
| t->GetItemAt<uint8_t>(&o, {0, 0}); | |||||
| MS_LOG(INFO) << "before op :" << std::to_string(o) << "."; | |||||
| // create buffer using system mempool | |||||
| DataHelper dh; | |||||
| void *data = malloc(t->SizeInBytes()); | |||||
| auto bytes_copied = dh.DumpTensor(t, data, t->SizeInBytes()); | |||||
| if (bytes_copied != t->SizeInBytes()) { | |||||
| EXPECT_TRUE(false); | |||||
| } | |||||
| t->GetItemAt<uint8_t>(&o, {}); | |||||
| MS_LOG(INFO) << "after op :" << std::to_string(o) << "."; | |||||
| uint8_t *array = static_cast<uint8_t *>(data); | |||||
| if (array[0] != 1) { EXPECT_TRUE(false); } | |||||
| if (array[1] != 2) { EXPECT_TRUE(false); } | |||||
| if (array[2] != 3) { EXPECT_TRUE(false); } | |||||
| if (array[3] != 4) { EXPECT_TRUE(false); } | |||||
| std::free(data); | |||||
| } | |||||