Browse Source

refactor mindrecord

tags/v1.5.0-rc1
liyong 4 years ago
parent
commit
c257cf36e2
51 changed files with 2020 additions and 2864 deletions
  1. +8
    -22
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  2. +21
    -38
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
  3. +6
    -9
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc
  4. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc
  5. +0
    -83
      mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc
  6. +171
    -40
      mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
  7. +32
    -33
      mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc
  8. +13
    -8
      mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h
  9. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_category.h
  10. +19
    -20
      mindspore/ccsrc/minddata/mindrecord/include/shard_column.h
  11. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h
  12. +41
    -51
      mindspore/ccsrc/minddata/mindrecord/include/shard_error.h
  13. +33
    -33
      mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
  14. +28
    -28
      mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h
  15. +16
    -21
      mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h
  16. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h
  17. +62
    -80
      mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
  18. +3
    -3
      mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
  19. +0
    -9
      mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h
  20. +19
    -20
      mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h
  21. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h
  22. +4
    -4
      mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h
  23. +0
    -9
      mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h
  24. +61
    -71
      mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
  25. +204
    -329
      mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc
  26. +374
    -549
      mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
  27. +122
    -181
      mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc
  28. +272
    -496
      mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
  29. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc
  30. +74
    -100
      mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc
  31. +5
    -11
      mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
  32. +166
    -316
      mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
  33. +3
    -5
      mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc
  34. +10
    -17
      mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
  35. +0
    -12
      mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc
  36. +4
    -5
      mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc
  37. +14
    -29
      mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc
  38. +0
    -18
      mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc
  39. +2
    -8
      mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc
  40. +1
    -1
      mindspore/mindrecord/shardreader.py
  41. +4
    -26
      mindspore/mindrecord/shardsegment.py
  42. +2
    -2
      tests/ut/cpp/mindrecord/ut_common.cc
  43. +8
    -3
      tests/ut/cpp/mindrecord/ut_shard.cc
  44. +19
    -18
      tests/ut/cpp/mindrecord/ut_shard_header_test.cc
  45. +8
    -8
      tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
  46. +82
    -45
      tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
  47. +49
    -29
      tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
  48. +5
    -5
      tests/ut/python/dataset/test_minddataset_exception.py
  49. +4
    -5
      tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py
  50. +5
    -5
      tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py
  51. +38
    -49
      tests/ut/python/mindrecord/test_mindrecord_exception.py

+ 8
- 22
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -256,9 +256,7 @@ Status SaveToDisk::Save() {
auto mr_header = std::make_shared<mindrecord::ShardHeader>();
auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
std::vector<std::string> blob_fields;
if (mindrecord::SUCCESS != mindrecord::ShardWriter::Initialize(&mr_writer, file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter, please check above `ERROR` level message.");
}
RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));

std::unordered_map<std::string, int32_t> column_name_id_map;
for (auto el : tree_adapter_->GetColumnNameMap()) {
@@ -286,22 +284,16 @@ Status SaveToDisk::Save() {
std::vector<std::string> index_fields;
RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
if (mindrecord::SUCCESS != mr_writer->SetShardHeader(mr_header)) {
RETURN_STATUS_UNEXPECTED("Error: failed to set header of ShardWriter.");
}
RETURN_IF_NOT_OK(
mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
first_loop = false;
}
// construct data
if (!row.empty()) { // write data
RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
std::shared_ptr<std::vector<uint8_t>> output_bin_data;
if (mindrecord::SUCCESS != mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)) {
RETURN_STATUS_UNEXPECTED("Error: failed to merge blob data of ShardWriter.");
}
RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
raw_data.insert(
std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
@@ -309,18 +301,12 @@ Status SaveToDisk::Save() {
if (output_bin_data != nullptr) {
bin_data.emplace_back(*output_bin_data);
}
if (mindrecord::SUCCESS != mr_writer->WriteRawData(raw_data, bin_data)) {
RETURN_STATUS_UNEXPECTED("Error: failed to write raw data to ShardWriter.");
}
RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
}
} while (!row.empty());

if (mindrecord::SUCCESS != mr_writer->Commit()) {
RETURN_STATUS_UNEXPECTED("Error: failed to commit ShardWriter.");
}
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::Finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
RETURN_IF_NOT_OK(mr_writer->Commit());
RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
return Status::OK();
}



+ 21
- 38
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc View File

@@ -23,6 +23,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/include/dataset/constants.h"
@@ -63,10 +64,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str

// Private helper method to encapsulate some common construction/reset tasks
Status MindRecordOp::Init() {
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
num_padded_);

CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed, " + ErrnoToMessage(rc));
RETURN_IF_NOT_OK(shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_,
operators_, num_padded_));

data_schema_ = std::make_unique<DataSchema>();

@@ -206,7 +205,9 @@ Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, i
fetched_row->setPath(file_path);
fetched_row->setId(row_id);
}
if (tupled_buffer.empty()) return Status::OK();
if (tupled_buffer.empty()) {
return Status::OK();
}
if (task_type == mindrecord::TaskType::kCommonTask) {
for (const auto &tupled_row : tupled_buffer) {
std::vector<uint8_t> columns_blob = std::get<0>(tupled_row);
@@ -237,20 +238,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
// Get column data
auto shard_column = shard_reader_->GetShardColumn();
if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) {
auto rc =
shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape);
if (rc.first != MSRStatus::SUCCESS) {
RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset.");
}
if (rc.second == mindrecord::ColumnInRaw) {
auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
if (column_in_raw == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample.");
}
} else if (rc.second == mindrecord::ColumnInBlob) {
if (sample_bytes_.find(column_name) == sample_bytes_.end()) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve blob data from padding sample.");
}
mindrecord::ColumnCategory category;
RETURN_IF_NOT_OK(shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size,
&column_shape, &category));
if (category == mindrecord::ColumnInRaw) {
RETURN_IF_NOT_OK(shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes));
} else if (category == mindrecord::ColumnInBlob) {
CHECK_FAIL_RETURN_UNEXPECTED(sample_bytes_.find(column_name) != sample_bytes_.end(),
"Invalid data, failed to retrieve blob data from padding sample.");

std::string ss(sample_bytes_[column_name]);
n_bytes = ss.size();
data_ptr = std::make_unique<unsigned char[]>(n_bytes);
@@ -262,12 +258,9 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
data = reinterpret_cast<const unsigned char *>(data_ptr.get());
}
} else {
auto has_column =
shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes,
&column_data_type, &column_data_type_size, &column_shape);
if (has_column == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve data from mindrecord reader.");
}
RETURN_IF_NOT_OK(shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr,
&n_bytes, &column_data_type, &column_data_type_size,
&column_shape));
}

std::shared_ptr<Tensor> tensor;
@@ -309,15 +302,10 @@ Status MindRecordOp::Reset() {
}

Status MindRecordOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
}

RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
if (shard_reader_->Launch(true) == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed.");
}
RETURN_IF_NOT_OK(shard_reader_->Launch(true));
// Launch main workers that load TensorRows by reading all images
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id()));
@@ -330,12 +318,7 @@ Status MindRecordOp::LaunchThreadsAndInitOp() {
Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) {
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded);
if (rc == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, MindRecordOp failed to count total rows. Check whether there are corresponding .db files "
"and the value of dataset_file parameter is given correctly.");
}
RETURN_IF_NOT_OK(shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded));
return Status::OK();
}



+ 6
- 9
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc View File

@@ -38,9 +38,8 @@ Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
@@ -57,9 +56,8 @@ Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, con
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
@@ -81,9 +79,8 @@ Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::v
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));

if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);



+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc View File

@@ -94,10 +94,9 @@ Status GraphLoader::InitAndLoad() {
TaskGroup vg;

shard_reader_ = std::make_unique<ShardReader>();
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
"Fail to open" + mr_path_);
RETURN_IF_NOT_OK(shard_reader_->Open({mr_path_}, true, num_workers_));
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
RETURN_IF_NOT_OK(shard_reader_->Launch(true));

graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"];
@@ -116,8 +115,7 @@ Status GraphLoader::InitAndLoad() {
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0;
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS,
"failed to get total blob size");
RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size));
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
#endif


+ 0
- 83
mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc View File

@@ -1,83 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/mindrecord/include/shard_error.h"

namespace mindspore {
namespace mindrecord {
static const std::map<MSRStatus, std::string> kErrnoToMessage = {
{FAILED, "operator failed"},
{SUCCESS, "operator success"},
{OPEN_FILE_FAILED, "open file failed"},
{CLOSE_FILE_FAILED, "close file failed"},
{WRITE_METADATA_FAILED, "write metadata failed"},
{WRITE_RAWDATA_FAILED, "write rawdata failed"},
{GET_SCHEMA_FAILED, "get schema failed"},
{ILLEGAL_RAWDATA, "illegal raw data"},
{PYTHON_TO_JSON_FAILED, "pybind: python object to json failed"},
{DIR_CREATE_FAILED, "directory create failed"},
{OPEN_DIR_FAILED, "open directory failed"},
{INVALID_STATISTICS, "invalid statistics object"},
{OPEN_DATABASE_FAILED, "open database failed"},
{CLOSE_DATABASE_FAILED, "close database failed"},
{DATABASE_OPERATE_FAILED, "database operate failed"},
{BUILD_SCHEMA_FAILED, "build schema failed"},
{DIVISOR_IS_ILLEGAL, "divisor is illegal"},
{INVALID_FILE_PATH, "file path is invalid"},
{SECURE_FUNC_FAILED, "secure function failed"},
{ALLOCATE_MEM_FAILED, "allocate memory failed"},
{ILLEGAL_FIELD_NAME, "illegal field name"},
{ILLEGAL_FIELD_TYPE, "illegal field type"},
{SET_METADATA_FAILED, "set metadata failed"},
{ILLEGAL_SCHEMA_DEFINITION, "illegal schema definition"},
{ILLEGAL_COLUMN_LIST, "illegal column list"},
{SQL_ERROR, "sql error"},
{ILLEGAL_SHARD_COUNT, "illegal shard count"},
{ILLEGAL_SCHEMA_COUNT, "illegal schema count"},
{VERSION_ERROR, "data version is not matched"},
{ADD_SCHEMA_FAILED, "add schema failed"},
{ILLEGAL_Header_SIZE, "illegal header size"},
{ILLEGAL_Page_SIZE, "illegal page size"},
{ILLEGAL_SIZE_VALUE, "illegal size value"},
{INDEX_FIELD_ERROR, "add index fields failed"},
{GET_CANDIDATE_CATEGORYFIELDS_FAILED, "get candidate category fields failed"},
{GET_CATEGORY_INFO_FAILED, "get category information failed"},
{ILLEGAL_CATEGORY_ID, "illegal category id"},
{ILLEGAL_ROWNUMBER_OF_PAGE, "illegal row number of page"},
{ILLEGAL_SCHEMA_ID, "illegal schema id"},
{DESERIALIZE_SCHEMA_FAILED, "deserialize schema failed"},
{DESERIALIZE_STATISTICS_FAILED, "deserialize statistics failed"},
{ILLEGAL_DB_FILE, "illegal db file"},
{OVERWRITE_DB_FILE, "overwrite db file"},
{OVERWRITE_MINDRECORD_FILE, "overwrite mindrecord file"},
{ILLEGAL_MINDRECORD_FILE, "illegal mindrecord file"},
{PARSE_JSON_FAILED, "parse json failed"},
{ILLEGAL_PARAMETERS, "illegal parameters"},
{GET_PAGE_BY_GROUP_ID_FAILED, "get page by group id failed"},
{GET_SYSTEM_STATE_FAILED, "get system state failed"},
{IO_FAILED, "io operate failed"},
{MATCH_HEADER_FAILED, "match header failed"}};

std::string ErrnoToMessage(MSRStatus status) {
auto iter = kErrnoToMessage.find(status);
if (iter != kErrnoToMessage.end()) {
return kErrnoToMessage.at(status);
} else {
return "invalid error no";
}
}
} // namespace mindrecord
} // namespace mindspore

+ 171
- 40
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc View File

@@ -36,20 +36,42 @@ using mindspore::MsLogLevel::ERROR;

namespace mindspore {
namespace mindrecord {
#define THROW_IF_ERROR(s) \
do { \
Status rc = std::move(s); \
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
} while (false)

void BindSchema(py::module *m) {
(void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
.def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build)
.def_static("build",
[](const std::string &desc, const pybind11::handle &schema) {
json schema_json = nlohmann::detail::ToJsonImpl(schema);
return Schema::Build(std::move(desc), schema_json);
})
.def("get_desc", &Schema::GetDesc)
.def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython)
.def("get_schema_content",
[](Schema &s) {
json schema_json = s.GetSchema();
return nlohmann::detail::FromJsonImpl(schema_json);
})
.def("get_blob_fields", &Schema::GetBlobFields)
.def("get_schema_id", &Schema::GetSchemaID);
}

void BindStatistics(const py::module *m) {
(void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
.def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build)
.def_static("build",
[](const std::string desc, const pybind11::handle &statistics) {
json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
return Statistics::Build(std::move(desc), statistics_json);
})
.def("get_desc", &Statistics::GetDesc)
.def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython)
.def("get_statistics",
[](Statistics &s) {
json statistics_json = s.GetStatistics();
return nlohmann::detail::FromJsonImpl(statistics_json);
})
.def("get_statistics_id", &Statistics::GetStatisticsID);
}

@@ -59,70 +81,179 @@ void BindShardHeader(const py::module *m) {
.def("add_schema", &ShardHeader::AddSchema)
.def("add_statistics", &ShardHeader::AddStatistic)
.def("add_index_fields",
(MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields)
[](ShardHeader &s, const std::vector<std::string> &fields) {
THROW_IF_ERROR(s.AddIndexFields(fields));
return SUCCESS;
})
.def("get_meta", &ShardHeader::GetSchemas)
.def("get_statistics", &ShardHeader::GetStatistics)
.def("get_fields", &ShardHeader::GetFields)
.def("get_schema_by_id", &ShardHeader::GetSchemaByID)
.def("get_statistic_by_id", &ShardHeader::GetStatisticByID);
.def("get_schema_by_id",
[](ShardHeader &s, int64_t schema_id) {
std::shared_ptr<Schema> schema_ptr;
THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
return schema_ptr;
})
.def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
std::shared_ptr<Statistics> statistics_ptr;
THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
return statistics_ptr;
});
}

void BindShardWriter(py::module *m) {
(void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
.def(py::init<>())
.def("open", &ShardWriter::Open)
.def("open_for_append", &ShardWriter::OpenForAppend)
.def("set_header_size", &ShardWriter::SetHeaderSize)
.def("set_page_size", &ShardWriter::SetPageSize)
.def("set_shard_header", &ShardWriter::SetShardHeader)
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
vector<vector<uint8_t>> &, bool, bool)) &
ShardWriter::WriteRawData)
.def("commit", &ShardWriter::Commit);
.def("open",
[](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
THROW_IF_ERROR(s.Open(paths, append));
return SUCCESS;
})
.def("open_for_append",
[](ShardWriter &s, const std::string &path) {
THROW_IF_ERROR(s.OpenForAppend(path));
return SUCCESS;
})
.def("set_header_size",
[](ShardWriter &s, const uint64_t &header_size) {
THROW_IF_ERROR(s.SetHeaderSize(header_size));
return SUCCESS;
})
.def("set_page_size",
[](ShardWriter &s, const uint64_t &page_size) {
THROW_IF_ERROR(s.SetPageSize(page_size));
return SUCCESS;
})
.def("set_shard_header",
[](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
THROW_IF_ERROR(s.SetShardHeader(header_data));
return SUCCESS;
})
.def("write_raw_data",
[](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, bool parallel_writer) {
std::map<uint64_t, std::vector<json>> raw_data_json;
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
[](const std::pair<uint64_t, std::vector<py::handle>> &p) {
auto &py_raw_data = p.second;
std::vector<json> json_raw_data;
(void)std::transform(
py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
return std::make_pair(p.first, std::move(json_raw_data));
});
THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer));
return SUCCESS;
})
.def("commit", [](ShardWriter &s) {
THROW_IF_ERROR(s.Commit());
return SUCCESS;
});
}

