Merge pull request !585 from panfengfeng/revert-merge-182-mastertags/v0.2.0-alpha
| @@ -1,46 +0,0 @@ | |||
| # MindRecord generating guidelines | |||
| <!-- TOC --> | |||
| - [MindRecord generating guidelines](#mindrecord-generating-guidelines) | |||
| - [Create work space](#create-work-space) | |||
| - [Implement data generator](#implement-data-generator) | |||
| - [Run data generator](#run-data-generator) | |||
| <!-- /TOC --> | |||
| ## Create work space | |||
| Assume the dataset name is 'xyz' | |||
| * Create work space from template | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| cp -r template xyz | |||
| ``` | |||
| ## Implement data generator | |||
| Edit dictionary data generator | |||
| * Edit file | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| vi xyz/mr_api.py | |||
| ``` | |||
| Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented | |||
| - 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. | |||
| - 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() | |||
| Tricky for parallel run | |||
| - For imagenet, one directory can be a task. | |||
| - For TFRecord with multiple files, each file can be a task. | |||
| - For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) | |||
| ## Run data generator | |||
| * run python script | |||
| ```shell | |||
| cd ${your_mindspore_home}/example/convert_to_mindrecord | |||
| python writer.py --mindrecord_script imagenet [...] | |||
| ``` | |||
| @@ -1,122 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| User-defined API for MindRecord writer. | |||
| Two API must be implemented, | |||
| 1. mindrecord_task_number() | |||
| # Return number of parallel tasks. return 1 if no parallel | |||
| 2. mindrecord_dict_data(task_id) | |||
| # Yield data for one task | |||
| # task_id is 0..N-1, if N is return value of mindrecord_task_number() | |||
| """ | |||
| import argparse | |||
| import os | |||
| import pickle | |||
| ######## mindrecord_schema begin ########## | |||
| mindrecord_schema = {"label": {"type": "int64"}, | |||
| "data": {"type": "bytes"}, | |||
| "file_name": {"type": "string"}} | |||
| ######## mindrecord_schema end ########## | |||
| ######## Frozen code begin ########## | |||
| with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: | |||
| ARG_LIST = pickle.load(mindrecord_argument_file_handle) | |||
| ######## Frozen code end ########## | |||
| parser = argparse.ArgumentParser(description='Mind record imagenet example') | |||
| parser.add_argument('--label_file', type=str, default="", help='label file') | |||
| parser.add_argument('--image_dir', type=str, default="", help='images directory') | |||
| ######## Frozen code begin ########## | |||
| args = parser.parse_args(ARG_LIST) | |||
| print(args) | |||
| ######## Frozen code end ########## | |||
| def _user_defined_private_func(): | |||
| """ | |||
| Internal function for tasks list | |||
| Return: | |||
| tasks list | |||
| """ | |||
| if not os.path.exists(args.label_file): | |||
| raise IOError("map file {} not exists".format(args.label_file)) | |||
| label_dict = {} | |||
| with open(args.label_file) as file_handle: | |||
| line = file_handle.readline() | |||
| while line: | |||
| labels = line.split(" ") | |||
| label_dict[labels[1]] = labels[0] | |||
| line = file_handle.readline() | |||
| # get all the dir which are n02087046, n02094114, n02109525 | |||
| dir_paths = {} | |||
| for item in label_dict: | |||
| real_path = os.path.join(args.image_dir, label_dict[item]) | |||
| if not os.path.isdir(real_path): | |||
| print("{} dir is not exist".format(real_path)) | |||
| continue | |||
| dir_paths[item] = real_path | |||
| if not dir_paths: | |||
| print("not valid image dir in {}".format(args.image_dir)) | |||
| return {}, {} | |||
| dir_list = [] | |||
| for label in dir_paths: | |||
| dir_list.append(label) | |||
| return dir_list, dir_paths | |||
| dir_list_global, dir_paths_global = _user_defined_private_func() | |||
| def mindrecord_task_number(): | |||
| """ | |||
| Get task size. | |||
| Return: | |||
| number of tasks | |||
| """ | |||
| return len(dir_list_global) | |||
| def mindrecord_dict_data(task_id): | |||
| """ | |||
| Get data dict. | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| # get the filename, label and image binary as a dict | |||
| label = dir_list_global[task_id] | |||
| for item in os.listdir(dir_paths_global[label]): | |||
| file_name = os.path.join(dir_paths_global[label], item) | |||
| if not item.endswith("JPEG") and not item.endswith( | |||
| "jpg") and not item.endswith("jpeg"): | |||
| print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) | |||
| continue | |||
| data = {} | |||
| data["file_name"] = str(file_name) | |||
| data["label"] = int(label) | |||
| # get the image data | |||
| image_file = open(file_name, "rb") | |||
| image_bytes = image_file.read() | |||
| image_file.close() | |||
| data["data"] = image_bytes | |||
| yield data | |||
| @@ -1,8 +0,0 @@ | |||
| #!/bin/bash | |||
| rm /tmp/imagenet/mr/* | |||
| python writer.py --mindrecord_script imagenet \ | |||
| --mindrecord_file "/tmp/imagenet/mr/m" \ | |||
| --mindrecord_partitions 16 \ | |||
| --label_file "/tmp/imagenet/label.txt" \ | |||
| --image_dir "/tmp/imagenet/jpeg" | |||
| @@ -1,6 +0,0 @@ | |||
| #!/bin/bash | |||
| rm /tmp/template/* | |||
| python writer.py --mindrecord_script template \ | |||
| --mindrecord_file "/tmp/template/m" \ | |||
| --mindrecord_partitions 4 | |||
| @@ -1,73 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| User-defined API for MindRecord writer. | |||
| Two API must be implemented, | |||
| 1. mindrecord_task_number() | |||
| # Return number of parallel tasks. return 1 if no parallel | |||
| 2. mindrecord_dict_data(task_id) | |||
| # Yield data for one task | |||
| # task_id is 0..N-1, if N is return value of mindrecord_task_number() | |||
| """ | |||
| import argparse | |||
| import pickle | |||
| # ## Parse argument | |||
| with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line | |||
| ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line | |||
| parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line | |||
| # ## Your arguments below | |||
| # parser.add_argument(...) | |||
| args = parser.parse_args(ARG_LIST) # Do NOT change this line | |||
| print(args) # Do NOT change this line | |||
| # ## Default mindrecord vars. Comment them unless default value has to be changed. | |||
| # mindrecord_index_fields = ['label'] | |||
| # mindrecord_header_size = 1 << 24 | |||
| # mindrecord_page_size = 1 << 25 | |||
| # define global vars here if necessary | |||
| # ####### Your code below ########## | |||
| mindrecord_schema = {"label": {"type": "int32"}} | |||
| def mindrecord_task_number(): | |||
| """ | |||
| Get task size. | |||
| Return: | |||
| number of tasks | |||
| """ | |||
| return 1 | |||
| def mindrecord_dict_data(task_id): | |||
| """ | |||
| Get data dict. | |||
| Yields: | |||
| data (dict): data row which is dict. | |||
| """ | |||
| print("task is {}".format(task_id)) | |||
| for i in range(256): | |||
| data = {} | |||
| data['label'] = i | |||
| yield data | |||
| @@ -1,149 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| ######################## write mindrecord example ######################## | |||
| Write mindrecord by data dictionary: | |||
| python writer.py --mindrecord_script /YourScriptPath ... | |||
| """ | |||
| import argparse | |||
| import os | |||
| import pickle | |||
| import time | |||
| from importlib import import_module | |||
| from multiprocessing import Pool | |||
| from mindspore.mindrecord import FileWriter | |||
| def _exec_task(task_id, parallel_writer=True): | |||
| """ | |||
| Execute task with specified task id | |||
| """ | |||
| print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) | |||
| imagenet_iter = mindrecord_dict_data(task_id) | |||
| batch_size = 2048 | |||
| transform_count = 0 | |||
| while True: | |||
| data_list = [] | |||
| try: | |||
| for _ in range(batch_size): | |||
| data_list.append(imagenet_iter.__next__()) | |||
| transform_count += 1 | |||
| writer.write_raw_data(data_list, parallel_writer=parallel_writer) | |||
| print("transformed {} record...".format(transform_count)) | |||
| except StopIteration: | |||
| if data_list: | |||
| writer.write_raw_data(data_list, parallel_writer=parallel_writer) | |||
| print("transformed {} record...".format(transform_count)) | |||
| break | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='Mind record writer') | |||
| parser.add_argument('--mindrecord_script', type=str, default="template", | |||
| help='path where script is saved') | |||
| parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", | |||
| help='written file name prefix') | |||
| parser.add_argument('--mindrecord_partitions', type=int, default=1, | |||
| help='number of written files') | |||
| parser.add_argument('--mindrecord_workers', type=int, default=8, | |||
| help='number of parallel workers') | |||
| args = parser.parse_known_args() | |||
| args, other_args = parser.parse_known_args() | |||
| print(args) | |||
| print(other_args) | |||
| with open('mr_argument.pickle', 'wb') as file_handle: | |||
| pickle.dump(other_args, file_handle) | |||
| try: | |||
| mr_api = import_module(args.mindrecord_script + '.mr_api') | |||
| except ModuleNotFoundError: | |||
| raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) | |||
| num_tasks = mr_api.mindrecord_task_number() | |||
| print("Write mindrecord ...") | |||
| mindrecord_dict_data = mr_api.mindrecord_dict_data | |||
| # get number of files | |||
| writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) | |||
| start_time = time.time() | |||
| # set the header size | |||
| try: | |||
| header_size = mr_api.mindrecord_header_size | |||
| writer.set_header_size(header_size) | |||
| except AttributeError: | |||
| print("Default header size: {}".format(1 << 24)) | |||
| # set the page size | |||
| try: | |||
| page_size = mr_api.mindrecord_page_size | |||
| writer.set_page_size(page_size) | |||
| except AttributeError: | |||
| print("Default page size: {}".format(1 << 25)) | |||
| # get schema | |||
| try: | |||
| mindrecord_schema = mr_api.mindrecord_schema | |||
| except AttributeError: | |||
| raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") | |||
| # create the schema | |||
| writer.add_schema(mindrecord_schema, "mindrecord_schema") | |||
| # add the index | |||
| try: | |||
| index_fields = mr_api.mindrecord_index_fields | |||
| writer.add_index(index_fields) | |||
| except AttributeError: | |||
| print("Default index fields: all simple fields are indexes.") | |||
| writer.open_and_set_header() | |||
| task_list = list(range(num_tasks)) | |||
| # set number of workers | |||
| num_workers = args.mindrecord_workers | |||
| if num_tasks < 1: | |||
| num_tasks = 1 | |||
| if num_workers > num_tasks: | |||
| num_workers = num_tasks | |||
| if num_tasks > 1: | |||
| with Pool(num_workers) as p: | |||
| p.map(_exec_task, task_list) | |||
| else: | |||
| _exec_task(0, False) | |||
| ret = writer.commit() | |||
| os.remove("{}".format("mr_argument.pickle")) | |||
| end_time = time.time() | |||
| print("--------------------------------------------") | |||
| print("END. Total time: {}".format(end_time - start_time)) | |||
| print("--------------------------------------------") | |||
| @@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) { | |||
| .def("set_header_size", &ShardWriter::set_header_size) | |||
| .def("set_page_size", &ShardWriter::set_page_size) | |||
| .def("set_shard_header", &ShardWriter::SetShardHeader) | |||
| .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | |||
| vector<vector<uint8_t>> &, bool, bool)) & | |||
| ShardWriter::WriteRawData) | |||
| .def("write_raw_data", | |||
| (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, vector<vector<uint8_t>> &, bool)) & | |||
| ShardWriter::WriteRawData) | |||
| .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | |||
| std::map<uint64_t, std::vector<py::handle>> &, bool)) & | |||
| ShardWriter::WriteRawData) | |||
| .def("commit", &ShardWriter::Commit); | |||
| } | |||
| @@ -121,10 +121,6 @@ class ShardHeader { | |||
| std::vector<std::string> SerializeHeader(); | |||
| MSRStatus PagesToFile(const std::string dump_file_name); | |||
| MSRStatus FileToPages(const std::string dump_file_name); | |||
| private: | |||
| MSRStatus InitializeHeader(const std::vector<json> &headers); | |||
| @@ -18,7 +18,6 @@ | |||
| #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ | |||
| #include <libgen.h> | |||
| #include <sys/file.h> | |||
| #include <unistd.h> | |||
| #include <algorithm> | |||
| #include <array> | |||
| @@ -88,7 +87,7 @@ class ShardWriter { | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign = true, bool parallel_writer = false); | |||
| bool sign = true); | |||
| /// \brief write raw data by group size for call from python | |||
| /// \param[in] raw_data the vector of raw json data, python-handle format | |||
| @@ -96,7 +95,7 @@ class ShardWriter { | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign = true, bool parallel_writer = false); | |||
| bool sign = true); | |||
| /// \brief write raw data by group size for call from python | |||
| /// \param[in] raw_data the vector of raw json data, python-handle format | |||
| @@ -104,8 +103,7 @@ class ShardWriter { | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||
| bool parallel_writer = false); | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true); | |||
| private: | |||
| /// \brief write shard header data to disk | |||
| @@ -203,34 +201,7 @@ class ShardWriter { | |||
| MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, | |||
| std::map<int, std::string> &err_raw_data); | |||
| /// \brief Lock writer and save pages info | |||
| int LockWriter(bool parallel_writer = false); | |||
| /// \brief Unlock writer and save pages info | |||
| MSRStatus UnlockWriter(int fd, bool parallel_writer = false); | |||
| /// \brief Check raw data before writing | |||
| MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign, int *schema_count, int *row_count); | |||
| /// \brief Get full path from file name | |||
| MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths); | |||
| /// \brief Open files | |||
| MSRStatus OpenDataFiles(bool append); | |||
| /// \brief Remove lock file | |||
| MSRStatus RemoveLockFile(); | |||
| /// \brief Remove lock file | |||
| MSRStatus InitLockFile(); | |||
| private: | |||
| const std::string kLockFileSuffix = "_Locker"; | |||
| const std::string kPageFileSuffix = "_Pages"; | |||
| std::string lock_file_; // lock file for parallel run | |||
| std::string pages_file_; // temporary file of pages info for parallel run | |||
| int shard_count_; // number of files | |||
| uint64_t header_size_; // header size | |||
| uint64_t page_size_; // page size | |||
| @@ -240,7 +211,7 @@ class ShardWriter { | |||
| std::vector<uint64_t> raw_data_size_; // Raw data size | |||
| std::vector<uint64_t> blob_data_size_; // Blob data size | |||
| std::vector<std::string> file_paths_; // file paths | |||
| std::vector<string> file_paths_; // file paths | |||
| std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles | |||
| std::shared_ptr<ShardHeader> shard_header_; // shard headers | |||
| @@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std | |||
| for (int raw_page_id : raw_page_ids) { | |||
| auto sql = GenerateRawSQL(fields_); | |||
| if (sql.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Generate raw SQL failed"; | |||
| return FAILED; | |||
| } | |||
| auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); | |||
| if (data.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Generate raw data failed"; | |||
| return FAILED; | |||
| } | |||
| if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { | |||
| MS_LOG(ERROR) << "Execute SQL failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; | |||
| @@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() { | |||
| } | |||
| } | |||
| MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) { | |||
| MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) { | |||
| shard_count_ = paths.size(); | |||
| if (shard_count_ > kMaxShardCount || shard_count_ == 0) { | |||
| MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; | |||
| return FAILED; | |||
| } | |||
| if (schema_count_ > kMaxSchemaCount) { | |||
| MS_LOG(ERROR) << "The schema Count greater than max value."; | |||
| return FAILED; | |||
| } | |||
| // Get full path from file name | |||
| for (const auto &path : paths) { | |||
| if (!CheckIsValidUtf8(path)) { | |||
| @@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p | |||
| char resolved_path[PATH_MAX] = {0}; | |||
| char buf[PATH_MAX] = {0}; | |||
| if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | |||
| MS_LOG(ERROR) << "Secure func failed"; | |||
| MS_LOG(ERROR) << "Securec func failed"; | |||
| return FAILED; | |||
| } | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| @@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p | |||
| #endif | |||
| file_paths_.emplace_back(string(resolved_path)); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::OpenDataFiles(bool append) { | |||
| // Open files | |||
| for (const auto &file : file_paths_) { | |||
| std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>(); | |||
| @@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) { | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::RemoveLockFile() { | |||
| // Remove temporary file | |||
| int ret = std::remove(pages_file_.c_str()); | |||
| if (ret == 0) { | |||
| MS_LOG(DEBUG) << "Remove page file."; | |||
| } | |||
| ret = std::remove(lock_file_.c_str()); | |||
| if (ret == 0) { | |||
| MS_LOG(DEBUG) << "Remove lock file."; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::InitLockFile() { | |||
| if (file_paths_.size() == 0) { | |||
| MS_LOG(ERROR) << "File path not initialized."; | |||
| return FAILED; | |||
| } | |||
| lock_file_ = file_paths_[0] + kLockFileSuffix; | |||
| pages_file_ = file_paths_[0] + kPageFileSuffix; | |||
| if (RemoveLockFile() == FAILED) { | |||
| MS_LOG(ERROR) << "Remove file failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) { | |||
| shard_count_ = paths.size(); | |||
| if (shard_count_ > kMaxShardCount || shard_count_ == 0) { | |||
| MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; | |||
| return FAILED; | |||
| } | |||
| if (schema_count_ > kMaxSchemaCount) { | |||
| MS_LOG(ERROR) << "The schema Count greater than max value."; | |||
| return FAILED; | |||
| } | |||
| // Get full path from file name | |||
| if (GetFullPathFromFileName(paths) == FAILED) { | |||
| MS_LOG(ERROR) << "Get full path from file name failed."; | |||
| return FAILED; | |||
| } | |||
| // Open files | |||
| if (OpenDataFiles(append) == FAILED) { | |||
| MS_LOG(ERROR) << "Open data files failed."; | |||
| return FAILED; | |||
| } | |||
| // Init lock file | |||
| if (InitLockFile() == FAILED) { | |||
| MS_LOG(ERROR) << "Init lock file failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||
| if (!IsLegalFile(path)) { | |||
| return FAILED; | |||
| @@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||
| } | |||
| MSRStatus ShardWriter::Commit() { | |||
| // Read pages file | |||
| std::ifstream page_file(pages_file_.c_str()); | |||
| if (page_file.good()) { | |||
| page_file.close(); | |||
| if (shard_header_->FileToPages(pages_file_) == FAILED) { | |||
| MS_LOG(ERROR) << "Read pages from file failed"; | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (WriteShardHeader() == FAILED) { | |||
| MS_LOG(ERROR) << "Write metadata failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Write metadata successfully."; | |||
| // Remove lock file | |||
| if (RemoveLockFile() == FAILED) { | |||
| MS_LOG(ERROR) << "Remove lock file failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>> | |||
| } | |||
| } | |||
| int ShardWriter::LockWriter(bool parallel_writer) { | |||
| if (!parallel_writer) { | |||
| return 0; | |||
| } | |||
| const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); | |||
| if (fd >= 0) { | |||
| flock(fd, LOCK_EX); | |||
| } else { | |||
| MS_LOG(ERROR) << "Shard writer failed when locking file"; | |||
| return -1; | |||
| } | |||
| // Open files | |||
| file_streams_.clear(); | |||
| for (const auto &file : file_paths_) { | |||
| std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>(); | |||
| fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); | |||
| if (fs->fail()) { | |||
| MS_LOG(ERROR) << "File could not opened"; | |||
| return -1; | |||
| } | |||
| file_streams_.push_back(fs); | |||
| } | |||
| if (shard_header_->FileToPages(pages_file_) == FAILED) { | |||
| MS_LOG(ERROR) << "Read pages from file failed"; | |||
| return -1; | |||
| } | |||
| return fd; | |||
| } | |||
| MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { | |||
| if (!parallel_writer) { | |||
| return SUCCESS; | |||
| } | |||
| if (shard_header_->PagesToFile(pages_file_) == FAILED) { | |||
| MS_LOG(ERROR) << "Write pages to file failed"; | |||
| return FAILED; | |||
| } | |||
| for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) { | |||
| file_streams_[i]->close(); | |||
| } | |||
| flock(fd, LOCK_UN); | |||
| close(fd); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count, | |||
| int *row_count) { | |||
| MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign) { | |||
| // check the free disk size | |||
| auto st_space = GetDiskSize(file_paths_[0], kFreeSize); | |||
| if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { | |||
| MS_LOG(ERROR) << "IO error / there is no free disk to be used"; | |||
| return FAILED; | |||
| } | |||
| // Add 4-bytes dummy blob data if no any blob fields | |||
| if (blob_data.size() == 0 && raw_data.size() > 0) { | |||
| blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0)); | |||
| @@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json> | |||
| MS_LOG(ERROR) << "Validate raw data failed"; | |||
| return FAILED; | |||
| } | |||
| *schema_count = std::get<1>(v); | |||
| *row_count = std::get<2>(v); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) { | |||
| // Lock Writer if loading data parallel | |||
| int fd = LockWriter(parallel_writer); | |||
| if (fd < 0) { | |||
| MS_LOG(ERROR) << "Lock writer failed"; | |||
| return FAILED; | |||
| } | |||
| // Get the count of schemas and rows | |||
| int schema_count = 0; | |||
| int row_count = 0; | |||
| // Serialize raw data | |||
| if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { | |||
| MS_LOG(ERROR) << "Check raw data failed"; | |||
| return FAILED; | |||
| } | |||
| int schema_count = std::get<1>(v); | |||
| int row_count = std::get<2>(v); | |||
| if (row_count == kInt0) { | |||
| MS_LOG(INFO) << "Raw data size is 0."; | |||
| @@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_d | |||
| } | |||
| MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; | |||
| if (UnlockWriter(fd, parallel_writer) == FAILED) { | |||
| MS_LOG(ERROR) << "Unlock writer failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign, | |||
| bool parallel_writer) { | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign) { | |||
| std::map<uint64_t, std::vector<json>> raw_data_json; | |||
| std::map<uint64_t, std::vector<json>> blob_data_json; | |||
| @@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> | |||
| MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; | |||
| return FAILED; | |||
| } | |||
| return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); | |||
| return WriteRawData(raw_data_json, bin_blob_data, sign); | |||
| } | |||
| MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||
| vector<vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) { | |||
| vector<vector<uint8_t>> &blob_data, bool sign) { | |||
| std::map<uint64_t, std::vector<json>> raw_data_json; | |||
| (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), | |||
| [](const std::pair<uint64_t, std::vector<py::handle>> &pair) { | |||
| @@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> | |||
| [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); | |||
| return std::make_pair(pair.first, std::move(json_raw_data)); | |||
| }); | |||
| return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); | |||
| return WriteRawData(raw_data_json, blob_data, sign); | |||
| } | |||
| MSRStatus ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, | |||
| @@ -677,43 +677,5 @@ std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID( | |||
| } | |||
| return std::make_pair(statistics_.at(statistic_id), SUCCESS); | |||
| } | |||
| MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { | |||
| // write header content to file, dump whatever is in the file before | |||
| std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); | |||
| if (page_out_handle.fail()) { | |||
| MS_LOG(ERROR) << "Failed in opening page file"; | |||
| return FAILED; | |||
| } | |||
| auto pages = SerializePage(); | |||
| for (const auto &shard_pages : pages) { | |||
| page_out_handle << shard_pages << "\n"; | |||
| } | |||
| page_out_handle.close(); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||
| for (auto &v : pages_) { // clean pages | |||
| v.clear(); | |||
| } | |||
| // attempt to open the file contains the page in json | |||
| std::ifstream page_in_handle(dump_file_name.c_str()); | |||
| if (!page_in_handle.good()) { | |||
| MS_LOG(INFO) << "No page file exists."; | |||
| return SUCCESS; | |||
| } | |||
| std::string line; | |||
| while (std::getline(page_in_handle, line)) { | |||
| ParsePage(json::parse(line)); | |||
| } | |||
| page_in_handle.close(); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -200,24 +200,13 @@ class FileWriter: | |||
| raw_data.pop(i) | |||
| logger.warning(v) | |||
| def open_and_set_header(self): | |||
| """ | |||
| Open writer and set header | |||
| """ | |||
| if not self._writer.is_open: | |||
| self._writer.open(self._paths) | |||
| if not self._writer.get_shard_header(): | |||
| self._writer.set_shard_header(self._header) | |||
| def write_raw_data(self, raw_data, parallel_writer=False): | |||
| def write_raw_data(self, raw_data): | |||
| """ | |||
| Write raw data and generate sequential pair of MindRecord File and \ | |||
| validate data based on predefined schema by default. | |||
| Args: | |||
| raw_data (list[dict]): List of raw data. | |||
| parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). | |||
| Raises: | |||
| ParamTypeError: If index field is invalid. | |||
| @@ -236,7 +225,7 @@ class FileWriter: | |||
| if not isinstance(each_raw, dict): | |||
| raise ParamTypeError('raw_data item', 'dict') | |||
| self._verify_based_on_schema(raw_data) | |||
| return self._writer.write_raw_data(raw_data, True, parallel_writer) | |||
| return self._writer.write_raw_data(raw_data, True) | |||
| def set_header_size(self, header_size): | |||
| """ | |||
| @@ -135,7 +135,7 @@ class ShardWriter: | |||
| def get_shard_header(self): | |||
| return self._header | |||
| def write_raw_data(self, data, validate=True, parallel_writer=False): | |||
| def write_raw_data(self, data, validate=True): | |||
| """ | |||
| Write raw data of cv dataset. | |||
| @@ -145,7 +145,6 @@ class ShardWriter: | |||
| Args: | |||
| data (list[dict]): List of raw data. | |||
| validate (bool, optional): verify data according schema if it equals to True. | |||
| parallel_writer (bool, optional): Load data parallel if it equals to True. | |||
| Returns: | |||
| MSRStatus, SUCCESS or FAILED. | |||
| @@ -166,7 +165,7 @@ class ShardWriter: | |||
| if row_raw: | |||
| raw_data.append(row_raw) | |||
| raw_data = {0: raw_data} if raw_data else {} | |||
| ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) | |||
| ret = self._writer.write_raw_data(raw_data, blob_data, validate) | |||
| if ret != ms.MSRStatus.SUCCESS: | |||
| logger.error("Failed to write dataset.") | |||
| raise MRMWriteDatasetError | |||