Browse Source

!585 回退 'Pull Request !182 : Tuning mindrecord writer performance'

Merge pull request !585 from panfengfeng/revert-merge-182-master
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
235c69979d
16 changed files with 35 additions and 668 deletions
  1. +0
    -46
      example/convert_to_mindrecord/README.md
  2. +0
    -0
      example/convert_to_mindrecord/imagenet/__init__.py
  3. +0
    -122
      example/convert_to_mindrecord/imagenet/mr_api.py
  4. +0
    -8
      example/convert_to_mindrecord/run_imagenet.sh
  5. +0
    -6
      example/convert_to_mindrecord/run_template.sh
  6. +0
    -0
      example/convert_to_mindrecord/template/__init__.py
  7. +0
    -73
      example/convert_to_mindrecord/template/mr_api.py
  8. +0
    -149
      example/convert_to_mindrecord/writer.py
  9. +6
    -3
      mindspore/ccsrc/mindrecord/common/shard_pybind.cc
  10. +0
    -4
      mindspore/ccsrc/mindrecord/include/shard_header.h
  11. +4
    -33
      mindspore/ccsrc/mindrecord/include/shard_writer.h
  12. +0
    -3
      mindspore/ccsrc/mindrecord/io/shard_index_generator.cc
  13. +21
    -167
      mindspore/ccsrc/mindrecord/io/shard_writer.cc
  14. +0
    -38
      mindspore/ccsrc/mindrecord/meta/shard_header.cc
  15. +2
    -13
      mindspore/mindrecord/filewriter.py
  16. +2
    -3
      mindspore/mindrecord/shardwriter.py

+ 0
- 46
example/convert_to_mindrecord/README.md View File

@@ -1,46 +0,0 @@
# MindRecord generating guidelines

<!-- TOC -->

- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
- [Create work space](#create-work-space)
- [Implement data generator](#implement-data-generator)
- [Run data generator](#run-data-generator)

<!-- /TOC -->

## Create work space

Assume the dataset name is 'xyz'
* Create work space from template
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
cp -r template xyz
```

## Implement data generator

Edit dictionary data generator
* Edit file
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
vi xyz/mr_api.py
```

Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()


Tricky for parallel run
- For imagenet, one directory can be a task.
- For TFRecord with multiple files, each file can be a task.
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K)


## Run data generator
* run python script
```shell
cd ${your_mindspore_home}/example/convert_to_mindrecord
python writer.py --mindrecord_script imagenet [...]
```

+ 0
- 0
example/convert_to_mindrecord/imagenet/__init__.py View File


+ 0
- 122
example/convert_to_mindrecord/imagenet/mr_api.py View File

@@ -1,122 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import os
import pickle

######## mindrecord_schema begin ##########
mindrecord_schema = {"label": {"type": "int64"},
"data": {"type": "bytes"},
"file_name": {"type": "string"}}
######## mindrecord_schema end ##########

######## Frozen code begin ##########
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
ARG_LIST = pickle.load(mindrecord_argument_file_handle)
######## Frozen code end ##########

parser = argparse.ArgumentParser(description='Mind record imagenet example')
parser.add_argument('--label_file', type=str, default="", help='label file')
parser.add_argument('--image_dir', type=str, default="", help='images directory')

######## Frozen code begin ##########
args = parser.parse_args(ARG_LIST)
print(args)
######## Frozen code end ##########


def _user_defined_private_func():
"""
Internal function for tasks list

Return:
tasks list
"""
if not os.path.exists(args.label_file):
raise IOError("map file {} not exists".format(args.label_file))

label_dict = {}
with open(args.label_file) as file_handle:
line = file_handle.readline()
while line:
labels = line.split(" ")
label_dict[labels[1]] = labels[0]
line = file_handle.readline()
# get all the dir which are n02087046, n02094114, n02109525
dir_paths = {}
for item in label_dict:
real_path = os.path.join(args.image_dir, label_dict[item])
if not os.path.isdir(real_path):
print("{} dir is not exist".format(real_path))
continue
dir_paths[item] = real_path

if not dir_paths:
print("not valid image dir in {}".format(args.image_dir))
return {}, {}

dir_list = []
for label in dir_paths:
dir_list.append(label)
return dir_list, dir_paths


dir_list_global, dir_paths_global = _user_defined_private_func()

def mindrecord_task_number():
"""
Get task size.

Return:
number of tasks
"""
return len(dir_list_global)


def mindrecord_dict_data(task_id):
"""
Get data dict.

Yields:
data (dict): data row which is dict.
"""

# get the filename, label and image binary as a dict
label = dir_list_global[task_id]
for item in os.listdir(dir_paths_global[label]):
file_name = os.path.join(dir_paths_global[label], item)
if not item.endswith("JPEG") and not item.endswith(
"jpg") and not item.endswith("jpeg"):
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
continue
data = {}
data["file_name"] = str(file_name)
data["label"] = int(label)

# get the image data
image_file = open(file_name, "rb")
image_bytes = image_file.read()
image_file.close()
data["data"] = image_bytes
yield data

+ 0
- 8
example/convert_to_mindrecord/run_imagenet.sh View File

@@ -1,8 +0,0 @@
#!/bin/bash
rm /tmp/imagenet/mr/*

python writer.py --mindrecord_script imagenet \
--mindrecord_file "/tmp/imagenet/mr/m" \
--mindrecord_partitions 16 \
--label_file "/tmp/imagenet/label.txt" \
--image_dir "/tmp/imagenet/jpeg"

+ 0
- 6
example/convert_to_mindrecord/run_template.sh View File

@@ -1,6 +0,0 @@
#!/bin/bash
rm /tmp/template/*

python writer.py --mindrecord_script template \
--mindrecord_file "/tmp/template/m" \
--mindrecord_partitions 4

+ 0
- 0
example/convert_to_mindrecord/template/__init__.py View File


+ 0
- 73
example/convert_to_mindrecord/template/mr_api.py View File

@@ -1,73 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
User-defined API for MindRecord writer.
Two API must be implemented,
1. mindrecord_task_number()
# Return number of parallel tasks. return 1 if no parallel
2. mindrecord_dict_data(task_id)
# Yield data for one task
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
"""
import argparse
import pickle

# ## Parse argument

with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line
ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line
parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line

# ## Your arguments below
# parser.add_argument(...)

args = parser.parse_args(ARG_LIST) # Do NOT change this line
print(args) # Do NOT change this line


# ## Default mindrecord vars. Comment them unless default value has to be changed.
# mindrecord_index_fields = ['label']
# mindrecord_header_size = 1 << 24
# mindrecord_page_size = 1 << 25


# define global vars here if necessary


# ####### Your code below ##########
mindrecord_schema = {"label": {"type": "int32"}}

def mindrecord_task_number():
"""
Get task size.

Return:
number of tasks
"""
return 1


def mindrecord_dict_data(task_id):
"""
Get data dict.

Yields:
data (dict): data row which is dict.
"""
print("task is {}".format(task_id))
for i in range(256):
data = {}
data['label'] = i
yield data

+ 0
- 149
example/convert_to_mindrecord/writer.py View File

@@ -1,149 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
######################## write mindrecord example ########################
Write mindrecord by data dictionary:
python writer.py --mindrecord_script /YourScriptPath ...
"""
import argparse
import os
import pickle
import time
from importlib import import_module
from multiprocessing import Pool

from mindspore.mindrecord import FileWriter


def _exec_task(task_id, parallel_writer=True):
"""
Execute task with specified task id
"""
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
imagenet_iter = mindrecord_dict_data(task_id)
batch_size = 2048
transform_count = 0
while True:
data_list = []
try:
for _ in range(batch_size):
data_list.append(imagenet_iter.__next__())
transform_count += 1
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
except StopIteration:
if data_list:
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
print("transformed {} record...".format(transform_count))
break


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Mind record writer')
parser.add_argument('--mindrecord_script', type=str, default="template",
help='path where script is saved')

parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
help='written file name prefix')

parser.add_argument('--mindrecord_partitions', type=int, default=1,
help='number of written files')

parser.add_argument('--mindrecord_workers', type=int, default=8,
help='number of parallel workers')

args = parser.parse_known_args()

args, other_args = parser.parse_known_args()

print(args)
print(other_args)

with open('mr_argument.pickle', 'wb') as file_handle:
pickle.dump(other_args, file_handle)

try:
mr_api = import_module(args.mindrecord_script + '.mr_api')
except ModuleNotFoundError:
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))

num_tasks = mr_api.mindrecord_task_number()

print("Write mindrecord ...")

mindrecord_dict_data = mr_api.mindrecord_dict_data

# get number of files
writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)

start_time = time.time()

# set the header size
try:
header_size = mr_api.mindrecord_header_size
writer.set_header_size(header_size)
except AttributeError:
print("Default header size: {}".format(1 << 24))

# set the page size
try:
page_size = mr_api.mindrecord_page_size
writer.set_page_size(page_size)
except AttributeError:
print("Default page size: {}".format(1 << 25))

# get schema
try:
mindrecord_schema = mr_api.mindrecord_schema
except AttributeError:
raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")

# create the schema
writer.add_schema(mindrecord_schema, "mindrecord_schema")

# add the index
try:
index_fields = mr_api.mindrecord_index_fields
writer.add_index(index_fields)
except AttributeError:
print("Default index fields: all simple fields are indexes.")

writer.open_and_set_header()

task_list = list(range(num_tasks))

# set number of workers
num_workers = args.mindrecord_workers

if num_tasks < 1:
num_tasks = 1

if num_workers > num_tasks:
num_workers = num_tasks

if num_tasks > 1:
with Pool(num_workers) as p:
p.map(_exec_task, task_list)
else:
_exec_task(0, False)

ret = writer.commit()

os.remove("{}".format("mr_argument.pickle"))

end_time = time.time()
print("--------------------------------------------")
print("END. Total time: {}".format(end_time - start_time))
print("--------------------------------------------")

+ 6
- 3
mindspore/ccsrc/mindrecord/common/shard_pybind.cc View File

@@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) {
.def("set_header_size", &ShardWriter::set_header_size)
.def("set_page_size", &ShardWriter::set_page_size)
.def("set_shard_header", &ShardWriter::SetShardHeader)
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
vector<vector<uint8_t>> &, bool, bool)) &
ShardWriter::WriteRawData)
.def("write_raw_data",
(MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, vector<vector<uint8_t>> &, bool)) &
ShardWriter::WriteRawData)
.def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
std::map<uint64_t, std::vector<py::handle>> &, bool)) &
ShardWriter::WriteRawData)
.def("commit", &ShardWriter::Commit);
}



+ 0
- 4
mindspore/ccsrc/mindrecord/include/shard_header.h View File

@@ -121,10 +121,6 @@ class ShardHeader {

std::vector<std::string> SerializeHeader();

MSRStatus PagesToFile(const std::string dump_file_name);

MSRStatus FileToPages(const std::string dump_file_name);

private:
MSRStatus InitializeHeader(const std::vector<json> &headers);



+ 4
- 33
mindspore/ccsrc/mindrecord/include/shard_writer.h View File

@@ -18,7 +18,6 @@
#define MINDRECORD_INCLUDE_SHARD_WRITER_H_

#include <libgen.h>
#include <sys/file.h>
#include <unistd.h>
#include <algorithm>
#include <array>
@@ -88,7 +87,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
bool sign = true);

/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
@@ -96,7 +95,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
bool sign = true);

/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
@@ -104,8 +103,7 @@ class ShardWriter {
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true);

private:
/// \brief write shard header data to disk
@@ -203,34 +201,7 @@ class ShardWriter {
MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);

/// \brief Lock writer and save pages info
int LockWriter(bool parallel_writer = false);

/// \brief Unlock writer and save pages info
MSRStatus UnlockWriter(int fd, bool parallel_writer = false);

/// \brief Check raw data before writing
MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);

/// \brief Get full path from file name
MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths);

/// \brief Open files
MSRStatus OpenDataFiles(bool append);

/// \brief Remove lock file
MSRStatus RemoveLockFile();

/// \brief Remove lock file
MSRStatus InitLockFile();

private:
const std::string kLockFileSuffix = "_Locker";
const std::string kPageFileSuffix = "_Pages";
std::string lock_file_; // lock file for parallel run
std::string pages_file_; // temporary file of pages info for parallel run

int shard_count_; // number of files
uint64_t header_size_; // header size
uint64_t page_size_; // page size
@@ -240,7 +211,7 @@ class ShardWriter {
std::vector<uint64_t> raw_data_size_; // Raw data size
std::vector<uint64_t> blob_data_size_; // Blob data size

std::vector<std::string> file_paths_; // file paths
std::vector<string> file_paths_; // file paths
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
std::shared_ptr<ShardHeader> shard_header_; // shard headers



+ 0
- 3
mindspore/ccsrc/mindrecord/io/shard_index_generator.cc View File

@@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
for (int raw_page_id : raw_page_ids) {
auto sql = GenerateRawSQL(fields_);
if (sql.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw SQL failed";
return FAILED;
}
auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in);
if (data.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw data failed";
return FAILED;
}
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
MS_LOG(ERROR) << "Execute SQL failed";
return FAILED;
}
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";


+ 21
- 167
mindspore/ccsrc/mindrecord/io/shard_writer.cc View File

@@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() {
}
}

MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
shard_count_ = paths.size();
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
return FAILED;
}
if (schema_count_ > kMaxSchemaCount) {
MS_LOG(ERROR) << "The schema Count greater than max value.";
return FAILED;
}

// Get full path from file name
for (const auto &path : paths) {
if (!CheckIsValidUtf8(path)) {
@@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
char resolved_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Secure func failed";
MS_LOG(ERROR) << "Securec func failed";
return FAILED;
}
#if defined(_WIN32) || defined(_WIN64)
@@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
#endif
file_paths_.emplace_back(string(resolved_path));
}
return SUCCESS;
}

MSRStatus ShardWriter::OpenDataFiles(bool append) {
// Open files
for (const auto &file : file_paths_) {
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
@@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) {
return SUCCESS;
}

MSRStatus ShardWriter::RemoveLockFile() {
// Remove temporary file
int ret = std::remove(pages_file_.c_str());
if (ret == 0) {
MS_LOG(DEBUG) << "Remove page file.";
}

ret = std::remove(lock_file_.c_str());
if (ret == 0) {
MS_LOG(DEBUG) << "Remove lock file.";
}
return SUCCESS;
}

MSRStatus ShardWriter::InitLockFile() {
if (file_paths_.size() == 0) {
MS_LOG(ERROR) << "File path not initialized.";
return FAILED;
}

lock_file_ = file_paths_[0] + kLockFileSuffix;
pages_file_ = file_paths_[0] + kPageFileSuffix;

if (RemoveLockFile() == FAILED) {
MS_LOG(ERROR) << "Remove file failed.";
return FAILED;
}
return SUCCESS;
}

MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
shard_count_ = paths.size();
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
return FAILED;
}
if (schema_count_ > kMaxSchemaCount) {
MS_LOG(ERROR) << "The schema Count greater than max value.";
return FAILED;
}

// Get full path from file name
if (GetFullPathFromFileName(paths) == FAILED) {
MS_LOG(ERROR) << "Get full path from file name failed.";
return FAILED;
}

// Open files
if (OpenDataFiles(append) == FAILED) {
MS_LOG(ERROR) << "Open data files failed.";
return FAILED;
}

// Init lock file
if (InitLockFile() == FAILED) {
MS_LOG(ERROR) << "Init lock file failed.";
return FAILED;
}
return SUCCESS;
}

MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (!IsLegalFile(path)) {
return FAILED;
@@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
}

MSRStatus ShardWriter::Commit() {
// Read pages file
std::ifstream page_file(pages_file_.c_str());
if (page_file.good()) {
page_file.close();
if (shard_header_->FileToPages(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Read pages from file failed";
return FAILED;
}
}

if (WriteShardHeader() == FAILED) {
MS_LOG(ERROR) << "Write metadata failed";
return FAILED;
}
MS_LOG(INFO) << "Write metadata successfully.";

// Remove lock file
if (RemoveLockFile() == FAILED) {
MS_LOG(ERROR) << "Remove lock file failed.";
return FAILED;
}

return SUCCESS;
}

@@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>>
}
}

int ShardWriter::LockWriter(bool parallel_writer) {
if (!parallel_writer) {
return 0;
}
const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
if (fd >= 0) {
flock(fd, LOCK_EX);
} else {
MS_LOG(ERROR) << "Shard writer failed when locking file";
return -1;
}

// Open files
file_streams_.clear();
for (const auto &file : file_paths_) {
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary);
if (fs->fail()) {
MS_LOG(ERROR) << "File could not opened";
return -1;
}
file_streams_.push_back(fs);
}

if (shard_header_->FileToPages(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Read pages from file failed";
return -1;
}
return fd;
}

MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
if (!parallel_writer) {
return SUCCESS;
}

if (shard_header_->PagesToFile(pages_file_) == FAILED) {
MS_LOG(ERROR) << "Write pages to file failed";
return FAILED;
}

for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
file_streams_[i]->close();
}

flock(fd, LOCK_UN);
close(fd);
return SUCCESS;
}

MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
int *row_count) {
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign) {
// check the free disk size
auto st_space = GetDiskSize(file_paths_[0], kFreeSize);
if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) {
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
return FAILED;
}

// Add 4-bytes dummy blob data if no any blob fields
if (blob_data.size() == 0 && raw_data.size() > 0) {
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
@@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG(ERROR) << "Validate raw data failed";
return FAILED;
}
*schema_count = std::get<1>(v);
*row_count = std::get<2>(v);
return SUCCESS;
}

MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
// Lock Writer if loading data parallel
int fd = LockWriter(parallel_writer);
if (fd < 0) {
MS_LOG(ERROR) << "Lock writer failed";
return FAILED;
}

// Get the count of schemas and rows
int schema_count = 0;
int row_count = 0;

// Serialize raw data
if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) {
MS_LOG(ERROR) << "Check raw data failed";
return FAILED;
}
int schema_count = std::get<1>(v);
int row_count = std::get<2>(v);

if (row_count == kInt0) {
MS_LOG(INFO) << "Raw data size is 0.";
@@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_d
}
MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully.";

if (UnlockWriter(fd, parallel_writer) == FAILED) {
MS_LOG(ERROR) << "Unlock writer failed";
return FAILED;
}

return SUCCESS;
}

MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
bool parallel_writer) {
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign) {
std::map<uint64_t, std::vector<json>> raw_data_json;
std::map<uint64_t, std::vector<json>> blob_data_json;

@@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
MS_LOG(ERROR) << "Serialize raw data failed in write raw data";
return FAILED;
}
return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
return WriteRawData(raw_data_json, bin_blob_data, sign);
}

MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
vector<vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
vector<vector<uint8_t>> &blob_data, bool sign) {
std::map<uint64_t, std::vector<json>> raw_data_json;
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
[](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
@@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
return std::make_pair(pair.first, std::move(json_raw_data));
});
return WriteRawData(raw_data_json, blob_data, sign, parallel_writer);
return WriteRawData(raw_data_json, blob_data, sign);
}

MSRStatus ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,


+ 0
- 38
mindspore/ccsrc/mindrecord/meta/shard_header.cc View File

@@ -677,43 +677,5 @@ std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(
}
return std::make_pair(statistics_.at(statistic_id), SUCCESS);
}

MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) {
// write header content to file, dump whatever is in the file before
std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out);
if (page_out_handle.fail()) {
MS_LOG(ERROR) << "Failed in opening page file";
return FAILED;
}

auto pages = SerializePage();
for (const auto &shard_pages : pages) {
page_out_handle << shard_pages << "\n";
}

page_out_handle.close();
return SUCCESS;
}

MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
for (auto &v : pages_) { // clean pages
v.clear();
}
// attempt to open the file contains the page in json
std::ifstream page_in_handle(dump_file_name.c_str());

if (!page_in_handle.good()) {
MS_LOG(INFO) << "No page file exists.";
return SUCCESS;
}

std::string line;
while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line));
}

page_in_handle.close();
return SUCCESS;
}
} // namespace mindrecord
} // namespace mindspore

+ 2
- 13
mindspore/mindrecord/filewriter.py View File

@@ -200,24 +200,13 @@ class FileWriter:
raw_data.pop(i)
logger.warning(v)

def open_and_set_header(self):
"""
Open writer and set header

"""
if not self._writer.is_open:
self._writer.open(self._paths)
if not self._writer.get_shard_header():
self._writer.set_shard_header(self._header)

def write_raw_data(self, raw_data, parallel_writer=False):
def write_raw_data(self, raw_data):
"""
Write raw data and generate sequential pair of MindRecord File and \
validate data based on predefined schema by default.

Args:
raw_data (list[dict]): List of raw data.
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).

Raises:
ParamTypeError: If index field is invalid.
@@ -236,7 +225,7 @@ class FileWriter:
if not isinstance(each_raw, dict):
raise ParamTypeError('raw_data item', 'dict')
self._verify_based_on_schema(raw_data)
return self._writer.write_raw_data(raw_data, True, parallel_writer)
return self._writer.write_raw_data(raw_data, True)

def set_header_size(self, header_size):
"""


+ 2
- 3
mindspore/mindrecord/shardwriter.py View File

@@ -135,7 +135,7 @@ class ShardWriter:
def get_shard_header(self):
return self._header

def write_raw_data(self, data, validate=True, parallel_writer=False):
def write_raw_data(self, data, validate=True):
"""
Write raw data of cv dataset.

@@ -145,7 +145,6 @@ class ShardWriter:
Args:
data (list[dict]): List of raw data.
validate (bool, optional): verify data according schema if it equals to True.
parallel_writer (bool, optional): Load data parallel if it equals to True.

Returns:
MSRStatus, SUCCESS or FAILED.
@@ -166,7 +165,7 @@ class ShardWriter:
if row_raw:
raw_data.append(row_raw)
raw_data = {0: raw_data} if raw_data else {}
ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer)
ret = self._writer.write_raw_data(raw_data, blob_data, validate)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to write dataset.")
raise MRMWriteDatasetError


Loading…
Cancel
Save