void BindShardReader(const py::module *m) {
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
.def(py::init<>())
.def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardReader::OpenPy)
.def("launch", &ShardReader::Launch)
.def("open",
[](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
return SUCCESS;
})
.def("launch",
[](ShardReader &s) {
THROW_IF_ERROR(s.Launch(false));
return SUCCESS;
})
.def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("get_next",
[](ShardReader &s) {
auto data = s.GetNext();
vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res;
std::transform(data.begin(), data.end(), std::back_inserter(res),
[&s](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
(void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr);
return std::make_tuple(*blob_data_ptr, std::move(obj));
});
return res;
})
.def("close", &ShardReader::Close);
}

void BindShardIndexGenerator(const py::module *m) {
(void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
.def(py::init<const std::string &, bool>())
.def("build", &ShardIndexGenerator::Build)
.def("write_to_db", &ShardIndexGenerator::WriteToDatabase);
.def("build",
[](ShardIndexGenerator &s) {
THROW_IF_ERROR(s.Build());
return SUCCESS;
})
.def("write_to_db", [](ShardIndexGenerator &s) {
THROW_IF_ERROR(s.WriteToDatabase());
return SUCCESS;
});
}

void BindShardSegment(py::module *m) {
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
.def(py::init<>())
.def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardSegment::OpenPy)
.def("open",
[](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
return SUCCESS;
})
.def("get_category_fields",
(std::pair<MSRStatus, vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetCategoryFields)
.def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField)
.def("read_category_info", (std::pair<MSRStatus, std::string>(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo)
.def("read_at_page_by_id", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
ShardSegment::*)(int64_t, int64_t, int64_t)) &
ShardSegment::ReadAtPageByIdPy)
.def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
ShardSegment::*)(std::string, int64_t, int64_t)) &
ShardSegment::ReadAtPageByNamePy)
[](ShardSegment &s) {
auto fields_ptr = std::make_shared<vector<std::string>>();
THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
return *fields_ptr;
})
.def("set_category_field",
[](ShardSegment &s, const std::string &category_field) {
THROW_IF_ERROR(s.SetCategoryField(category_field));
return SUCCESS;
})
.def("read_category_info",
[](ShardSegment &s) {
std::shared_ptr<std::string> category_ptr;
THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
return *category_ptr;
})
.def("read_at_page_by_id",
[](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
auto pages_ptr = std::make_shared<PAGES>();
THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
(void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return *pages_load_ptr;
})
.def("read_at_page_by_name",
[](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
auto pages_ptr = std::make_shared<PAGES>();
THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
(void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return *pages_load_ptr;
})
.def("get_header", &ShardSegment::GetShardHeader)
.def("get_blob_fields",
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields);
.def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
}

void BindGlobalParams(py::module *m) {


+ 32
- 33
mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc View File

@@ -57,26 +57,24 @@ bool ValidateFieldName(const std::string &str) {
return true;
}

std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr) {
RETURN_UNEXPECTED_IF_NULL(fn_ptr);
char real_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
}
char tmp[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
}
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
}
#else
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
}
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
MS_LOG(DEBUG) << "Path: " << path << "check successfully";
@@ -87,32 +85,32 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
size_t i = s.rfind(sep, s.length());
if (i != std::string::npos) {
if (i + 1 < s.size()) {
return {SUCCESS, s.substr(i + 1)};
*fn_ptr = std::make_shared<std::string>(s.substr(i + 1));
return Status::OK();
}
}
return {SUCCESS, s};
*fn_ptr = std::make_shared<std::string>(s);
return Status::OK();
}

std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr) {
RETURN_UNEXPECTED_IF_NULL(pd_ptr);
char real_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
}
char tmp[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64)
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
}
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
}
#else
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
}
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
MS_LOG(DEBUG) << "Path: " << path << "check successfully";
@@ -120,9 +118,11 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
#endif
std::string s = real_path;
if (s.rfind('/') + 1 <= s.size()) {
return {SUCCESS, s.substr(0, s.rfind('/') + 1)};
*pd_ptr = std::make_shared<std::string>(s.substr(0, s.rfind('/') + 1));
return Status::OK();
}
return {SUCCESS, "/"};
*pd_ptr = std::make_shared<std::string>("/");
return Status::OK();
}

bool CheckIsValidUtf8(const std::string &str) {
@@ -163,15 +163,16 @@ bool IsLegalFile(const std::string &path) {
return false;
}

std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) {
Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size_ptr) {
RETURN_UNEXPECTED_IF_NULL(size_ptr);
#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__)
return {SUCCESS, 100};
*size_ptr = std::make_shared<uint64_t>(100);
return Status::OK();
#else
uint64_t ll_count = 0;
struct statfs disk_info;
if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) {
MS_LOG(ERROR) << "Get disk size error";
return {FAILED, 0};
RETURN_STATUS_UNEXPECTED("Get disk size error.");
}

switch (disk_type) {
@@ -187,8 +188,8 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
ll_count = 0;
break;
}
return {SUCCESS, ll_count};
*size_ptr = std::make_shared<uint64_t>(ll_count);
return Status::OK();
#endif
}

@@ -201,17 +202,15 @@ uint32_t GetMaxThreadNum() {
return thread_num;
}

std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) {
auto ret = GetParentDir(path);
if (SUCCESS != ret.first) {
return {FAILED, {}};
}
std::vector<std::string> abs_addresses;
Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds) {
RETURN_UNEXPECTED_IF_NULL(ds);
std::shared_ptr<std::string> parent_dir;
RETURN_IF_NOT_OK(GetParentDir(path, &parent_dir));
for (const auto &p : addresses) {
std::string abs_path = ret.second + std::string(p);
abs_addresses.emplace_back(abs_path);
std::string abs_path = *parent_dir + std::string(p);
(*ds)->emplace_back(abs_path);
}
return {SUCCESS, abs_addresses};
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 13
- 8
mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h View File

@@ -33,6 +33,7 @@
#include <future>
#include <iostream>
#include <map>
#include <memory>
#include <random>
#include <set>
#include <sstream>
@@ -159,13 +160,15 @@ bool ValidateFieldName(const std::string &str);

/// \brief get the filename by the path
/// \param s file path
/// \return
std::pair<MSRStatus, std::string> GetFileName(const std::string &s);
/// \param fn_ptr shared ptr of file name
/// \return Status
Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr);

/// \brief get parent dir
/// \param path file path
/// \return parent path
std::pair<MSRStatus, std::string> GetParentDir(const std::string &path);
/// \param pd_ptr shared ptr of parent path
/// \return Status
Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr);

bool CheckIsValidUtf8(const std::string &str);

@@ -179,8 +182,9 @@ enum DiskSizeType { kTotalSize = 0, kFreeSize };
/// \brief get the free space about the disk
/// \param str_dir file path
/// \param disk_type: kTotalSize / kFreeSize
/// \return size in Megabytes
std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type);
/// \param size: shared ptr of size in Megabytes
/// \return Status
Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size);

/// \brief get the max hardware concurrency
/// \return max concurrency
@@ -189,8 +193,9 @@ uint32_t GetMaxThreadNum();
/// \brief get absolute path of all mindrecord files
/// \param path path to one fo mindrecord files
/// \param addresses relative path of all mindrecord files
/// \return vector of absolute path
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses);
/// \param ds shared ptr of vector of absolute path
/// \return Status
Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds);
} // namespace mindrecord
} // namespace mindspore



+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_category.h View File

@@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato

bool GetReplacement() const { return replacement_; }

MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;



+ 19
- 20
mindspore/ccsrc/minddata/mindrecord/include/shard_column.h View File

@@ -65,11 +65,11 @@ class __attribute__((visibility("default"))) ShardColumn {
~ShardColumn() = default;

/// \brief get column value by column name
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
Status GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);

/// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size);
@@ -90,19 +90,18 @@ class __attribute__((visibility("default"))) ShardColumn {
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }

/// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes);
Status GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes);

/// \brief get column type
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type,
uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
Status GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type,
uint64_t *column_data_type_size, std::vector<int64_t> *column_shape,
ColumnCategory *column_category);

/// \brief get column value from json
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
Status GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);

private:
/// \brief initialization
@@ -110,15 +109,15 @@ class __attribute__((visibility("default"))) ShardColumn {

/// \brief get float value from json
template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
Status GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);

/// \brief get integer value from json
template <typename T>
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
Status GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);

/// \brief get column offset address and size from blob
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);
Status GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);

/// \brief check if column name is available
ColumnCategory CheckColumnName(const std::string &column_name);
@@ -128,8 +127,8 @@ class __attribute__((visibility("default"))) ShardColumn {

/// \brief uncompress integer array column
template <typename T>
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
static Status UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);

/// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h View File

@@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha

~ShardDistributedSample() override{};

MSRStatus PreExecute(ShardTaskList &tasks) override;
Status PreExecute(ShardTaskList &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;



+ 41
- 51
mindspore/ccsrc/minddata/mindrecord/include/shard_error.h View File

@@ -19,65 +19,55 @@

#include <map>
#include <string>
#include "include/api/status.h"

namespace mindspore {
namespace mindrecord {
#define RETURN_IF_NOT_OK(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
return __rc; \
} \
} while (false)

#define RELEASE_AND_RETURN_IF_NOT_OK(_s, _db, _in) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
if ((_db) != nullptr) { \
sqlite3_close(_db); \
} \
(_in).close(); \
return __rc; \
} \
} while (false)

#define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \
do { \
if (!(_condition)) { \
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \
} \
} while (false)

#define RETURN_UNEXPECTED_IF_NULL(_ptr) \
do { \
if ((_ptr) == nullptr) { \
std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \
RETURN_STATUS_UNEXPECTED(err_msg); \
} \
} while (false)

#define RETURN_STATUS_UNEXPECTED(_e) \
do { \
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \
} while (false)

enum MSRStatus {
SUCCESS = 0,
FAILED = 1,
OPEN_FILE_FAILED,
CLOSE_FILE_FAILED,
WRITE_METADATA_FAILED,
WRITE_RAWDATA_FAILED,
GET_SCHEMA_FAILED,
ILLEGAL_RAWDATA,
PYTHON_TO_JSON_FAILED,
DIR_CREATE_FAILED,
OPEN_DIR_FAILED,
INVALID_STATISTICS,
OPEN_DATABASE_FAILED,
CLOSE_DATABASE_FAILED,
DATABASE_OPERATE_FAILED,
BUILD_SCHEMA_FAILED,
DIVISOR_IS_ILLEGAL,
INVALID_FILE_PATH,
SECURE_FUNC_FAILED,
ALLOCATE_MEM_FAILED,
ILLEGAL_FIELD_NAME,
ILLEGAL_FIELD_TYPE,
SET_METADATA_FAILED,
ILLEGAL_SCHEMA_DEFINITION,
ILLEGAL_COLUMN_LIST,
SQL_ERROR,
ILLEGAL_SHARD_COUNT,
ILLEGAL_SCHEMA_COUNT,
VERSION_ERROR,
ADD_SCHEMA_FAILED,
ILLEGAL_Header_SIZE,
ILLEGAL_Page_SIZE,
ILLEGAL_SIZE_VALUE,
INDEX_FIELD_ERROR,
GET_CANDIDATE_CATEGORYFIELDS_FAILED,
GET_CATEGORY_INFO_FAILED,
ILLEGAL_CATEGORY_ID,
ILLEGAL_ROWNUMBER_OF_PAGE,
ILLEGAL_SCHEMA_ID,
DESERIALIZE_SCHEMA_FAILED,
DESERIALIZE_STATISTICS_FAILED,
ILLEGAL_DB_FILE,
OVERWRITE_DB_FILE,
OVERWRITE_MINDRECORD_FILE,
ILLEGAL_MINDRECORD_FILE,
PARSE_JSON_FAILED,
ILLEGAL_PARAMETERS,
GET_PAGE_BY_GROUP_ID_FAILED,
GET_SYSTEM_STATE_FAILED,
IO_FAILED,
MATCH_HEADER_FAILED
};

// convert error no to string message
std::string __attribute__((visibility("default"))) ErrnoToMessage(MSRStatus status);
} // namespace mindrecord
} // namespace mindspore



+ 33
- 33
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h View File

@@ -37,9 +37,9 @@ class __attribute__((visibility("default"))) ShardHeader {

~ShardHeader() = default;

MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);

static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr);
/// \brief add the schema and save it
/// \param[in] schema the schema needs to be added
/// \return the last schema's id
@@ -53,9 +53,9 @@ class __attribute__((visibility("default"))) ShardHeader {
/// \brief create index and add fields which from schema for each schema
/// \param[in] fields the index fields needs to be added
/// \return SUCCESS if add successfully, FAILED if not
MSRStatus AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);
Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);

MSRStatus AddIndexFields(const std::vector<std::string> &fields);
Status AddIndexFields(const std::vector<std::string> &fields);

/// \brief get the schema
/// \return the schema
@@ -79,9 +79,10 @@ class __attribute__((visibility("default"))) ShardHeader {
std::shared_ptr<Index> GetIndex();

/// \brief get the schema by schemaid
/// \param[in] schemaId the id of schema needs to be got
/// \return the schema obtained by schemaId
std::pair<std::shared_ptr<Schema>, MSRStatus> GetSchemaByID(int64_t schema_id);
/// \param[in] schema_id the id of schema needs to be got
/// \param[in] schema_ptr the schema obtained by schemaId
/// \return Status
Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr);

/// \brief get the filepath to shard by shardID
/// \param[in] shardID the id of shard which filepath needs to be obtained
@@ -89,25 +90,26 @@ class __attribute__((visibility("default"))) ShardHeader {
std::string GetShardAddressByID(int64_t shard_id);

/// \brief get the statistic by statistic id
/// \param[in] statisticId the id of statistic needs to be get
/// \return the statistics obtained by statistic id
std::pair<std::shared_ptr<Statistics>, MSRStatus> GetStatisticByID(int64_t statistic_id);
/// \param[in] statistic_id the id of statistic needs to be get
/// \param[in] statistics_ptr the statistics obtained by statistic id
/// \return Status
Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr);

MSRStatus InitByFiles(const std::vector<std::string> &file_paths);
Status InitByFiles(const std::vector<std::string> &file_paths);

void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }

std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id);
Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr);

MSRStatus SetPage(const std::shared_ptr<Page> &new_page);
Status SetPage(const std::shared_ptr<Page> &new_page);

MSRStatus AddPage(const std::shared_ptr<Page> &new_page);
Status AddPage(const std::shared_ptr<Page> &new_page);

int64_t GetLastPageId(const int &shard_id);

int GetLastPageIdByType(const int &shard_id, const std::string &page_type);

const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id);
Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr);

std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }

@@ -129,43 +131,41 @@ class __attribute__((visibility("default"))) ShardHeader {

std::vector<std::string> SerializeHeader();

MSRStatus PagesToFile(const std::string dump_file_name);
Status PagesToFile(const std::string dump_file_name);

MSRStatus FileToPages(const std::string dump_file_name);
Status FileToPages(const std::string dump_file_name);

static MSRStatus Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);
static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);

private:
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
Status InitializeHeader(const std::vector<json> &headers, bool load_dataset);

/// \brief get the headers from all the shard data
/// \param[in] the shard data real path
/// \param[in] the headers which read from the shard data
/// \return SUCCESS/FAILED
MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);
Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);

MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);

/// \brief check the binary file status
static MSRStatus CheckFileStatus(const std::string &path);
static Status CheckFileStatus(const std::string &path);

static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);

void ParseHeader(const json &header);
static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr);

void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses);

MSRStatus ParseIndexFields(const json &index_fields);
Status ParseIndexFields(const json &index_fields);

MSRStatus CheckIndexField(const std::string &field, const json &schema);
Status CheckIndexField(const std::string &field, const json &schema);

MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset);
Status ParsePage(const json &page, int shard_index, bool load_dataset);

MSRStatus ParseStatistics(const json &statistics);
Status ParseStatistics(const json &statistics);

MSRStatus ParseSchema(const json &schema);
Status ParseSchema(const json &schema);

void ParseShardAddress(const json &address);

@@ -181,7 +181,7 @@ class __attribute__((visibility("default"))) ShardHeader {

std::shared_ptr<Index> InitIndexPtr();

MSRStatus GetAllSchemaID(std::set<uint64_t> &bucket_count);
Status GetAllSchemaID(std::set<uint64_t> &bucket_count);

uint32_t shard_count_;
uint64_t header_size_;


+ 28
- 28
mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h View File

@@ -30,23 +30,24 @@

namespace mindspore {
namespace mindrecord {
using INDEX_FIELDS = std::pair<MSRStatus, std::vector<std::tuple<std::string, std::string, std::string>>>;
using ROW_DATA = std::pair<MSRStatus, std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>>;
using INDEX_FIELDS = std::vector<std::tuple<std::string, std::string, std::string>>;
using ROW_DATA = std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>;
class __attribute__((visibility("default"))) ShardIndexGenerator {
public:
explicit ShardIndexGenerator(const std::string &file_path, bool append = false);

MSRStatus Build();
Status Build();

static std::pair<MSRStatus, std::string> GenerateFieldName(const std::pair<uint64_t, std::string> &field);
static Status GenerateFieldName(const std::pair<uint64_t, std::string> &field, std::shared_ptr<std::string> *fn_ptr);

~ShardIndexGenerator() {}

/// \brief fetch value in json by field name
/// \param[in] field
/// \param[in] input
/// \return pair<MSRStatus, value>
std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input);
/// \param[in] value
/// \return Status
Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value);

/// \brief fetch field type in schema n by field path
/// \param[in] field_path
@@ -55,55 +56,54 @@ class __attribute__((visibility("default"))) ShardIndexGenerator {
static std::string TakeFieldType(const std::string &field_path, json schema);

/// \brief create databases for indexes
MSRStatus WriteToDatabase();
Status WriteToDatabase();

static MSRStatus Finalize(const std::vector<std::string> file_names);
static Status Finalize(const std::vector<std::string> file_names);

private:
static int Callback(void *not_used, int argc, char **argv, char **az_col_name);

static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = "");
static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = "");

static std::string ConvertJsonToSQL(const std::string &json);

std::pair<MSRStatus, sqlite3 *> CreateDatabase(int shard_no);
Status CreateDatabase(int shard_no, sqlite3 **db);

std::pair<MSRStatus, std::vector<json>> GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in);
Status GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
std::shared_ptr<std::vector<json>> *detail_ptr);

static std::pair<MSRStatus, std::string> GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields);
static Status GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
std::shared_ptr<std::string> *sql_ptr);

std::pair<MSRStatus, sqlite3 *> CheckDatabase(const std::string &shard_address);
Status CheckDatabase(const std::string &shard_address, sqlite3 **db);

///
/// \param shard_no
/// \param blob_id_to_page_id
/// \param raw_page_id
/// \param in
/// \return field name, db type, field value
ROW_DATA GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
std::fstream &in);
/// \return Status
Status GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, std::fstream &in,
std::shared_ptr<ROW_DATA> *row_data_ptr);
///
/// \param db
/// \param sql
/// \param data
/// \return
MSRStatus BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data);

INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
Status GenerateIndexFields(const std::vector<json> &schema_detail, std::shared_ptr<INDEX_FIELDS> *index_fields_ptr);

MSRStatus ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id);

MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name);

MSRStatus AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
std::fstream &in);
Status AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in);

void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
Status AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);

void DatabaseWriter(); // worker thread



+ 16
- 21
mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h View File

@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_

@@ -28,33 +27,29 @@ class __attribute__((visibility("default"))) ShardOperator {
public:
virtual ~ShardOperator() = default;

MSRStatus operator()(ShardTaskList &tasks) {
if (SUCCESS != this->PreExecute(tasks)) {
return FAILED;
}
if (SUCCESS != this->Execute(tasks)) {
return FAILED;
}
if (SUCCESS != this->SufExecute(tasks)) {
return FAILED;
}
return SUCCESS;
Status operator()(ShardTaskList &tasks) {
RETURN_IF_NOT_OK(this->PreExecute(tasks));
RETURN_IF_NOT_OK(this->Execute(tasks));
RETURN_IF_NOT_OK(this->SufExecute(tasks));
return Status::OK();
}

virtual bool HasChildOp() { return child_op_ != nullptr; }

virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) child_op_ = child_op;
return SUCCESS;
virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) {
child_op_ = child_op;
}
return Status::OK();
}

virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }

virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); }

virtual MSRStatus Execute(ShardTaskList &tasks) = 0;
virtual Status Execute(ShardTaskList &tasks) = 0;

virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }

virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }

@@ -72,9 +67,9 @@ class __attribute__((visibility("default"))) ShardOperator {
std::shared_ptr<ShardOperator> child_op_ = nullptr;

// indicate shard_id : inc_count
// 0 : 15 - shard0 has 15 samples
// 1 : 41 - shard1 has 26 samples
// 2 : 58 - shard2 has 17 samples
// // 0 : 15 - shard0 has 15 samples
// // 1 : 41 - shard1 has 26 samples
// // 2 : 58 - shard2 has 17 samples
std::vector<uint32_t> shard_sample_count_;

dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h View File

@@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor

~ShardPkSample() override{};

MSRStatus SufExecute(ShardTaskList &tasks) override;
Status SufExecute(ShardTaskList &tasks) override;

int64_t GetNumSamples() const { return num_samples_; }



+ 62
- 80
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h View File

@@ -59,12 +59,9 @@

namespace mindspore {
namespace mindrecord {
using ROW_GROUPS =
std::tuple<MSRStatus, std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
using ROW_GROUP_BRIEF =
std::tuple<MSRStatus, std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
using TASK_RETURN_CONTENT =
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>;
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode

class API_PUBLIC ShardReader {
@@ -82,21 +79,10 @@ class API_PUBLIC ShardReader {
/// \param[in] num_padded the number of padded samples
/// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
bool lazy_load = false);

/// \brief open files and initialize reader, python API
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \return MSRStatus the status of MSRStatus
MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
bool lazy_load = false);

/// \brief close reader
/// \return null
@@ -104,16 +90,16 @@ class API_PUBLIC ShardReader {

/// \brief read the file, get schema meta,statistics and index, single-thread mode
/// \return MSRStatus the status of MSRStatus
MSRStatus Open();
Status Open();

/// \brief read the file, get schema meta,statistics and index, multiple-thread mode
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(int n_consumer);
Status Open(int n_consumer);

/// \brief launch threads to get batches
/// \param[in] is_simple_reader trigger threads if false; do nothing if true
/// \return MSRStatus the status of MSRStatus
MSRStatus Launch(bool is_simple_reader = false);
Status Launch(bool is_simple_reader = false);

/// \brief aim to get the meta data
/// \return the metadata
@@ -133,8 +119,8 @@ class API_PUBLIC ShardReader {
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus
MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded);
Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded);

/// \brief shuffle task with incremental seed
/// \return void
@@ -162,8 +148,8 @@ class API_PUBLIC ShardReader {
/// 3. Offset address of row group in file
/// 4. The list of image offset in page [startOffset, endOffset)
/// 5. The list of columns data
ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id,
const std::vector<std::string> &columns = std::vector<std::string>());
Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);

/// \brief Read 1 row group data, excluding images, following an index field criteria
/// \param[in] groupID row group ID
@@ -176,8 +162,9 @@ class API_PUBLIC ShardReader {
/// 3. Offset address of row group in file
/// 4. The list of image offset in page [startOffset, endOffset)
/// 5. The list of columns data
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns = std::vector<std::string>());
Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns,
std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);

/// \brief return a batch, given that one is ready
/// \return a batch of images and image data
@@ -185,13 +172,7 @@ class API_PUBLIC ShardReader {

/// \brief return a row by id
/// \return a batch of images and image data
std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>> GetNextById(const int64_t &task_id,
const int32_t &consumer_id);

/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();

TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id);
/// \brief get blob filed list
/// \return blob field list
std::pair<ShardType, std::vector<std::string>> GetBlobFields();
@@ -205,83 +186,86 @@ class API_PUBLIC ShardReader {
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }

/// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);

/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);

/// \brief get a read-only ptr to the sampled ids for this epoch
const std::vector<int> *GetSampleIds();

/// \brief get the size of blob data
Status GetTotalBlobSize(int64_t *total_blob_size);

/// \brief extract uncompressed data based on column list
Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr);

protected:
/// \brief sqlite call back function
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);

private:
/// \brief wrap up labels to json format
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id,
const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);

/// \brief read all rows for specified columns
ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns);
Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr);

/// \brief read row meta by shard_id and sample_id
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
const uint32_t &sample_id);
Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr);

/// \brief read all rows in one shard
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);

/// \brief initialize reader
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
Status Init(const std::vector<std::string> &file_paths, bool load_dataset);

/// \brief validate column list
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
Status CheckColumnList(const std::vector<std::string> &selected_columns);

/// \brief populate one row by task list in row-reader mode
MSRStatus ConsumerByRow(int consumer_id);
void ConsumerByRow(int consumer_id);

/// \brief get offset address of images within page
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
const std::pair<std::string, std::string> &criteria = {"", ""});

/// \brief get page id by category
std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id,
const std::pair<std::string, std::string> &criteria);
Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
std::shared_ptr<std::vector<uint64_t>> *pages_ptr);
/// \brief execute sqlite query with prepare statement
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
/// \brief verify the validity of dataset
MSRStatus VerifyDataset(sqlite3 **db, const string &file);
Status VerifyDataset(sqlite3 **db, const string &file);

/// \brief get column values
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria = {"", ""});
Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr);

/// \brief get column values from raw data page
std::pair<MSRStatus, std::vector<json>> GetLabelsFromPage(int group_id, int shard_id,
const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria = {"",
""});
Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria,
std::shared_ptr<std::vector<json>> *labels_ptr);

/// \brief create category-applied task list
MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);

/// \brief create task list in row-reader mode
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);

/// \brief create task list in row-reader mode and lazy mode
MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);

/// \brief crate task list
MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);

/// \brief check if all specified columns are in index table
void CheckIfColumnInIndex(const std::vector<std::string> &columns);
@@ -290,11 +274,12 @@ class API_PUBLIC ShardReader {
void FileStreamsOperator();

/// \brief read one row by one task
TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id);
Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt);

/// \brief get labels from binary file
std::pair<MSRStatus, std::vector<json>> GetLabelsFromBinaryFile(
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
const std::vector<std::vector<std::string>> &label_offsets,
std::shared_ptr<std::vector<json>> *labels_ptr);

/// \brief get classes in one shard
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
@@ -304,11 +289,8 @@ class API_PUBLIC ShardReader {
int64_t GetNumClasses(const std::string &category_field);

/// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path,
std::shared_ptr<json> meta_data_ptr);

/// \brief extract uncompressed data based on column list
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
std::shared_ptr<std::vector<std::string>> *addresses_ptr);

protected:
uint64_t header_size_; // header size


+ 3
- 3
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h View File

@@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator

~ShardSample() override{};

MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;

MSRStatus UpdateTasks(ShardTaskList &tasks, int taking);
Status UpdateTasks(ShardTaskList &tasks, int taking);

MSRStatus SufExecute(ShardTaskList &tasks) override;
Status SufExecute(ShardTaskList &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;



+ 0
- 9
mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h View File

@@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Schema {
/// \param[in] schema the schema's json
static std::shared_ptr<Schema> Build(std::string desc, const json &schema);

/// \brief obtain the json schema and its description for python
/// \param[in] desc the description of the schema
/// \param[in] schema the schema's json
static std::shared_ptr<Schema> Build(std::string desc, pybind11::handle schema);

/// \brief compare two schema to judge if they are equal
/// \param b another schema to be judged
/// \return true if they are equal,false if not
@@ -57,10 +52,6 @@ class __attribute__((visibility("default"))) Schema {
/// \return the json format of the schema and its description
json GetSchema() const;

/// \brief get the schema and its description for python method
/// \return the python object of the schema and its description
pybind11::object GetSchemaForPython() const;

/// set the schema id
/// \param[in] id the id need to be set
void SetSchemaID(int64_t id);


+ 19
- 20
mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_

#include <memory>
#include <string>
#include <tuple>
#include <utility>
@@ -25,6 +26,10 @@

namespace mindspore {
namespace mindrecord {
using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>;
using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>;
using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>;

class __attribute__((visibility("default"))) ShardSegment : public ShardReader {
public:
ShardSegment();
@@ -33,12 +38,12 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader {

/// \brief Get candidate category fields
/// \return a list of fields names which are the candidates of category
std::pair<MSRStatus, vector<std::string>> GetCategoryFields();
Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr);

/// \brief Set category field
/// \param[in] category_field category name
/// \return true if category name is existed
MSRStatus SetCategoryField(std::string category_field);
Status SetCategoryField(std::string category_field);

/// \brief Thread-safe implementation of ReadCategoryInfo
/// \return statistics data in json format with 2 field: "key" and "categories".
@@ -50,47 +55,41 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader {
/// { "key": "label",
/// "categories": [ { "count": 3, "id": 0, "name": "sport", },
/// { "count": 3, "id": 1, "name": "finance", } ] }
std::pair<MSRStatus, std::string> ReadCategoryInfo();
Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr);

/// \brief Thread-safe implementation of ReadAtPageById
/// \param[in] category_id category ID
/// \param[in] page_no page number
/// \param[in] n_rows_of_page rows number in one page
/// \return images array, image is a vector of uint8_t
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageById(int64_t category_id, int64_t page_no,
int64_t n_rows_of_page);
Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr);

/// \brief Thread-safe implementation of ReadAtPageByName
/// \param[in] category_name category Name
/// \param[in] page_no page number
/// \param[in] n_rows_of_page rows number in one page
/// \return images array, image is a vector of uint8_t
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageByName(std::string category_name, int64_t page_no,
int64_t n_rows_of_page);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageById(int64_t category_id,
int64_t page_no,
int64_t n_rows_of_page);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageByName(
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByIdPy(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page);
Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy(
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr);

std::pair<ShardType, std::vector<std::string>> GetBlobFields();

private:
std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo();
Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr);

std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec);

std::string CleanUp(std::string fieldName);

std::pair<MSRStatus, std::vector<uint8_t>> PackImages(int group_id, int shard_id, std::vector<uint64_t> offset);
Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
std::shared_ptr<std::vector<uint8_t>> *images_ptr);

std::vector<std::string> candidate_category_fields_;
std::string current_category_field_;


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h View File

@@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar

~ShardSequentialSample() override{};

MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;



+ 4
- 4
mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h View File

@@ -31,19 +31,19 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator

~ShardShuffle() override{};

MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;

int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;

private:
// Private helper function
MSRStatus CategoryShuffle(ShardTaskList &tasks);
Status CategoryShuffle(ShardTaskList &tasks);

// Keep the file sequence the same but shuffle the data within each file
MSRStatus ShuffleInfile(ShardTaskList &tasks);
Status ShuffleInfile(ShardTaskList &tasks);

// Shuffle the file sequence but keep the order of data within each file
MSRStatus ShuffleFiles(ShardTaskList &tasks);
Status ShuffleFiles(ShardTaskList &tasks);

uint32_t shuffle_seed_;
int64_t no_of_samples_;


+ 0
- 9
mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h View File

@@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Statistics {
/// \param[in] statistics the statistic needs to be saved
static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics);

/// \brief save the statistic from python and its description
/// \param[in] desc the statistic's description
/// \param[in] statistics the statistic needs to be saved
static std::shared_ptr<Statistics> Build(std::string desc, pybind11::handle statistics);

~Statistics() = default;

/// \brief compare two statistics to judge if they are equal
@@ -59,10 +54,6 @@ class __attribute__((visibility("default"))) Statistics {
/// \return json format of the statistic
json GetStatistics() const;

/// \brief get the statistic for python
/// \return the python object of statistics
pybind11::object GetStatisticsForPython() const;

/// \brief decode the bson statistics to json
/// \param[in] encodedStatistics the bson type of statistics
/// \return json type of statistic


+ 61
- 71
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h View File

@@ -55,69 +55,60 @@ class __attribute__((visibility("default"))) ShardWriter {
/// \brief Open file at the beginning
/// \param[in] paths the file names list
/// \param[in] append new data at the end of file if true, otherwise overwrite file
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::vector<std::string> &paths, bool append = false);
/// \return Status
Status Open(const std::vector<std::string> &paths, bool append = false);

/// \brief Open file at the ending
/// \param[in] paths the file names list
/// \return MSRStatus the status of MSRStatus
MSRStatus OpenForAppend(const std::string &path);
Status OpenForAppend(const std::string &path);

/// \brief Write header to disk
/// \return MSRStatus the status of MSRStatus
MSRStatus Commit();
Status Commit();

/// \brief Set file size
/// \param[in] header_size the size of header, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus SetHeaderSize(const uint64_t &header_size);
Status SetHeaderSize(const uint64_t &header_size);

/// \brief Set page size
/// \param[in] page_size the size of page, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus
MSRStatus SetPageSize(const uint64_t &page_size);
Status SetPageSize(const uint64_t &page_size);

/// \brief Set shard header
/// \param[in] header_data the info of header
/// WARNING, only called when file is empty
/// \return MSRStatus the status of MSRStatus
MSRStatus SetShardHeader(std::shared_ptr<ShardHeader> header_data);
Status SetShardHeader(std::shared_ptr<ShardHeader> header_data);

/// \brief write raw data by group size
/// \param[in] raw_data the vector of raw json data, vector format
/// \param[in] blob_data the vector of image data
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);

/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
/// \param[in] blob_data the vector of image data
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);

/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
/// \param[in] blob_data the vector of blob json data, python-handle format
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);

MSRStatus MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);
Status MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);

static MSRStatus Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names);
static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names);

private:
/// \brief write shard header data to disk
MSRStatus WriteShardHeader();
Status WriteShardHeader();

/// \brief erase error data
void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data);
@@ -130,108 +121,107 @@ class __attribute__((visibility("default"))) ShardWriter {
std::map<int, std::string> &err_raw_data);

/// \brief write shard header data to disk
std::tuple<MSRStatus, int, int> ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign);
Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data,
bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr);

/// \brief fill data array in multiple thread run
void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &bin_data);

/// \brief serialized raw data
MSRStatus SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count);
Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data,
uint32_t row_count);

/// \brief write all data parallel
MSRStatus ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief write data shard by shard
MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief break image data up into multiple row groups
MSRStatus CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page,
const std::shared_ptr<Page> &last_blob_page);
Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page,
const std::shared_ptr<Page> &last_blob_page);

/// \brief append partial blob data to previous page
MSRStatus AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page);

/// \brief write new blob data page to disk
MSRStatus NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page);

/// \brief write new blob data page to disk
Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page);

/// \brief shift last row group to next raw page for new appending
MSRStatus ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page);
Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page);

/// \brief write raw data page to disk
MSRStatus WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief generate empty raw data page
void EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);

/// \brief append a row group at the end of raw page
MSRStatus AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
const int &chunk_id, int &last_row_groupId, std::shared_ptr<Page> last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
int &last_row_groupId, std::shared_ptr<Page> last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief write blob chunk to disk
MSRStatus FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
const std::pair<int, int> &blob_row);
Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
const std::pair<int, int> &blob_row);

/// \brief write raw chunk to disk
MSRStatus FlushRawChunk(const std::shared_ptr<std::fstream> &out,
const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group,
const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief break up into tasks by shard
std::vector<std::pair<int, int>> BreakIntoShards();

/// \brief calculate raw data size row by row
MSRStatus SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);

/// \brief calculate blob data size row by row
MSRStatus SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);
Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);

/// \brief populate last raw page pointer
void SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);

/// \brief populate last blob page pointer
void SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);
Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);

/// \brief check the data by schema
MSRStatus CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);
Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);

/// \brief check the data and type
MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);
Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);

/// \brief Lock writer and save pages info
int LockWriter(bool parallel_writer = false);
Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr);

/// \brief Unlock writer and save pages info
MSRStatus UnlockWriter(int fd, bool parallel_writer = false);
Status UnlockWriter(int fd, bool parallel_writer = false);

/// \brief Check raw data before writing
MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);
Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);

/// \brief Get full path from file name
MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths);
Status GetFullPathFromFileName(const std::vector<std::string> &paths);

/// \brief Open files
MSRStatus OpenDataFiles(bool append);
Status OpenDataFiles(bool append);

/// \brief Remove lock file
MSRStatus RemoveLockFile();
Status RemoveLockFile();

/// \brief Remove lock file
MSRStatus InitLockFile();
Status InitLockFile();

private:
const std::string kLockFileSuffix = "_Locker";


+ 204
- 329
mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc View File

@@ -37,70 +37,48 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
task_(0),
write_success_(true) {}

MSRStatus ShardIndexGenerator::Build() {
auto ret = ShardHeader::BuildSingleHeader(file_path_);
if (ret.first != SUCCESS) {
return FAILED;
}
auto json_header = ret.second;

auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]);
if (SUCCESS != ret2.first) {
return FAILED;
}
Status ShardIndexGenerator::Build() {
std::shared_ptr<json> header_ptr;
RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
auto ds = std::make_shared<std::vector<std::string>>();
RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
ShardHeader header = ShardHeader();
auto addresses = ret2.second;
if (header.BuildDataset(addresses) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(header.BuildDataset(*ds));
shard_header_ = header;
MS_LOG(INFO) << "Init header from mindrecord file for index successfully.";
return SUCCESS;
return Status::OK();
}

std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) {
if (field.empty()) {
MS_LOG(ERROR) << "The input field is None.";
return {FAILED, ""};
}

if (input.empty()) {
MS_LOG(ERROR) << "The input json is None.";
return {FAILED, ""};
}
Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) {
RETURN_UNEXPECTED_IF_NULL(value);
CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty.");

// parameter input does not contain the field
if (input.find(field) == input.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input;
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(),
"The field " + field + " is not found in json " + input.dump());

// schema does not contain the field
auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
"The field " + field + " is not found in schema " + schema.dump());

// field should be scalar type
if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) {
MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable";
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(
kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
"The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");

if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
auto schema_field_options = schema[field];
if (schema_field_options.find("shape") == schema_field_options.end()) {
return {SUCCESS, input[field].dump()};
} else {
// field with shape option
MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable";
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(
schema_field_options.find("shape") == schema_field_options.end(),
"The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable.");
*value = std::make_shared<std::string>(input[field].dump());
} else {
// the field type is string in here
*value = std::make_shared<std::string>(input[field].get<std::string>());
}

// the field type is string in here
return {SUCCESS, input[field].get<std::string>()};
return Status::OK();
}

std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
@@ -150,24 +128,28 @@ int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **
return 0;
}

MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
char *z_err_msg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Sql error: " << z_err_msg;
std::ostringstream oss;
oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg;
MS_LOG(DEBUG) << oss.str();
sqlite3_free(z_err_msg);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED(oss.str());
} else {
if (!success_msg.empty()) {
MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg;
MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg;
}
sqlite3_free(z_err_msg);
return SUCCESS;
return Status::OK();
}
}

std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName(
const std::pair<uint64_t, std::string> &field) {
Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
std::shared_ptr<std::string> *fn_ptr) {
RETURN_UNEXPECTED_IF_NULL(fn_ptr);
// Replaces dots and dashes with underscores for SQL use
std::string field_name = field.second;
// white list to avoid sql injection
@@ -176,95 +158,71 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName(
auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
});
if (pos != field_name.end()) {
MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name;
return {FAILED, ""};
}
return {SUCCESS, field_name + "_" + std::to_string(field.first)};
CHECK_FAIL_RETURN_UNEXPECTED(
pos == field_name.end(),
"Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name);
*fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
return Status::OK();
}

std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CheckDatabase(const std::string &shard_address) {
Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
auto realpath = Common::GetRealPath(shard_address);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << shard_address;
return {FAILED, nullptr};
}

sqlite3 *db = nullptr;
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
std::ifstream fin(realpath.value());
if (!append_ && fin.good()) {
MS_LOG(ERROR) << "Invalid file, DB file already exist: " << shard_address;
fin.close();
return {FAILED, nullptr};
RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address);
}
fin.close();
int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
if (rc) {
MS_LOG(ERROR) << "Invalid file, failed to open database: " << shard_address << ", error" << sqlite3_errmsg(db);
return {FAILED, nullptr};
} else {
MS_LOG(DEBUG) << "Opened database successfully";
return {SUCCESS, db};
if (sqlite3_open_v2(common::SafeCStr(shard_address), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" +
std::string(sqlite3_errmsg(*db)));
}
MS_LOG(DEBUG) << "Opened database successfully";
return Status::OK();
}

MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
// create shard_name table
std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully."));
sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully."));
sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt != nullptr) {
(void)sqlite3_finalize(stmt);
}
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
}

int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << shard_name;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(shard_name));
}

if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
}

(void)sqlite3_finalize(stmt);
return SUCCESS;
return Status::OK();
}

std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) {
Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no;
return {FAILED, nullptr};
}

string shard_name = GetFileName(shard_address).second;
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no);
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr));
shard_address += ".db";
auto ret1 = CheckDatabase(shard_address);
if (ret1.first != SUCCESS) {
return {FAILED, nullptr};
}
sqlite3 *db = ret1.second;
RETURN_IF_NOT_OK(CheckDatabase(shard_address, db));
std::string sql = "DROP TABLE IF EXISTS INDEXES;";
if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) {
return {FAILED, nullptr};
}
RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully."));
sql =
"CREATE TABLE INDEXES("
" ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL"
@@ -273,95 +231,79 @@ std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no
", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL";

int field_no = 0;
std::shared_ptr<std::string> field_ptr;
for (const auto &field : fields_) {
uint64_t schema_id = field.first;
auto result = shard_header_.GetSchemaByID(schema_id);
if (result.second != SUCCESS) {
return {FAILED, nullptr};
}
json json_schema = (result.first->GetSchema())["schema"];
std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
json json_schema = (schema_ptr->GetSchema())["schema"];
std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, nullptr};
}
sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type;
RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr));
sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
}
sql += ", PRIMARY KEY(ROW_ID";
for (uint64_t i = 0; i < fields_.size(); ++i) {
sql += ",INC_" + std::to_string(i);
}
sql += "));";
if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) {
return {FAILED, nullptr};
}

if (CreateShardNameTable(db, shard_name) != SUCCESS) {
return {FAILED, nullptr};
}
return {SUCCESS, db};
RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully."));
RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr));
return Status::OK();
}

std::pair<MSRStatus, std::vector<json>> ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens,
std::fstream &in) {
std::vector<json> schema_details;
Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
std::shared_ptr<std::vector<json>> *detail_ptr) {
RETURN_UNEXPECTED_IF_NULL(detail_ptr);
if (schema_count_ <= kMaxSchemaCount) {
for (int sc = 0; sc < schema_count_; ++sc) {
std::vector<char> schema_detail(schema_lens[sc]);

auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
in.close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to read file.");
}
schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end())));
auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
(*detail_ptr)->emplace_back(j);
}
}

return {SUCCESS, schema_details};
return Status::OK();
}

std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
const std::vector<std::pair<uint64_t, std::string>> &fields) {
Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
std::shared_ptr<std::string> *sql_ptr) {
std::string sql =
"INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
"PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";

int field_no = 0;
for (const auto &field : fields) {
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, ""};
}
sql += ",INC_" + std::to_string(field_no++) + "," + ret.second;
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
}
sql +=
") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
field_no = 0;
for (const auto &field : fields) {
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, ""};
}
sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second;
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
}
sql += " )";
return {SUCCESS, sql};

*sql_ptr = std::make_shared<std::string>(sql);
return Status::OK();
}

MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt != nullptr) {
(void)sqlite3_finalize(stmt);
}
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
}
for (auto &row : data) {
for (auto &field : row) {
@@ -373,45 +315,47 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
if (field_type == "INTEGER") {
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoll(field_value);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
}
} else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stold(field_value);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
}
} else if (field_type == "NULL") {
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
return FAILED;

sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: NULL");
}
} else {
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
}
}
}
if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
}
(void)sqlite3_reset(stmt);
}
(void)sqlite3_finalize(stmt);
return SUCCESS;
return Status::OK();
}

MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page,
uint64_t &cur_blob_page_offset, std::fstream &in) {
Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
std::fstream &in) {
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));

// blob data start
@@ -419,89 +363,71 @@ MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::strin
auto &io_seekg_blob =
in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
MS_LOG(ERROR) << "File seekg failed";
in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
}

uint64_t image_size = 0;

auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to read file.");
}

cur_blob_page_offset += (kInt64Len + image_size);
row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));

return SUCCESS;
return Status::OK();
}

void ShardIndexGenerator::AddIndexFieldByRawData(
Status ShardIndexGenerator::AddIndexFieldByRawData(
const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
auto result = GenerateIndexFields(schema_detail);
if (result.first == SUCCESS) {
int index = 0;
for (const auto &field : result.second) {
// assume simple field: string , number etc.
row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
}
}
auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr));
int index = 0;
for (const auto &field : *index_fields_ptr) {
// assume simple field: string , number etc.
row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
}
return Status::OK();
}

ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id,
int raw_page_id, std::fstream &in) {
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;

Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
RETURN_UNEXPECTED_IF_NULL(row_data_ptr);
// current raw data page
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
if (ret1.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_raw_page = ret1.first;

std::shared_ptr<Page> page_ptr;
RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
// related blob page
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();

// pair: row_group id, offset in raw data page
for (pair<int, int> blob_ids : row_group_list) {
// get blob data page according to row_group id
auto iter = blob_id_to_page_id.find(blob_ids.first);
if (iter == blob_id_to_page_id.end()) {
MS_LOG(ERROR) << "Convert blob id failed";
return {FAILED, {}};
}
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
if (ret2.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_blob_page = ret2.first;

CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id.");
std::shared_ptr<Page> blob_page_ptr;
RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
// offset in current raw data page
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
uint64_t cur_blob_page_offset = 0;
for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) {
for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
std::vector<std::tuple<std::string, std::string, std::string>> row_data;
row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID()));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));

// raw data start
row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));

// calculate raw data end
auto &io_seekg =
in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
return {FAILED, {}};
in.close();
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
}

std::vector<uint64_t> schema_lens;
if (schema_count_ <= kMaxSchemaCount) {
for (int sc = 0; sc < schema_count_; sc++) {
@@ -509,8 +435,8 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,

auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
return {FAILED, {}};
in.close();
RETURN_STATUS_UNEXPECTED("Failed to read file.");
}

cur_raw_page_offset += (kInt64Len + schema_size);
@@ -520,122 +446,79 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));

// Getting schema for getting data for fields
auto st_schema_detail = GetSchemaDetails(schema_lens, in);
if (st_schema_detail.first != SUCCESS) {
return {FAILED, {}};
}

auto detail_ptr = std::make_shared<std::vector<json>>();
RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr));
// start blob page info
if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) {
return {FAILED, {}};
}
RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));

// start index field
AddIndexFieldByRawData(st_schema_detail.second, row_data);
full_data.push_back(std::move(row_data));
AddIndexFieldByRawData(*detail_ptr, row_data);
(*row_data_ptr)->push_back(std::move(row_data));
}
}
return {SUCCESS, full_data};
return Status::OK();
}

INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) {
std::vector<std::tuple<std::string, std::string, std::string>> fields;
Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
RETURN_UNEXPECTED_IF_NULL(index_fields_ptr);
// index fields
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
for (const auto &field : index_fields) {
if (field.first >= schema_detail.size()) {
return {FAILED, {}};
}
auto field_value = GetValueByField(field.second, schema_detail[field.first]);
if (field_value.first != SUCCESS) {
MS_LOG(ERROR) << "Get value from json by field name failed";
return {FAILED, {}};
}

auto result = shard_header_.GetSchemaByID(field.first);
if (result.second != SUCCESS) {
return {FAILED, {}};
}

std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"]));
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, {}};
}

fields.emplace_back(ret.second, field_type, field_value.second);
}
return {SUCCESS, std::move(fields)};
CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range.");
std::shared_ptr<std::string> field_val_ptr;
RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr));
std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
(*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
}
return Status::OK();
}

MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
// Add index data to database
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Invalid data, shard address is null";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty.");

auto realpath = Common::GetRealPath(shard_address);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << shard_address;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
std::fstream in;
in.open(realpath.value(), std::ios::in | std::ios::binary);
if (!in.good()) {
MS_LOG(ERROR) << "Invalid file, failed to open file: " << shard_address;
in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address);
}
(void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
(void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
for (int raw_page_id : raw_page_ids) {
auto sql = GenerateRawSQL(fields_);
if (sql.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw SQL failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in);
if (data.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw data failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
MS_LOG(ERROR) << "Execute SQL failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
}
(void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr);
std::shared_ptr<std::string> sql_ptr;
RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in);
auto row_data_ptr = std::make_shared<ROW_DATA>();
RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in);
RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
}
(void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
in.close();

// Close database
if (sqlite3_close(db.second) != SQLITE_OK) {
MS_LOG(ERROR) << "Close database failed";
return FAILED;
}
db.second = nullptr;
return SUCCESS;
sqlite3_close(db);
db = nullptr;
return Status::OK();
}

MSRStatus ShardIndexGenerator::WriteToDatabase() {
Status ShardIndexGenerator::WriteToDatabase() {
fields_ = shard_header_.GetFields();
page_size_ = shard_header_.GetPageSize();
header_size_ = shard_header_.GetHeaderSize();
schema_count_ = shard_header_.GetSchemaCount();
if (shard_header_.GetShardCount() > kMaxShardCount) {
MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount,
"num shards: " + std::to_string(shard_header_.GetShardCount()) +
" exceeds max count:" + std::to_string(kMaxSchemaCount));
task_ = 0; // set two atomic vars to initial value
write_success_ = true;

@@ -653,40 +536,41 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
for (size_t t = 0; t < threads.capacity(); t++) {
threads[t].join();
}
return write_success_ ? SUCCESS : FAILED;
CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db.");
return Status::OK();
}

void ShardIndexGenerator::DatabaseWriter() {
int shard_no = task_++;
while (shard_no < shard_header_.GetShardCount()) {
auto db = CreateDatabase(shard_no);
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
sqlite3 *db = nullptr;
if (CreateDatabase(shard_no, &db).IsError()) {
MS_LOG(ERROR) << "Failed to create Generate database.";
write_success_ = false;
return;
}

MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";

// Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;

std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
auto ret = shard_header_.GetPage(shard_no, i);
if (ret.second != SUCCESS) {
std::shared_ptr<Page> page_ptr;
if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) {
MS_LOG(ERROR) << "Failed to get page.";
write_success_ = false;
return;
}
std::shared_ptr<Page> cur_page = ret.first;
if (cur_page->GetPageType() == "RAW_DATA") {
if (page_ptr->GetPageType() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->GetPageType() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->GetPageTypeID()] = i;
} else if (page_ptr->GetPageType() == "BLOB_DATA") {
blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
}
}

if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) {
MS_LOG(ERROR) << "Failed to execute transaction.";
write_success_ = false;
return;
}
@@ -694,21 +578,12 @@ void ShardIndexGenerator::DatabaseWriter() {
shard_no = task_++;
}
}
MSRStatus ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
if (file_names.empty()) {
MS_LOG(ERROR) << "Mindrecord files is empty.";
return FAILED;
}
Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty.");
ShardIndexGenerator sg{file_names[0]};
if (SUCCESS != sg.Build()) {
MS_LOG(ERROR) << "Failed to build index generator.";
return FAILED;
}
if (SUCCESS != sg.WriteToDatabase()) {
MS_LOG(ERROR) << "Failed to write to database.";
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(sg.Build());
RETURN_IF_NOT_OK(sg.WriteToDatabase());
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 374
- 549
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
File diff suppressed because it is too large
View File


+ 122
- 181
mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc View File

@@ -30,9 +30,13 @@ namespace mindspore {
namespace mindrecord {
ShardSegment::ShardSegment() { SetAllInIndex(false); }

std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
RETURN_UNEXPECTED_IF_NULL(fields_ptr);
// Skip if already populated
if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_};
if (!candidate_category_fields_.empty()) {
*fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
return Status::OK();
}

std::string sql = "PRAGMA table_info(INDEXES);";
std::vector<std::vector<std::string>> field_names;
@@ -40,11 +44,12 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
char *errmsg = nullptr;
int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
std::ostringstream oss;
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
sqlite3_free(errmsg);
sqlite3_close(database_paths_[0]);
database_paths_[0] = nullptr;
return {FAILED, vector<std::string>{}};
RETURN_STATUS_UNEXPECTED(oss.str());
} else {
MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index.";
}
@@ -55,53 +60,46 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
sqlite3_free(errmsg);
sqlite3_close(database_paths_[0]);
database_paths_[0] = nullptr;
return {FAILED, vector<std::string>{}};
RETURN_STATUS_UNEXPECTED("idx is out of range.");
}
candidate_category_fields_.push_back(field_names[idx][1]);
idx += 2;
}
sqlite3_free(errmsg);
return {SUCCESS, candidate_category_fields_};
*fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
return Status::OK();
}

MSRStatus ShardSegment::SetCategoryField(std::string category_field) {
if (GetCategoryFields().first != SUCCESS) {
MS_LOG(ERROR) << "Get candidate category field failed";
return FAILED;
}
Status ShardSegment::SetCategoryField(std::string category_field) {
std::shared_ptr<vector<std::string>> fields_ptr;
RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr));
category_field = category_field + "_0";
if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
[category_field](std::string x) { return x == category_field; })) {
current_category_field_ = category_field;
return SUCCESS;
return Status::OK();
}
MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field.";
return FAILED;
RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field.");
}

std::pair<MSRStatus, std::string> ShardSegment::ReadCategoryInfo() {
Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
RETURN_UNEXPECTED_IF_NULL(category_ptr);
MS_LOG(INFO) << "Read category begin";
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info failed";
return {FAILED, ""};
}
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
// Convert category info to json string
auto category_json_string = ToJsonForCategory(ret.second);
*category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));

MS_LOG(INFO) << "Read category end";

return {SUCCESS, category_json_string};
return Status::OK();
}

std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegment::WrapCategoryInfo() {
Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
RETURN_UNEXPECTED_IF_NULL(category_info_ptr);
std::map<std::string, int> counter;

if (!ValidateFieldName(current_category_field_)) {
MS_LOG(ERROR) << "category field error from index, it is: " << current_category_field_;
return {FAILED, std::vector<std::tuple<int, std::string, int>>()};
}

CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_),
"Category field error from index, it is: " + current_category_field_);
std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";

@@ -109,13 +107,13 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen
std::vector<std::vector<std::string>> field_count;

char *errmsg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
std::ostringstream oss;
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
sqlite3_free(errmsg);
sqlite3_close(db);
db = nullptr;
return {FAILED, std::vector<std::tuple<int, std::string, int>>()};
RETURN_STATUS_UNEXPECTED(oss.str());
} else {
MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index.";
}
@@ -127,14 +125,14 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen
}

int idx = 0;
std::vector<std::tuple<int, std::string, int>> category_vec(counter.size());
(void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple<std::string, int> item) {
return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item));
});
return {SUCCESS, std::move(category_vec)};
(*category_info_ptr)->resize(counter.size());
(void)std::transform(
counter.begin(), counter.end(), (*category_info_ptr)->begin(),
[&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
return Status::OK();
}

std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec) {
std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
std::vector<json> category_json_vec;
for (auto q : tri_vec) {
json j;
@@ -152,27 +150,20 @@ std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, st
return category_info.dump();
}

std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageById(int64_t category_id,
int64_t page_no,
int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
if (category_id >= static_cast<int>(ret.second.size()) || category_id < 0) {
MS_LOG(ERROR) << "Illegal category id, id: " << category_id;
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
int total_rows_in_category = std::get<2>(ret.second[category_id]);
Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
"Invalid category id, id: " + std::to_string(category_id));
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
// Quit if category not found or page number is out of range
if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 ||
page_no * n_rows_of_page >= total_rows_in_category) {
MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page;
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
page_no * n_rows_of_page < total_rows_in_category,
"Invalid page no / page size, page no: " + std::to_string(page_no) +
", page size: " + std::to_string(n_rows_of_page));

std::vector<std::vector<uint8_t>> page;
auto row_group_summary = ReadRowGroupSummary();

uint64_t i_start = page_no * n_rows_of_page;
@@ -183,12 +174,12 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage

auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(
group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id])));
if (SUCCESS != std::get<0>(details)) {
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
auto offsets = std::get<4>(details);
std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
RETURN_IF_NOT_OK(ReadRowGroupCriteria(
group_id, shard_id,
std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
&row_group_brief_ptr));
auto offsets = std::get<3>(*row_group_brief_ptr);
uint64_t number_of_rows = offsets.size();
if (idx + number_of_rows < i_start) {
idx += number_of_rows;
@@ -197,131 +188,116 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage

for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
if (idx >= i_start && idx < i_end) {
auto ret1 = PackImages(group_id, shard_id, offsets[i]);
if (SUCCESS != ret1.first) {
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
page.push_back(std::move(ret1.second));
auto images_ptr = std::make_shared<std::vector<uint8_t>>();
RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
(*page_ptr)->push_back(std::move(*images_ptr));
}
}
}

return {SUCCESS, std::move(page)};
return Status::OK();
}

std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id, int shard_id,
std::vector<uint64_t> offset) {
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
return {FAILED, std::vector<uint8_t>()};
}
const std::shared_ptr<Page> &blob_page = ret.second;

Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
RETURN_UNEXPECTED_IF_NULL(images_ptr);
std::shared_ptr<Page> page_ptr;
RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
// Pack image list
std::vector<uint8_t> images(offset[1] - offset[0]);
auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0];
(*images_ptr)->resize(offset[1] - offset[0]);

auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
file_streams_random_[0][shard_id]->close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
}

auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&images[0]), offset[1] - offset[0]);
auto &io_read =
file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
file_streams_random_[0][shard_id]->close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to read file.");
}

return {SUCCESS, std::move(images)};
return Status::OK();
}

std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageByName(std::string category_name,
int64_t page_no,
int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
for (const auto &categories : ret.second) {
Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
for (const auto &categories : *category_info_ptr) {
if (std::get<1>(categories) == category_name) {
auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page);
return result;
RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
return Status::OK();
}
}

return {FAILED, std::vector<std::vector<uint8_t>>()};
RETURN_STATUS_UNEXPECTED("Category name can not match.");
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageById(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS || category_id >= static_cast<int>(ret.second.size())) {
MS_LOG(ERROR) << "Illegal category id, id: " << category_id;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
int total_rows_in_category = std::get<2>(ret.second[category_id]);
// Quit if category not found or page number is out of range
if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) {
MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()),
"Invalid category id: " + std::to_string(category_id));

std::vector<std::tuple<std::vector<uint8_t>, json>> page;
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
// Quit if category not found or page number is out of range
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
page_no * n_rows_of_page < total_rows_in_category,
"Invalid page no / page size / total size, page no: " + std::to_string(page_no) +
", page size of page: " + std::to_string(n_rows_of_page) +
", total size: " + std::to_string(total_rows_in_category));
auto row_group_summary = ReadRowGroupSummary();

int i_start = page_no * n_rows_of_page;
int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
int idx = 0;
for (const auto &rg : row_group_summary) {
if (idx >= i_end) break;
if (idx >= i_end) {
break;
}

auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(
group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id])));
if (SUCCESS != std::get<0>(details)) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
auto offsets = std::get<4>(details);
auto labels = std::get<5>(details);
std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
RETURN_IF_NOT_OK(ReadRowGroupCriteria(
group_id, shard_id,
std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
&row_group_brief_ptr));
auto offsets = std::get<3>(*row_group_brief_ptr);
auto labels = std::get<4>(*row_group_brief_ptr);

int number_of_rows = offsets.size();
if (idx + number_of_rows < i_start) {
idx += number_of_rows;
continue;
}

if (number_of_rows > static_cast<int>(labels.size())) {
MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()),
"Invalid row number of page: " + number_of_rows);
for (int i = 0; i < number_of_rows; ++i, ++idx) {
if (idx >= i_start && idx < i_end) {
auto ret1 = PackImages(group_id, shard_id, offsets[i]);
if (SUCCESS != ret1.first) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
page.emplace_back(std::move(ret1.second), std::move(labels[i]));
auto images_ptr = std::make_shared<std::vector<uint8_t>>();
RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
(*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i]));
}
}
}
return {SUCCESS, std::move(page)};
return Status::OK();
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageByName(
std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}

Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
// category_name to category_id
int64_t category_id = -1;
for (const auto &categories : ret.second) {
for (const auto &categories : *category_info_ptr) {
std::string categories_name = std::get<1>(categories);

if (categories_name == category_name) {
@@ -329,45 +305,8 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
break;
}
}

if (category_id == -1) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}

return ReadAllAtPageById(category_id, page_no, n_rows_of_page);
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page);
if (res.first != SUCCESS) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}};
}

vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data;
std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return {SUCCESS, std::move(json_data)};
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByNamePy(
std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page);
if (res.first != SUCCESS) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}};
}
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data;
std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return {SUCCESS, std::move(json_data)};
CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name.");
return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
}

std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
@@ -382,7 +321,9 @@ std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
}

std::string ShardSegment::CleanUp(std::string field_name) {
while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back();
while (field_name.back() >= '0' && field_name.back() <= '9') {
field_name.pop_back();
}
field_name.pop_back();
return field_name;
}


+ 272
- 496
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
File diff suppressed because it is too large
View File


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc View File

@@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
num_categories_(num_categories),
replacement_(replacement) {}

MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
Status ShardCategory::Execute(ShardTaskList &tasks) { return Status::OK(); }

int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;


+ 74
- 100
mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc View File

@@ -72,36 +72,36 @@ void ShardColumn::Init(const json &schema_json, bool compress_integer) {
num_blob_column_ = blob_column_.size();
}

std::pair<MSRStatus, ColumnCategory> ShardColumn::GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type,
uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type,
uint64_t *column_data_type_size, std::vector<int64_t> *column_shape,
ColumnCategory *column_category) {
RETURN_UNEXPECTED_IF_NULL(column_data_type);
RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
RETURN_UNEXPECTED_IF_NULL(column_shape);
RETURN_UNEXPECTED_IF_NULL(column_category);
// Skip if column not found
auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return {FAILED, ColumnNotFound};
}
*column_category = CheckColumnName(column_name);
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category.");

// Get data type and size
auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id];
*column_data_type_size = ColumnDataTypeSize[*column_data_type];
*column_shape = column_shape_[column_id];

return {SUCCESS, column_category};
return Status::OK();
}

MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
RETURN_UNEXPECTED_IF_NULL(column_data_type);
RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
RETURN_UNEXPECTED_IF_NULL(column_shape);
// Skip if column not found
auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category.");
// Get data type and size
auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id];
@@ -110,37 +110,31 @@ MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, cons

// Retrieve value from json
if (column_category == ColumnInRaw) {
if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << ".";
return FAILED;
}
RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes));
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
return SUCCESS;
return Status::OK();
}

// Retrieve value from blob
if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << ".";
return FAILED;
}
RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes));
if (*data == nullptr) {
*data = reinterpret_cast<const unsigned char *>(data_ptr->get());
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
RETURN_UNEXPECTED_IF_NULL(n_bytes);
RETURN_UNEXPECTED_IF_NULL(data_ptr);
auto column_id = column_name_id_[column_name];
auto column_data_type = column_data_type_[column_id];

// Initialize num bytes
*n_bytes = ColumnDataTypeSize[column_data_type];
auto json_column_value = columns_json[column_name];
if (!json_column_value.is_string() && !json_column_value.is_number()) {
MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ").";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(),
"Conversion to string or number failed (" + json_column_value.dump() + ").");
switch (column_data_type) {
case ColumnFloat32: {
return GetFloat<float>(data_ptr, json_column_value, false);
@@ -171,12 +165,13 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j
break;
}
}
return SUCCESS;
return Status::OK();
}

template <typename T>
MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
if (json_column_value.is_number()) {
array_data[0] = json_column_value;
@@ -189,8 +184,7 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons
array_data[0] = json_column_value.get<float>();
}
} catch (json::exception &e) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ").");
}
}

@@ -199,54 +193,43 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}

return SUCCESS;
return Status::OK();
}

template <typename T>
MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
int64_t temp_value;
bool less_than_zero = false;

if (json_column_value.is_number_integer()) {
const json json_zero = 0;
if (json_column_value < json_zero) less_than_zero = true;
if (json_column_value < json_zero) {
less_than_zero = true;
}
temp_value = json_column_value;
} else if (json_column_value.is_string()) {
std::string string_value = json_column_value;

if (!string_value.empty() && string_value[0] == '-') {
try {
try {
if (!string_value.empty() && string_value[0] == '-') {
temp_value = std::stoll(string_value);
less_than_zero = true;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
} else {
try {
} else {
temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
}
} else {
MS_LOG(ERROR) << "Conversion to int failed.";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
}

if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
MS_LOG(ERROR) << "Conversion to int failed. Out of range";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
}
array_data[0] = static_cast<T>(temp_value);

@@ -255,33 +238,26 @@ MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const
for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i);
}

return SUCCESS;
return Status::OK();
}

MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes) {
Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes) {
RETURN_UNEXPECTED_IF_NULL(data);
uint64_t offset_address = 0;
auto column_id = column_name_id_[column_name];
if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) {
return FAILED;
}

RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address));
auto column_data_type = column_data_type_[column_id];
if (has_compress_blob_ && column_data_type == ColumnInt32) {
if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
} else if (has_compress_blob_ && column_data_type == ColumnInt64) {
if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
} else {
*data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
}

return SUCCESS;
return Status::OK();
}

ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
@@ -296,7 +272,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
// Skip if no compress columns
*compression_size = 0;
if (!CheckCompressBlob()) return blob;
if (!CheckCompressBlob()) {
return blob;
}

std::vector<uint8_t> dst_blob;
uint64_t i_src = 0;
@@ -380,12 +358,14 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
return dst_bytes;
}

MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
RETURN_UNEXPECTED_IF_NULL(num_bytes);
RETURN_UNEXPECTED_IF_NULL(shift_idx);
if (num_blob_column_ == 1) {
*num_bytes = columns_blob.size();
*shift_idx = 0;
return SUCCESS;
return Status::OK();
}
auto blob_id = blob_column_id_[column_name_[column_id]];

@@ -396,13 +376,14 @@ MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const

(*shift_idx) += kInt64Len;

return SUCCESS;
return Status::OK();
}

template <typename T>
MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes,
uint64_t shift_idx) {
Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
RETURN_UNEXPECTED_IF_NULL(num_bytes);
auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
*num_bytes = sizeof(T) * num_elements;

@@ -421,19 +402,12 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<

auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(*num_bytes);

// field is none. for example: numpy is null
if (*num_bytes == 0) {
return SUCCESS;
return Status::OK();
}

int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes);
if (ret_code != 0) {
MS_LOG(ERROR) << "Failed to copy data!";
return FAILED;
}

return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!");
return Status::OK();
}

uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,


+ 5
- 11
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc View File

@@ -55,15 +55,11 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
return 0;
}

MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
auto total_no = tasks.Size();
if (no_of_padded_samples_ > 0 && first_epoch_) {
if (total_no % denominator_ != 0) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. "
<< "task size: " << total_no << ", number padded: " << no_of_padded_samples_
<< ", denominator: " << denominator_;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0,
"Dataset size plus number of padded samples is not divisible by number of shards.");
}
if (first_epoch_) {
first_epoch_ = false;
@@ -74,11 +70,9 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
if (shuffle_ == true) {
shuffle_op_->SetShardSampleCount(GetShardSampleCount());
shuffle_op_->UpdateShuffleMode(GetShuffleMode());
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
}
return SUCCESS;
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 166
- 316
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc View File

@@ -38,104 +38,74 @@ ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), co
index_ = std::make_shared<Index>();
}

MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size();
int shard_index = 0;
bool first = true;
for (const auto &header : headers) {
if (first) {
first = false;
if (ParseSchema(header["schema"]) != SUCCESS) {
return FAILED;
}
if (ParseIndexFields(header["index_fields"]) != SUCCESS) {
return FAILED;
}
if (ParseStatistics(header["statistics"]) != SUCCESS) {
return FAILED;
}
RETURN_IF_NOT_OK(ParseSchema(header["schema"]));
RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"]));
RETURN_IF_NOT_OK(ParseStatistics(header["statistics"]));
ParseShardAddress(header["shard_addresses"]);
header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>();
compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
}
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED;
}
RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset));
shard_index++;
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::CheckFileStatus(const std::string &path) {
Status ShardHeader::CheckFileStatus(const std::string &path) {
auto realpath = Common::GetRealPath(path);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << path;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path);
std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
if (!fin) {
MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path;
return FAILED;
}
if (fin.fail()) {
MS_LOG(ERROR) << "Failed to open file. path: " << path;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path);
// fetch file size
auto &io_seekg = fin.seekg(0, std::ios::end);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
fin.close();
MS_LOG(ERROR) << "File seekg failed. path: " << path;
return FAILED;
RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path);
}

size_t file_size = fin.tellg();
if (file_size < kMinFileSize) {
fin.close();
MS_LOG(ERROR) << "Invalid file. path: " << path;
return FAILED;
RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path);
}
fin.close();
return SUCCESS;
return Status::OK();
}

std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) {
if (CheckFileStatus(path) != SUCCESS) {
return {FAILED, {}};
}

Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
RETURN_IF_NOT_OK(CheckFileStatus(path));
// read header size
json json_header;
std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary);
if (!fin.is_open()) {
MS_LOG(ERROR) << "File seekg failed. path: " << path;
return {FAILED, json_header};
}
CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path);

uint64_t header_size = 0;
auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
fin.close();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("File read failed");
}

if (header_size > kMaxHeaderSize) {
fin.close();
MS_LOG(ERROR) << "Invalid file content. path: " << path;
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path);
}

// read header content
std::vector<uint8_t> header_content(header_size);
auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
MS_LOG(ERROR) << "File read failed. path: " << path;
fin.close();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("File read failed. path: " + path);
}

fin.close();
@@ -144,34 +114,35 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
try {
json_header = json::parse(raw_header_content);
} catch (json::parse_error &e) {
MS_LOG(ERROR) << "Json parse error: " << e.what();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what()));
}
return {SUCCESS, json_header};
*header_ptr = std::make_shared<json>(json_header);
return Status::OK();
}

std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
auto ret = ValidateHeader(file_path);
if (SUCCESS != ret.first) {
return {FAILED, json()};
}
json raw_header = ret.second;
Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
std::shared_ptr<json> raw_header;
RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header));
uint64_t compression_size =
raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]},
raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
{"header_size", (*raw_header)["header_size"]},
{"page_size", (*raw_header)["page_size"]},
{"compression_size", compression_size},
{"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]},
{"version", raw_header["version"]}};
return {SUCCESS, header};
{"index_fields", (*raw_header)["index_fields"]},
{"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
{"schema", (*raw_header)["schema"][0]["schema"]},
{"version", (*raw_header)["version"]}};
*header_ptr = std::make_shared<json>(header);
return Status::OK();
}

MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
uint32_t thread_num = std::thread::hardware_concurrency();
if (thread_num == 0) thread_num = kThreadNumber;
if (thread_num == 0) {
thread_num = kThreadNumber;
}
uint32_t work_thread_num = 0;
uint32_t shard_count = file_paths.size();
int group_num = ceil(shard_count * 1.0 / thread_num);
@@ -194,12 +165,10 @@ MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths,
}
if (thread_status) {
thread_status = false;
return FAILED;
RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread.");
}
if (SUCCESS != InitializeHeader(headers, load_dataset)) {
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset));
return Status::OK();
}

void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
@@ -208,48 +177,39 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
return;
}
for (int x = start; x < end; ++x) {
auto ret = ValidateHeader(realAddresses[x]);
if (SUCCESS != ret.first) {
std::shared_ptr<json> header;
if (ValidateHeader(realAddresses[x], &header).IsError()) {
thread_status = true;
return;
}
json header;
header = ret.second;
header["shard_addresses"] = realAddresses;
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
(*header)["shard_addresses"] = realAddresses;
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump()
<< ", lib version is: " << kVersion;
thread_status = true;
return;
}
headers[x] = header;
headers[x] = *header;
}
}

MSRStatus ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
std::vector<std::string> file_names(file_paths.size());
std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
if (GetFileName(fp).first == SUCCESS) {
return GetFileName(fp).second;
}
std::shared_ptr<std::string> fn;
return GetFileName(fp, &fn).IsOk() ? *fn : "";
});

shard_addresses_ = std::move(file_names);
shard_count_ = file_paths.size();
if (shard_count_ == 0) {
return FAILED;
}
if (shard_count_ <= kMaxShardCount) {
pages_.resize(shard_count_);
} else {
return FAILED;
}
return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
"shard count is invalid. shard count: " + std::to_string(shard_count_));
pages_.resize(shard_count_);
return Status::OK();
}

void ShardHeader::ParseHeader(const json &header) {}

MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
Status ShardHeader::ParseIndexFields(const json &index_fields) {
std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
for (auto &index_field : index_fields) {
auto schema_id = index_field["schema_id"].get<uint64_t>();
@@ -257,18 +217,15 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
parsed_index_fields.push_back(parsed_index_field);
}
if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) {
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields));
return Status::OK();
}

MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false
if (shard_count_ > kMaxFileCount) {
MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(
shard_count_ <= kMaxFileCount,
"The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount));
if (pages_.empty() && shard_count_ <= kMaxFileCount) {
pages_.resize(shard_count_);
}
@@ -295,44 +252,37 @@ MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_d
pages_[shard_index].push_back(std::move(parsed_page));
}
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::ParseStatistics(const json &statistics) {
Status ShardHeader::ParseStatistics(const json &statistics) {
for (auto &statistic : statistics) {
if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) {
MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump();
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(
statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
"Deserialize statistics failed, statistic: " + statistics.dump());
std::string statistic_description = statistic["desc"].get<std::string>();
json statistic_body = statistic["statistics"];
std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
if (!parsed_statistic) {
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(parsed_statistic);
AddStatistic(parsed_statistic);
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::ParseSchema(const json &schemas) {
Status ShardHeader::ParseSchema(const json &schemas) {
for (auto &schema : schemas) {
// change how we get schemaBody once design is finalized
if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() ||
schema.find("schema") == schema.end()) {
MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump();
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
schema.find("schema") != schema.end(),
"Deserialize schema failed. schema: " + schema.dump());
std::string schema_description = schema["desc"].get<std::string>();
std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
json schema_body = schema["schema"];
std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
if (!parsed_schema) {
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(parsed_schema);
AddSchema(parsed_schema);
}
return SUCCESS;
return Status::OK();
}

void ShardHeader::ParseShardAddress(const json &address) {
@@ -340,7 +290,7 @@ void ShardHeader::ParseShardAddress(const json &address) {
}

std::vector<std::string> ShardHeader::SerializeHeader() {
std::vector<string> header;
std::vector<std::string> header;
auto index = SerializeIndexFields();
auto stats = SerializeStatistics();
auto schema = SerializeSchema();
@@ -406,45 +356,42 @@ std::string ShardHeader::SerializeSchema() {

std::string ShardHeader::SerializeShardAddress() {
json j;
(void)std::transform(shard_addresses_.begin(), shard_addresses_.end(), std::back_inserter(j),
[](const std::string &addr) { return GetFileName(addr).second; });
std::shared_ptr<std::string> fn_ptr;
for (const auto &addr : shard_addresses_) {
(void)GetFileName(addr, &fn_ptr);
j.emplace_back(*fn_ptr);
}
return j.dump();
}

std::pair<std::shared_ptr<Page>, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) {
Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
return std::make_pair(pages_[shard_id][page_id], SUCCESS);
} else {
return std::make_pair(nullptr, FAILED);
*page_ptr = pages_[shard_id][page_id];
return Status::OK();
}
page_ptr = nullptr;
RETURN_STATUS_UNEXPECTED("Failed to Get Page.");
}

MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id][page_id] = new_page;
return SUCCESS;
} else {
return FAILED;
return Status::OK();
}
RETURN_STATUS_UNEXPECTED("Failed to Set Page.");
}

MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id].push_back(new_page);
return SUCCESS;
} else {
return FAILED;
return Status::OK();
}
RETURN_STATUS_UNEXPECTED("Failed to Add Page.");
}

int64_t ShardHeader::GetLastPageId(const int &shard_id) {
@@ -468,20 +415,18 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag
return last_page_id;
}

const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(const int &group_id,
const int &shard_id) {
if (shard_id >= static_cast<int>(pages_.size())) {
MS_LOG(ERROR) << "Shard id is more than sum of shards.";
return {FAILED, nullptr};
}
Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards.");
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
auto page = pages_[shard_id][i - 1];
if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
return {SUCCESS, page};
*page_ptr = std::make_shared<Page>(*page);
return Status::OK();
}
}
MS_LOG(ERROR) << "Could not get page by group id " << group_id;
return {FAILED, nullptr};
page_ptr = nullptr;
RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id);
}

int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
@@ -524,151 +469,88 @@ std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
return index;
}

MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
// check field name is or is not valid
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "Schema do not contain the field: " << field << ".";
return FAILED;
}

if (schema[field]["type"] == "bytes") {
MS_LOG(ERROR) << field << " is bytes type, can not be schema index field.";
return FAILED;
}

if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) {
MS_LOG(ERROR) << field << " array can not be schema index field.";
return FAILED;
}
return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema.");
CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field.");
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
"array can not be as index field.");
return Status::OK();
}

MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
if (fields.empty()) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty.");
// create index Object
std::shared_ptr<Index> index = InitIndexPtr();

if (fields.size() == kInt0) {
MS_LOG(ERROR) << "There are no index fields";
return FAILED;
}

if (GetSchemas().empty()) {
MS_LOG(ERROR) << "No schema is set";
return FAILED;
}

for (const auto &schemaPtr : schema_) {
auto result = GetSchemaByID(schemaPtr->GetSchemaID());
if (result.second != SUCCESS) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}

if (result.first == nullptr) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}

json schema = result.first->GetSchema().at("schema");

std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
json schema = schema_ptr->GetSchema().at("schema");
// checkout and add fields for each schema
std::set<std::string> field_set;
for (const auto &item : index->GetFields()) {
field_set.insert(item.second);
}
for (const auto &field : fields) {
if (field_set.find(field) != field_set.end()) {
MS_LOG(ERROR) << "Add same index field twice";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
// check field name is or is not valid
if (CheckIndexField(field, schema) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(CheckIndexField(field, schema));
field_set.insert(field);

// add field into index
index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
}
}

index_ = index;
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
// get all schema id
for (const auto &schema : schema_) {
auto bucket_it = bucket_count.find(schema->GetSchemaID());
if (bucket_it != bucket_count.end()) {
MS_LOG(ERROR) << "Schema duplication";
return FAILED;
} else {
bucket_count.insert(schema->GetSchemaID());
}
CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication.");
bucket_count.insert(schema->GetSchemaID());
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
if (fields.empty()) {
return Status::OK();
}
// create index Object
std::shared_ptr<Index> index = InitIndexPtr();

if (fields.size() == kInt0) {
MS_LOG(ERROR) << "There are no index fields";
return FAILED;
}

// get all schema id
std::set<uint64_t> bucket_count;
if (GetAllSchemaID(bucket_count) != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count));
// check and add fields for each schema
std::set<std::pair<uint64_t, std::string>> field_set;
for (const auto &item : index->GetFields()) {
field_set.insert(item);
}
for (const auto &field : fields) {
if (field_set.find(field) != field_set.end()) {
MS_LOG(ERROR) << "Add same index field twice";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
uint64_t schema_id = field.first;
std::string field_name = field.second;

// check schemaId is or is not valid
if (bucket_count.find(schema_id) == bucket_count.end()) {
MS_LOG(ERROR) << "Illegal schema id: " << schema_id;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id);
// check field name is or is not valid
auto result = GetSchemaByID(schema_id);
if (result.second != SUCCESS) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}
json schema = result.first->GetSchema().at("schema");
if (schema.find(field_name) == schema.end()) {
MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name;
return FAILED;
}

if (CheckIndexField(field_name, schema) == FAILED) {
return FAILED;
}

std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr));
json schema = schema_ptr->GetSchema().at("schema");
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(),
"Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name);
RETURN_IF_NOT_OK(CheckIndexField(field_name, schema));
field_set.insert(field);

// add field into index
index.get()->AddIndexField(schema_id, field_name);
index->AddIndexField(schema_id, field_name);
}
index_ = index;
return SUCCESS;
return Status::OK();
}

std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
@@ -686,103 +568,71 @@ std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return

std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }

std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) {
Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
RETURN_UNEXPECTED_IF_NULL(schema_ptr);
int64_t schemaSize = schema_.size();
if (schema_id < 0 || schema_id >= schemaSize) {
MS_LOG(ERROR) << "Illegal schema id";
return std::make_pair(nullptr, FAILED);
}
return std::make_pair(schema_.at(schema_id), SUCCESS);
CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid.");
*schema_ptr = schema_.at(schema_id);
return Status::OK();
}

std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) {
Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
RETURN_UNEXPECTED_IF_NULL(statistics_ptr);
int64_t statistics_size = statistics_.size();
if (statistic_id < 0 || statistic_id >= statistics_size) {
return std::make_pair(nullptr, FAILED);
}
return std::make_pair(statistics_.at(statistic_id), SUCCESS);
CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid.");
*statistics_ptr = statistics_.at(statistic_id);
return Status::OK();
}

MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) {
Status ShardHeader::PagesToFile(const std::string dump_file_name) {
auto realpath = Common::GetRealPath(dump_file_name);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
// write header content to file, dump whatever is in the file before
std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
if (page_out_handle.fail()) {
MS_LOG(ERROR) << "Failed in opening page file";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file.");
auto pages = SerializePage();
for (const auto &shard_pages : pages) {
page_out_handle << shard_pages << "\n";
}

page_out_handle.close();
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
Status ShardHeader::FileToPages(const std::string dump_file_name) {
for (auto &v : pages_) { // clean pages
v.clear();
}

auto realpath = Common::GetRealPath(dump_file_name);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
// attempt to open the file contains the page in json
std::ifstream page_in_handle(realpath.value());

if (!page_in_handle.good()) {
MS_LOG(INFO) << "No page file exists.";
return SUCCESS;
}

CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists.");
std::string line;
while (std::getline(page_in_handle, line)) {
if (SUCCESS != ParsePage(json::parse(line), -1, true)) {
return FAILED;
}
RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true));
}

page_in_handle.close();
return SUCCESS;
return Status::OK();
}

MSRStatus ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
if (header_ptr == nullptr) {
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
return FAILED;
}
Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
auto schema_ptr = Schema::Build("mindrecord", schema);
if (schema_ptr == nullptr) {
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(schema_ptr);
schema_id = (*header_ptr)->AddSchema(schema_ptr);
// create index
std::vector<std::pair<uint64_t, std::string>> id_index_fields;
if (!index_fields.empty()) {
(void)std::transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
[schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) {
MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index.";
return FAILED;
}
(void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
[schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields));
}

auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
blob_fields = build_schema_ptr->GetBlobFields();
return SUCCESS;
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 3
- 5
mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc View File

@@ -37,13 +37,11 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
}

MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) {
Status ShardPkSample::SufExecute(ShardTaskList &tasks) {
if (shuffle_ == true) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
}
return SUCCESS;
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 10
- 17
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc View File

@@ -80,7 +80,7 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return 0;
}

MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
if (tasks.permutation_.empty()) {
ShardTaskList new_tasks;
int total_no = static_cast<int>(tasks.sample_ids_.size());
@@ -110,9 +110,7 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
ShardTaskList::TaskListSwap(tasks, new_tasks);
} else {
ShardTaskList new_tasks;
if (taking > static_cast<int>(tasks.sample_ids_.size())) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int>(tasks.sample_ids_.size()), "taking is out of range.");
int total_no = static_cast<int>(tasks.permutation_.size());
int cnt = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
@@ -122,10 +120,10 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
return SUCCESS;
return Status::OK();
}

MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
Status ShardSample::Execute(ShardTaskList &tasks) {
if (offset_ != -1) {
int64_t old_v = 0;
int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
@@ -146,10 +144,8 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
no_of_samples_ = std::min(no_of_samples_, total_no);
taking = no_of_samples_ - no_of_samples_ % no_of_categories;
} else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
if (indices_.size() > static_cast<size_t>(total_no)) {
MS_LOG(ERROR) << "parameter indices's size is greater than dataset size.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no),
"Parameter indices's size is greater than dataset size.");
} else { // constructor TopPercent
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
if (numerator_ == 1 && denominator_ > 1) { // sharding
@@ -159,20 +155,17 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
taking -= (taking % no_of_categories);
}
} else {
MS_LOG(ERROR) << "parameter numerator or denominator is illegal";
return FAILED;
RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid.");
}
}
return UpdateTasks(tasks, taking);
}

MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) {
Status ShardSample::SufExecute(ShardTaskList &tasks) {
if (sampler_type_ == kSubsetRandomSampler) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
}
return SUCCESS;
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 0
- 12
mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc View File

@@ -38,12 +38,6 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) {
return std::make_shared<Schema>(object_schema);
}

std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) {
// validate check
json schema_json = nlohmann::detail::ToJsonImpl(schema);
return Build(std::move(desc), schema_json);
}

std::string Schema::GetDesc() const { return desc_; }

json Schema::GetSchema() const {
@@ -54,12 +48,6 @@ json Schema::GetSchema() const {
return str_schema;
}

pybind11::object Schema::GetSchemaForPython() const {
json schema_json = GetSchema();
pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json);
return schema_py;
}

void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }

int64_t Schema::GetSchemaID() const { return schema_id_; }


+ 4
- 5
mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc View File

@@ -38,7 +38,7 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
}

MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
Status ShardSequentialSample::Execute(ShardTaskList &tasks) {
int64_t taking;
int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size());
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
@@ -58,16 +58,15 @@ MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
ShardTaskList::TaskListSwap(tasks, new_tasks);
} else { // shuffled
ShardTaskList new_tasks;
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int64_t>(tasks.permutation_.size()),
"Taking is out of task range.");
total_no = static_cast<int64_t>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
return SUCCESS;
return Status::OK();
}

} // namespace mindrecord


+ 14
- 29
mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc View File

@@ -42,7 +42,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
}

MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
for (uint32_t i = 0; i < tasks.categories; i++) {
@@ -62,17 +62,14 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
}
ShardTaskList::TaskListSwap(tasks, new_tasks);

return SUCCESS;
return Status::OK();
}

MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size());
}
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
auto shard_sample_cout = GetShardSampleCount();

// shuffle the files index
@@ -118,16 +115,14 @@ MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
return Status::OK();
}

MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size());
}
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
// reconstruct the permutation in file
// -- before --
// file1: [0, 1, 2]
@@ -154,13 +149,12 @@ MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]);
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
return Status::OK();
}

MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
Status ShardShuffle::Execute(ShardTaskList &tasks) {
if (reshuffle_each_epoch_) shuffle_seed_++;
if (tasks.categories < 1) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid.");
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
if (tasks.permutation_.empty() == true) {
tasks.MakePerm();
@@ -169,10 +163,7 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
if (replacement_ == true) {
ShardTaskList new_tasks;
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
}
@@ -190,20 +181,14 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
ShardTaskList::TaskListSwap(tasks, new_tasks);
}
} else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
auto ret = ShuffleInfile(tasks);
if (ret != SUCCESS) {
return ret;
}
RETURN_IF_NOT_OK(ShuffleInfile(tasks));
} else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
auto ret = ShuffleFiles(tasks);
if (ret != SUCCESS) {
return ret;
}
RETURN_IF_NOT_OK(ShuffleFiles(tasks));
}
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
return this->CategoryShuffle(tasks);
}
return SUCCESS;
return Status::OK();
}
} // namespace mindrecord
} // namespace mindspore

+ 0
- 18
mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc View File

@@ -35,19 +35,6 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, const json &stat
return std::make_shared<Statistics>(object_statistics);
}

std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle statistics) {
// validate check
json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
if (!Validate(statistics_json)) {
return nullptr;
}
Statistics object_statistics;
object_statistics.desc_ = std::move(desc);
object_statistics.statistics_ = statistics_json;
object_statistics.statistics_id_ = -1;
return std::make_shared<Statistics>(object_statistics);
}

std::string Statistics::GetDesc() const { return desc_; }

json Statistics::GetStatistics() const {
@@ -57,11 +44,6 @@ json Statistics::GetStatistics() const {
return str_statistics;
}

pybind11::object Statistics::GetStatisticsForPython() const {
json str_statistics = Statistics::GetStatistics();
return nlohmann::detail::FromJsonImpl(str_statistics);
}

void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; }

int64_t Statistics::GetStatisticsID() const { return statistics_id_; }


+ 2
- 8
mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc View File

@@ -85,15 +85,9 @@ uint32_t ShardTaskList::SizeOfRows() const {
return nRows;
}

ShardTask &ShardTaskList::GetTaskByID(size_t id) {
MS_ASSERT(id < task_list_.size());
return task_list_[id];
}
ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; }

int ShardTaskList::GetTaskSampleByID(size_t id) {
MS_ASSERT(id < sample_ids_.size());
return sample_ids_[id];
}
int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; }

int ShardTaskList::GetRandomTaskID() {
std::mt19937 gen = mindspore::dataset::GetRandomDevice();


+ 1
- 1
mindspore/mindrecord/shardreader.py View File

@@ -70,7 +70,7 @@ class ShardReader:
Raises:
MRMLaunchError: If failed to launch worker threads.
"""
ret = self._reader.launch(False)
ret = self._reader.launch()
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to launch worker threads.")
raise MRMLaunchError


+ 4
- 26
mindspore/mindrecord/shardsegment.py View File

@@ -19,7 +19,6 @@ import mindspore._c_mindrecord as ms
from mindspore import log as logger
from .shardutils import populate_data, SUCCESS
from .shardheader import ShardHeader
from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError

__all__ = ['ShardSegment']

@@ -73,15 +72,8 @@ class ShardSegment:
Returns:
list[str], by which data could be grouped.

Raises:
MRMFetchCandidateFieldsError: If failed to get candidate category fields.
"""
ret, fields = self._segment.get_category_fields()
if ret != SUCCESS:
logger.error("Failed to get candidate category fields.")
raise MRMFetchCandidateFieldsError
return fields

return self._segment.get_category_fields()

def set_category_field(self, category_field):
"""Select one category field to use."""
@@ -94,14 +86,8 @@ class ShardSegment:
Returns:
str, description fo group information.

Raises:
MRMReadCategoryInfoError: If failed to read category information.
"""
ret, category_info = self._segment.read_category_info()
if ret != SUCCESS:
logger.error("Failed to read category information.")
raise MRMReadCategoryInfoError
return category_info
return self._segment.read_category_info()

def read_at_page_by_id(self, category_id, page, num_row):
"""
@@ -116,13 +102,9 @@ class ShardSegment:
list[dict]

Raises:
MRMFetchDataError: If failed to read by category id.
MRMUnsupportedSchemaError: If schema is invalid.
"""
ret, data = self._segment.read_at_page_by_id(category_id, page, num_row)
if ret != SUCCESS:
logger.error("Failed to read by category id.")
raise MRMFetchDataError
data = self._segment.read_at_page_by_id(category_id, page, num_row)
return [populate_data(raw, blob, self._columns, self._header.blob_fields,
self._header.schema) for blob, raw in data]

@@ -139,12 +121,8 @@ class ShardSegment:
list[dict]

Raises:
MRMFetchDataError: If failed to read by category name.
MRMUnsupportedSchemaError: If schema is invalid.
"""
ret, data = self._segment.read_at_page_by_name(category_name, page, num_row)
if ret != SUCCESS:
logger.error("Failed to read by category name.")
raise MRMFetchDataError
data = self._segment.read_at_page_by_name(category_name, page, num_row)
return [populate_data(raw, blob, self._columns, self._header.blob_fields,
self._header.schema) for blob, raw in data]

+ 2
- 2
tests/ut/cpp/mindrecord/ut_common.cc View File

@@ -384,8 +384,8 @@ void ShardWriterImageNetOpenForAppend(string filename) {
{
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
mindrecord::ShardWriter fw;
auto ret = fw.OpenForAppend(filename);
if (ret == FAILED) {
auto status = fw.OpenForAppend(filename);
if (status.IsError()) {
return;
}



+ 8
- 3
tests/ut/cpp/mindrecord/ut_shard.cc View File

@@ -121,14 +121,19 @@ TEST_F(TestShard, TestShardHeaderPart) {
re_statistics.push_back(*statistic);
}
ASSERT_EQ(re_statistics, validate_statistics);
ASSERT_EQ(header_data.GetStatisticByID(-1).second, FAILED);
ASSERT_EQ(header_data.GetStatisticByID(10).second, FAILED);
std::shared_ptr<Statistics> statistics_ptr;

auto status = header_data.GetStatisticByID(-1, &statistics_ptr);
EXPECT_FALSE(status.IsOk());
status = header_data.GetStatisticByID(10, &statistics_ptr);
EXPECT_FALSE(status.IsOk());

// test add index fields
std::vector<std::pair<uint64_t, std::string>> fields;
std::pair<uint64_t, std::string> pair1(0, "name");
fields.push_back(pair1);
ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS);
status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields();
ASSERT_EQ(resFields, fields);
}


+ 19
- 18
tests/ut/cpp/mindrecord/ut_shard_header_test.cc View File

@@ -79,36 +79,37 @@ TEST_F(TestShardHeader, AddIndexFields) {
std::pair<uint64_t, std::string> index_field2(schema_id1, "box");
fields.push_back(index_field1);
fields.push_back(index_field2);
MSRStatus res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, SUCCESS);
Status status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());

ASSERT_EQ(header_data.GetFields().size(), 2);

fields.clear();
std::pair<uint64_t, std::string> index_field3(schema_id1, "name");
fields.push_back(index_field3);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2);

fields.clear();
std::pair<uint64_t, std::string> index_field4(schema_id1, "names");
fields.push_back(index_field4);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2);

fields.clear();
std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name");
fields.push_back(index_field5);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2);

fields.clear();
std::pair<uint64_t, std::string> index_field6(schema_id1, "label");
fields.push_back(index_field6);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2);

std::string desc_new = "this is a test1";
@@ -129,26 +130,26 @@ TEST_F(TestShardHeader, AddIndexFields) {
single_fields.push_back("name");
single_fields.push_back("name");
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1);

single_fields.push_back("name");
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1);

single_fields.clear();
single_fields.push_back("names");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1);

single_fields.clear();
single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, SUCCESS);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_TRUE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 2);
}
} // namespace mindrecord


+ 8
- 8
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc View File

@@ -167,8 +167,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label"};
ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({file_name}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

while (true) {
@@ -188,16 +188,16 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_namex"};
ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
auto status= dataset.Open({file_name}, true, 4, column_list);
EXPECT_FALSE(status.IsOk());
}

TEST_F(TestShardReader, TestShardVersion) {
MS_LOG(INFO) << FormatInfo("Test shard version");
std::string file_name = "./imagenet.shard01";
ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({file_name}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

while (true) {
@@ -219,8 +219,8 @@ TEST_F(TestShardReader, TestShardReaderDir) {
auto column_list = std::vector<std::string>{"file_name"};

ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, FAILED);
auto status = dataset.Open({file_name}, true, 4, column_list);
EXPECT_FALSE(status.IsOk());
}

TEST_F(TestShardReader, TestShardReaderConsumer) {


+ 82
- 45
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc View File

@@ -61,35 +61,44 @@ TEST_F(TestShardSegment, TestShardSegment) {
ShardSegment dataset;
dataset.Open({file_name}, true, 4);

auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields;
}

ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
ASSERT_TRUE(dataset.SetCategoryField("laabel_0") == FAILED);
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
status = dataset.SetCategoryField("laabel_0");
EXPECT_FALSE(status.IsOk());

MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;

auto ret = dataset.ReadAtPageByName("822", 0, 10);
auto images = ret.second;
MS_LOG(INFO) << "category field: 822, images count: " << images.size() << ", image[0] size: " << images[0].size();
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto ret1 = dataset.ReadAtPageByName("823", 0, 10);
auto images2 = ret1.second;
MS_LOG(INFO) << "category field: 823, images count: " << images2.size();
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName("822", 0, 10, &pages_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr->size() << ", image[0] size: " << ((*pages_ptr)[0]).size();

auto ret2 = dataset.ReadAtPageById(1, 0, 10);
auto images3 = ret2.second;
MS_LOG(INFO) << "category id: 1, images count: " << images3.size() << ", image[0] size: " << images3[0].size();
auto pages_ptr_1 = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName("823", 0, 10, &pages_ptr_1);
MS_LOG(INFO) << "category field: 823, images count: " << pages_ptr_1->size();

auto ret3 = dataset.ReadAllAtPageByName("822", 0, 10);
auto images4 = ret3.second;
MS_LOG(INFO) << "category field: 822, images count: " << images4.size();
auto pages_ptr_2 = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, 0, 10, &pages_ptr_2);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_2->size() << ", image[0] size: " << ((*pages_ptr_2)[0]).size();

auto ret4 = dataset.ReadAllAtPageById(1, 0, 10);
auto images5 = ret4.second;
MS_LOG(INFO) << "category id: 1, images count: " << images5.size();
auto pages_ptr_3 = std::make_shared<PAGES>();
status = dataset.ReadAllAtPageByName("822", 0, 10, &pages_ptr_3);
MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr_3->size();

auto pages_ptr_4 = std::make_shared<PAGES>();
status = dataset.ReadAllAtPageById(1, 0, 10, &pages_ptr_4);
MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_4->size();
}

TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
@@ -99,21 +108,28 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
ShardSegment dataset;
dataset.Open({file_name}, true, 4);

auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields;
}

string category_name = "82Cus";
string category_field = "laabel_0";

ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
ASSERT_TRUE(dataset.SetCategoryField(category_field) == FAILED);
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
status = dataset.SetCategoryField(category_field);
EXPECT_FALSE(status.IsOk());

MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto ret = dataset.ReadAtPageByName(category_name, 0, 10);
EXPECT_TRUE(ret.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName(category_name, 0, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
}

TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
@@ -123,19 +139,25 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
ShardSegment dataset;
dataset.Open({file_name}, true, 4);

auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields;
}

int64_t categoryId = 2251799813685247;
MS_LOG(INFO) << "Input category id: " << categoryId;

ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto ret2 = dataset.ReadAtPageById(categoryId, 0, 10);
EXPECT_TRUE(ret2.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(categoryId, 0, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
}

TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
@@ -145,19 +167,27 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
ShardSegment dataset;
dataset.Open({file_name}, true, 4);

auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields;
}

int64_t page_no = 2251799813685247;
MS_LOG(INFO) << "Input page no: " << page_no;

ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());


auto ret2 = dataset.ReadAtPageById(1, page_no, 10);
EXPECT_TRUE(ret2.first == FAILED);
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, page_no, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
}

TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
@@ -167,19 +197,26 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
ShardSegment dataset;
dataset.Open({file_name}, true, 4);

auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields;
}

int64_t pageRows = 0;
MS_LOG(INFO) << "Input page rows: " << pageRows;

ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());

std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto ret2 = dataset.ReadAtPageById(1, 0, pageRows);
EXPECT_TRUE(ret2.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, 0, pageRows, &pages_ptr);
EXPECT_FALSE(status.IsOk());
}

} // namespace mindrecord


+ 49
- 29
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc View File

@@ -60,8 +60,8 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
std::string filename = "./OneSample.shard01";

ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({filename}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

while (true) {
@@ -675,8 +675,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) {
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));

// write rawdata
MSRStatus res = fw.WriteRawData(rawdatas, bin_data);
ASSERT_EQ(res, SUCCESS);
auto status = fw.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));
@@ -716,7 +716,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
fields.push_back(index_field2);

// add index to shardHeader
ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS);
auto status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Init Index Fields Already.";

// load meta data
@@ -736,28 +737,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
}

mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());

// set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());


// write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());

// create the index file
std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename};
sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index";

// read the mindrecord file
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name", "data"};
ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

int count = 0;
@@ -822,28 +829,34 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
}

mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
auto status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());


// set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());

// write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());

// create the index file
std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename};
sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index";

// read the mindrecord file
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name"};
ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

int count = 0;
@@ -896,7 +909,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
fields.push_back(index_field1);

// add index to shardHeader
ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS);
auto status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Init Index Fields Already.";

// load meta data
@@ -916,28 +930,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
}

mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());


// set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());

// write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());

// create the index file
std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename};
sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index";

// read the mindrecord file
filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "data"};
ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

int count = 0;
@@ -1043,8 +1063,8 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {

filename = "./TenSampleFortyShard.shard01";
ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({filename}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch();

int count = 0;


+ 5
- 5
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -95,7 +95,7 @@ def test_invalid_mindrecord():
f.write('just for test')
columns_list = ["data", "file_name", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"):
data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -114,7 +114,7 @@ def test_minddataset_lack_db():
os.remove("{}.db".format(CV_FILE_NAME))
columns_list = ["data", "file_name", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -133,7 +133,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'no_exist_column')
with pytest.raises(Exception, match="MindRecordOp launch failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -162,7 +162,7 @@ def test_cv_minddataset_reader_different_schema():
create_diff_schema_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0
@@ -179,7 +179,7 @@ def test_cv_minddataset_reader_different_page_size():
create_diff_page_size_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers)
num_iter = 0


+ 4
- 5
tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py View File

@@ -19,7 +19,6 @@ import pytest
from mindspore import log as logger
from mindspore.mindrecord import Cifar100ToMR
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError
from mindspore.mindrecord import SUCCESS

CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
@@ -119,8 +118,8 @@ def test_cifar100_to_mindrecord_directory(fixture_file):
test transform cifar10 dataset to mindrecord
when destination path is directory.
"""
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
CIFAR100_DIR)
cifar100_transformer.transform()
@@ -130,8 +129,8 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
test transform cifar10 dataset to mindrecord
when destination path equals source path.
"""
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="indRecord file already existed, please delete file:"):
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
CIFAR100_DIR + "/train")
cifar100_transformer.transform()

+ 5
- 5
tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py View File

@@ -19,7 +19,7 @@ import pytest
from mindspore import log as logger
from mindspore.mindrecord import Cifar10ToMR
from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError, SUCCESS
from mindspore.mindrecord import SUCCESS

CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord"
@@ -146,8 +146,8 @@ def test_cifar10_to_mindrecord_directory(fixture_file):
test transform cifar10 dataset to mindrecord
when destination path is directory.
"""
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR)
cifar10_transformer.transform()

@@ -157,8 +157,8 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10():
test transform cifar10 dataset to mindrecord
when destination path equals source path.
"""
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR,
CIFAR10_DIR + "/data_batch_0")
cifar10_transformer.transform()

+ 38
- 49
tests/ut/python/mindrecord/test_mindrecord_exception.py View File

@@ -21,8 +21,7 @@ from utils import get_data

from mindspore import log as logger
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \
MRMFetchDataError
from mindspore.mindrecord import ParamValueError, MRMGetMetaError

CV_FILE_NAME = "./imagenet.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord"
@@ -106,21 +105,19 @@ def create_cv_mindrecord(files_num):

def test_lack_partition_and_db():
"""test file reader when mindrecord file does not exist."""
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader('dummy.mindrecord')
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_lack_db(fixture_cv_file):
"""test file reader when db file does not exist."""
create_cv_mindrecord(1)
os.remove("{}.db".format(CV_FILE_NAME))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME)
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid database file:' in str(err.value)

def test_lack_some_partition_and_db(fixture_cv_file):
"""test file reader when some partition and db do not exist."""
@@ -129,11 +126,10 @@ def test_lack_some_partition_and_db(fixture_cv_file):
for x in range(FILES_NUM)]
os.remove("{}".format(paths[3]))
os.remove("{}.db".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0")
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_lack_some_partition_first(fixture_cv_file):
"""test file reader when first partition does not exist."""
@@ -141,11 +137,10 @@ def test_lack_some_partition_first(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
os.remove("{}".format(paths[0]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0")
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_lack_some_partition_middle(fixture_cv_file):
"""test file reader when some partition does not exist."""
@@ -153,11 +148,10 @@ def test_lack_some_partition_middle(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
os.remove("{}".format(paths[1]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0")
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_lack_some_partition_last(fixture_cv_file):
"""test file reader when last partition does not exist."""
@@ -165,11 +159,10 @@ def test_lack_some_partition_last(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
os.remove("{}".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0")
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_mindpage_lack_some_partition(fixture_cv_file):
"""test page reader when some partition does not exist."""
@@ -177,10 +170,9 @@ def test_mindpage_lack_some_partition(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
os.remove("{}".format(paths[0]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
MindPage(CV_FILE_NAME + "0")
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)

def test_lack_some_db(fixture_cv_file):
"""test file reader when some db does not exist."""
@@ -188,11 +180,10 @@ def test_lack_some_db(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
os.remove("{}.db".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0")
reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid database file:' in str(err.value)


def test_invalid_mindrecord():
@@ -200,10 +191,9 @@ def test_invalid_mindrecord():
with open(CV_FILE_NAME, 'w') as f:
dummy = 's' * 100
f.write(dummy)
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
FileReader(CV_FILE_NAME)
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file content. path:' in str(err.value)
os.remove(CV_FILE_NAME)

def test_invalid_db(fixture_cv_file):
@@ -212,27 +202,26 @@ def test_invalid_db(fixture_cv_file):
os.remove("imagenet.mindrecord.db")
with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test')
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
FileReader('imagenet.mindrecord')
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Error in execute sql:' in str(err.value)

def test_overwrite_invalid_mindrecord(fixture_cv_file):
"""test file writer when overwrite invalid mindreocrd file."""
with open(CV_FILE_NAME, 'w') as f:
f.write('just for test')
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
create_cv_mindrecord(1)
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
assert 'Unexpected error. MindRecord file already existed, please delete file:' \
in str(err.value)

def test_overwrite_invalid_db(fixture_cv_file):
"""test file writer when overwrite invalid db file."""
with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test')
with pytest.raises(MRMGenerateIndexError) as err:
with pytest.raises(RuntimeError) as err:
create_cv_mindrecord(1)
assert '[MRMGenerateIndexError]: Failed to generate index.' in str(err.value)
assert 'Unexpected error. Failed to write data to db.' in str(err.value)

def test_read_after_close(fixture_cv_file):
"""test file reader when close read."""
@@ -302,7 +291,7 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
with pytest.raises(ParamValueError):
reader.read_at_page_by_name("822", 0, "qwer")

with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
reader.read_at_page_by_id(99999, 0, 1)


@@ -320,10 +309,10 @@ def test_mindpage_filename_not_exist(fixture_cv_file):
info = reader.read_category_info()
logger.info("category info: {}".format(info))

with pytest.raises(MRMFetchDataError):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
reader.read_at_page_by_id(9999, 0, 1)

with pytest.raises(MRMFetchDataError):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."):
reader.read_at_page_by_name("abc.jpg", 0, 1)

with pytest.raises(ParamValueError):
@@ -475,7 +464,7 @@ def test_write_with_invalid_data():
mindrecord_file_name = "test.mindrecord"

# field: file_name => filename
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -510,7 +499,7 @@ def test_write_with_invalid_data():
writer.commit()

# field: mask => masks
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -545,7 +534,7 @@ def test_write_with_invalid_data():
writer.commit()

# field: data => image
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -580,7 +569,7 @@ def test_write_with_invalid_data():
writer.commit()

# field: label => labels
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -615,7 +604,7 @@ def test_write_with_invalid_data():
writer.commit()

# field: score => scores
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -650,7 +639,7 @@ def test_write_with_invalid_data():
writer.commit()

# string type with int value
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -685,7 +674,7 @@ def test_write_with_invalid_data():
writer.commit()

# field with int64 type, but the real data is string
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -720,7 +709,7 @@ def test_write_with_invalid_data():
writer.commit()

# bytes field is string
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -755,7 +744,7 @@ def test_write_with_invalid_data():
writer.commit()

# field is not numpy type
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")

@@ -790,7 +779,7 @@ def test_write_with_invalid_data():
writer.commit()

# not enough field
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db")



Loading…
Cancel
Save