Browse Source

refactor mindrecord

tags/v1.5.0-rc1
liyong 4 years ago
parent
commit
c257cf36e2
51 changed files with 2020 additions and 2864 deletions
  1. +8
    -22
      mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc
  2. +21
    -38
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc
  3. +6
    -9
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc
  4. +3
    -5
      mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc
  5. +0
    -83
      mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc
  6. +171
    -40
      mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
  7. +32
    -33
      mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc
  8. +13
    -8
      mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h
  9. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_category.h
  10. +19
    -20
      mindspore/ccsrc/minddata/mindrecord/include/shard_column.h
  11. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h
  12. +41
    -51
      mindspore/ccsrc/minddata/mindrecord/include/shard_error.h
  13. +33
    -33
      mindspore/ccsrc/minddata/mindrecord/include/shard_header.h
  14. +28
    -28
      mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h
  15. +16
    -21
      mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h
  16. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h
  17. +62
    -80
      mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
  18. +3
    -3
      mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
  19. +0
    -9
      mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h
  20. +19
    -20
      mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h
  21. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h
  22. +4
    -4
      mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h
  23. +0
    -9
      mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h
  24. +61
    -71
      mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h
  25. +204
    -329
      mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc
  26. +374
    -549
      mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
  27. +122
    -181
      mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc
  28. +272
    -496
      mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
  29. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc
  30. +74
    -100
      mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc
  31. +5
    -11
      mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
  32. +166
    -316
      mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc
  33. +3
    -5
      mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc
  34. +10
    -17
      mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
  35. +0
    -12
      mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc
  36. +4
    -5
      mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc
  37. +14
    -29
      mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc
  38. +0
    -18
      mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc
  39. +2
    -8
      mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc
  40. +1
    -1
      mindspore/mindrecord/shardreader.py
  41. +4
    -26
      mindspore/mindrecord/shardsegment.py
  42. +2
    -2
      tests/ut/cpp/mindrecord/ut_common.cc
  43. +8
    -3
      tests/ut/cpp/mindrecord/ut_shard.cc
  44. +19
    -18
      tests/ut/cpp/mindrecord/ut_shard_header_test.cc
  45. +8
    -8
      tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
  46. +82
    -45
      tests/ut/cpp/mindrecord/ut_shard_segment_test.cc
  47. +49
    -29
      tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
  48. +5
    -5
      tests/ut/python/dataset/test_minddataset_exception.py
  49. +4
    -5
      tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py
  50. +5
    -5
      tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py
  51. +38
    -49
      tests/ut/python/mindrecord/test_mindrecord_exception.py

+ 8
- 22
mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc View File

@@ -256,9 +256,7 @@ Status SaveToDisk::Save() {
auto mr_header = std::make_shared<mindrecord::ShardHeader>(); auto mr_header = std::make_shared<mindrecord::ShardHeader>();
auto mr_writer = std::make_unique<mindrecord::ShardWriter>(); auto mr_writer = std::make_unique<mindrecord::ShardWriter>();
std::vector<std::string> blob_fields; std::vector<std::string> blob_fields;
if (mindrecord::SUCCESS != mindrecord::ShardWriter::Initialize(&mr_writer, file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter, please check above `ERROR` level message.");
}
RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names));


std::unordered_map<std::string, int32_t> column_name_id_map; std::unordered_map<std::string, int32_t> column_name_id_map;
for (auto el : tree_adapter_->GetColumnNameMap()) { for (auto el : tree_adapter_->GetColumnNameMap()) {
@@ -286,22 +284,16 @@ Status SaveToDisk::Save() {
std::vector<std::string> index_fields; std::vector<std::string> index_fields;
RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields));
MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump(); MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump();
if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
if (mindrecord::SUCCESS != mr_writer->SetShardHeader(mr_header)) {
RETURN_STATUS_UNEXPECTED("Error: failed to set header of ShardWriter.");
}
RETURN_IF_NOT_OK(
mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id));
RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header));
first_loop = false; first_loop = false;
} }
// construct data // construct data
if (!row.empty()) { // write data if (!row.empty()) { // write data
RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data));
std::shared_ptr<std::vector<uint8_t>> output_bin_data; std::shared_ptr<std::vector<uint8_t>> output_bin_data;
if (mindrecord::SUCCESS != mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)) {
RETURN_STATUS_UNEXPECTED("Error: failed to merge blob data of ShardWriter.");
}
RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data));
std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data; std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data;
raw_data.insert( raw_data.insert(
std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
@@ -309,18 +301,12 @@ Status SaveToDisk::Save() {
if (output_bin_data != nullptr) { if (output_bin_data != nullptr) {
bin_data.emplace_back(*output_bin_data); bin_data.emplace_back(*output_bin_data);
} }
if (mindrecord::SUCCESS != mr_writer->WriteRawData(raw_data, bin_data)) {
RETURN_STATUS_UNEXPECTED("Error: failed to write raw data to ShardWriter.");
}
RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data));
} }
} while (!row.empty()); } while (!row.empty());


if (mindrecord::SUCCESS != mr_writer->Commit()) {
RETURN_STATUS_UNEXPECTED("Error: failed to commit ShardWriter.");
}
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::Finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
RETURN_IF_NOT_OK(mr_writer->Commit());
RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names));
return Status::OK(); return Status::OK();
} }




+ 21
- 38
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc View File

@@ -23,6 +23,7 @@
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/global_context.h" #include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h"
#include "minddata/mindrecord/include/shard_column.h"
#include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/include/dataset/constants.h" #include "minddata/dataset/include/dataset/constants.h"
@@ -63,10 +64,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str


// Private helper method to encapsulate some common construction/reset tasks // Private helper method to encapsulate some common construction/reset tasks
Status MindRecordOp::Init() { Status MindRecordOp::Init() {
auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_,
num_padded_);

CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed, " + ErrnoToMessage(rc));
RETURN_IF_NOT_OK(shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_,
operators_, num_padded_));


data_schema_ = std::make_unique<DataSchema>(); data_schema_ = std::make_unique<DataSchema>();


@@ -206,7 +205,9 @@ Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, i
fetched_row->setPath(file_path); fetched_row->setPath(file_path);
fetched_row->setId(row_id); fetched_row->setId(row_id);
} }
if (tupled_buffer.empty()) return Status::OK();
if (tupled_buffer.empty()) {
return Status::OK();
}
if (task_type == mindrecord::TaskType::kCommonTask) { if (task_type == mindrecord::TaskType::kCommonTask) {
for (const auto &tupled_row : tupled_buffer) { for (const auto &tupled_row : tupled_buffer) {
std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); std::vector<uint8_t> columns_blob = std::get<0>(tupled_row);
@@ -237,20 +238,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
// Get column data // Get column data
auto shard_column = shard_reader_->GetShardColumn(); auto shard_column = shard_reader_->GetShardColumn();
if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) {
auto rc =
shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape);
if (rc.first != MSRStatus::SUCCESS) {
RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset.");
}
if (rc.second == mindrecord::ColumnInRaw) {
auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
if (column_in_raw == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample.");
}
} else if (rc.second == mindrecord::ColumnInBlob) {
if (sample_bytes_.find(column_name) == sample_bytes_.end()) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve blob data from padding sample.");
}
mindrecord::ColumnCategory category;
RETURN_IF_NOT_OK(shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size,
&column_shape, &category));
if (category == mindrecord::ColumnInRaw) {
RETURN_IF_NOT_OK(shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes));
} else if (category == mindrecord::ColumnInBlob) {
CHECK_FAIL_RETURN_UNEXPECTED(sample_bytes_.find(column_name) != sample_bytes_.end(),
"Invalid data, failed to retrieve blob data from padding sample.");

std::string ss(sample_bytes_[column_name]); std::string ss(sample_bytes_[column_name]);
n_bytes = ss.size(); n_bytes = ss.size();
data_ptr = std::make_unique<unsigned char[]>(n_bytes); data_ptr = std::make_unique<unsigned char[]>(n_bytes);
@@ -262,12 +258,9 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
data = reinterpret_cast<const unsigned char *>(data_ptr.get()); data = reinterpret_cast<const unsigned char *>(data_ptr.get());
} }
} else { } else {
auto has_column =
shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes,
&column_data_type, &column_data_type_size, &column_shape);
if (has_column == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve data from mindrecord reader.");
}
RETURN_IF_NOT_OK(shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr,
&n_bytes, &column_data_type, &column_data_type_size,
&column_shape));
} }


std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
@@ -309,15 +302,10 @@ Status MindRecordOp::Reset() {
} }


Status MindRecordOp::LaunchThreadsAndInitOp() { Status MindRecordOp::LaunchThreadsAndInitOp() {
if (tree_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
}

RETURN_UNEXPECTED_IF_NULL(tree_);
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
if (shard_reader_->Launch(true) == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed.");
}
RETURN_IF_NOT_OK(shard_reader_->Launch(true));
// Launch main workers that load TensorRows by reading all images // Launch main workers that load TensorRows by reading all images
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id())); tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id()));
@@ -330,12 +318,7 @@ Status MindRecordOp::LaunchThreadsAndInitOp() {
Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) { const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) {
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded);
if (rc == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, MindRecordOp failed to count total rows. Check whether there are corresponding .db files "
"and the value of dataset_file parameter is given correctly.");
}
RETURN_IF_NOT_OK(shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded));
return Status::OK(); return Status::OK();
} }




+ 6
- 9
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc View File

@@ -38,9 +38,8 @@ Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::
uint64_t n_bytes = 0, col_type_size = 1; uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape; std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
@@ -57,9 +56,8 @@ Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, con
uint64_t n_bytes = 0, col_type_size = 1; uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape; std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
@@ -81,9 +79,8 @@ Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::v
uint64_t n_bytes = 0, col_type_size = 1; uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape; std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape));


if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);




+ 3
- 5
mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc View File

@@ -94,10 +94,9 @@ Status GraphLoader::InitAndLoad() {
TaskGroup vg; TaskGroup vg;


shard_reader_ = std::make_unique<ShardReader>(); shard_reader_ = std::make_unique<ShardReader>();
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
"Fail to open" + mr_path_);
RETURN_IF_NOT_OK(shard_reader_->Open({mr_path_}, true, num_workers_));
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
RETURN_IF_NOT_OK(shard_reader_->Launch(true));


graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"]; mindrecord::json schema = graph_impl_->data_schema_["schema"];
@@ -116,8 +115,7 @@ Status GraphLoader::InitAndLoad() {
if (graph_impl_->server_mode_) { if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0; int64_t total_blob_size = 0;
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS,
"failed to get total blob size");
RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size));
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
#endif #endif


+ 0
- 83
mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc View File

@@ -1,83 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "minddata/mindrecord/include/shard_error.h"

namespace mindspore {
namespace mindrecord {
static const std::map<MSRStatus, std::string> kErrnoToMessage = {
{FAILED, "operator failed"},
{SUCCESS, "operator success"},
{OPEN_FILE_FAILED, "open file failed"},
{CLOSE_FILE_FAILED, "close file failed"},
{WRITE_METADATA_FAILED, "write metadata failed"},
{WRITE_RAWDATA_FAILED, "write rawdata failed"},
{GET_SCHEMA_FAILED, "get schema failed"},
{ILLEGAL_RAWDATA, "illegal raw data"},
{PYTHON_TO_JSON_FAILED, "pybind: python object to json failed"},
{DIR_CREATE_FAILED, "directory create failed"},
{OPEN_DIR_FAILED, "open directory failed"},
{INVALID_STATISTICS, "invalid statistics object"},
{OPEN_DATABASE_FAILED, "open database failed"},
{CLOSE_DATABASE_FAILED, "close database failed"},
{DATABASE_OPERATE_FAILED, "database operate failed"},
{BUILD_SCHEMA_FAILED, "build schema failed"},
{DIVISOR_IS_ILLEGAL, "divisor is illegal"},
{INVALID_FILE_PATH, "file path is invalid"},
{SECURE_FUNC_FAILED, "secure function failed"},
{ALLOCATE_MEM_FAILED, "allocate memory failed"},
{ILLEGAL_FIELD_NAME, "illegal field name"},
{ILLEGAL_FIELD_TYPE, "illegal field type"},
{SET_METADATA_FAILED, "set metadata failed"},
{ILLEGAL_SCHEMA_DEFINITION, "illegal schema definition"},
{ILLEGAL_COLUMN_LIST, "illegal column list"},
{SQL_ERROR, "sql error"},
{ILLEGAL_SHARD_COUNT, "illegal shard count"},
{ILLEGAL_SCHEMA_COUNT, "illegal schema count"},
{VERSION_ERROR, "data version is not matched"},
{ADD_SCHEMA_FAILED, "add schema failed"},
{ILLEGAL_Header_SIZE, "illegal header size"},
{ILLEGAL_Page_SIZE, "illegal page size"},
{ILLEGAL_SIZE_VALUE, "illegal size value"},
{INDEX_FIELD_ERROR, "add index fields failed"},
{GET_CANDIDATE_CATEGORYFIELDS_FAILED, "get candidate category fields failed"},
{GET_CATEGORY_INFO_FAILED, "get category information failed"},
{ILLEGAL_CATEGORY_ID, "illegal category id"},
{ILLEGAL_ROWNUMBER_OF_PAGE, "illegal row number of page"},
{ILLEGAL_SCHEMA_ID, "illegal schema id"},
{DESERIALIZE_SCHEMA_FAILED, "deserialize schema failed"},
{DESERIALIZE_STATISTICS_FAILED, "deserialize statistics failed"},
{ILLEGAL_DB_FILE, "illegal db file"},
{OVERWRITE_DB_FILE, "overwrite db file"},
{OVERWRITE_MINDRECORD_FILE, "overwrite mindrecord file"},
{ILLEGAL_MINDRECORD_FILE, "illegal mindrecord file"},
{PARSE_JSON_FAILED, "parse json failed"},
{ILLEGAL_PARAMETERS, "illegal parameters"},
{GET_PAGE_BY_GROUP_ID_FAILED, "get page by group id failed"},
{GET_SYSTEM_STATE_FAILED, "get system state failed"},
{IO_FAILED, "io operate failed"},
{MATCH_HEADER_FAILED, "match header failed"}};

std::string ErrnoToMessage(MSRStatus status) {
auto iter = kErrnoToMessage.find(status);
if (iter != kErrnoToMessage.end()) {
return kErrnoToMessage.at(status);
} else {
return "invalid error no";
}
}
} // namespace mindrecord
} // namespace mindspore

+ 171
- 40
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc View File

@@ -36,20 +36,42 @@ using mindspore::MsLogLevel::ERROR;


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
#define THROW_IF_ERROR(s) \
do { \
Status rc = std::move(s); \
if (rc.IsError()) throw std::runtime_error(rc.ToString()); \
} while (false)

void BindSchema(py::module *m) { void BindSchema(py::module *m) {
(void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
.def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build)
.def_static("build",
[](const std::string &desc, const pybind11::handle &schema) {
json schema_json = nlohmann::detail::ToJsonImpl(schema);
return Schema::Build(std::move(desc), schema_json);
})
.def("get_desc", &Schema::GetDesc) .def("get_desc", &Schema::GetDesc)
.def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython)
.def("get_schema_content",
[](Schema &s) {
json schema_json = s.GetSchema();
return nlohmann::detail::FromJsonImpl(schema_json);
})
.def("get_blob_fields", &Schema::GetBlobFields) .def("get_blob_fields", &Schema::GetBlobFields)
.def("get_schema_id", &Schema::GetSchemaID); .def("get_schema_id", &Schema::GetSchemaID);
} }


void BindStatistics(const py::module *m) { void BindStatistics(const py::module *m) {
(void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
.def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build)
.def_static("build",
[](const std::string desc, const pybind11::handle &statistics) {
json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
return Statistics::Build(std::move(desc), statistics_json);
})
.def("get_desc", &Statistics::GetDesc) .def("get_desc", &Statistics::GetDesc)
.def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython)
.def("get_statistics",
[](Statistics &s) {
json statistics_json = s.GetStatistics();
return nlohmann::detail::FromJsonImpl(statistics_json);
})
.def("get_statistics_id", &Statistics::GetStatisticsID); .def("get_statistics_id", &Statistics::GetStatisticsID);
} }


@@ -59,70 +81,179 @@ void BindShardHeader(const py::module *m) {
.def("add_schema", &ShardHeader::AddSchema) .def("add_schema", &ShardHeader::AddSchema)
.def("add_statistics", &ShardHeader::AddStatistic) .def("add_statistics", &ShardHeader::AddStatistic)
.def("add_index_fields", .def("add_index_fields",
(MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields)
[](ShardHeader &s, const std::vector<std::string> &fields) {
THROW_IF_ERROR(s.AddIndexFields(fields));
return SUCCESS;
})
.def("get_meta", &ShardHeader::GetSchemas) .def("get_meta", &ShardHeader::GetSchemas)
.def("get_statistics", &ShardHeader::GetStatistics) .def("get_statistics", &ShardHeader::GetStatistics)
.def("get_fields", &ShardHeader::GetFields) .def("get_fields", &ShardHeader::GetFields)
.def("get_schema_by_id", &ShardHeader::GetSchemaByID)
.def("get_statistic_by_id", &ShardHeader::GetStatisticByID);
.def("get_schema_by_id",
[](ShardHeader &s, int64_t schema_id) {
std::shared_ptr<Schema> schema_ptr;
THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr));
return schema_ptr;
})
.def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) {
std::shared_ptr<Statistics> statistics_ptr;
THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr));
return statistics_ptr;
});
} }


void BindShardWriter(py::module *m) { void BindShardWriter(py::module *m) {
(void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local()) (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local())
.def(py::init<>()) .def(py::init<>())
.def("open", &ShardWriter::Open)
.def("open_for_append", &ShardWriter::OpenForAppend)
.def("set_header_size", &ShardWriter::SetHeaderSize)
.def("set_page_size", &ShardWriter::SetPageSize)
.def("set_shard_header", &ShardWriter::SetShardHeader)
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
vector<vector<uint8_t>> &, bool, bool)) &
ShardWriter::WriteRawData)
.def("commit", &ShardWriter::Commit);
.def("open",
[](ShardWriter &s, const std::vector<std::string> &paths, bool append) {
THROW_IF_ERROR(s.Open(paths, append));
return SUCCESS;
})
.def("open_for_append",
[](ShardWriter &s, const std::string &path) {
THROW_IF_ERROR(s.OpenForAppend(path));
return SUCCESS;
})
.def("set_header_size",
[](ShardWriter &s, const uint64_t &header_size) {
THROW_IF_ERROR(s.SetHeaderSize(header_size));
return SUCCESS;
})
.def("set_page_size",
[](ShardWriter &s, const uint64_t &page_size) {
THROW_IF_ERROR(s.SetPageSize(page_size));
return SUCCESS;
})
.def("set_shard_header",
[](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) {
THROW_IF_ERROR(s.SetShardHeader(header_data));
return SUCCESS;
})
.def("write_raw_data",
[](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, bool parallel_writer) {
std::map<uint64_t, std::vector<json>> raw_data_json;
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
[](const std::pair<uint64_t, std::vector<py::handle>> &p) {
auto &py_raw_data = p.second;
std::vector<json> json_raw_data;
(void)std::transform(
py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data),
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
return std::make_pair(p.first, std::move(json_raw_data));
});
THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer));
return SUCCESS;
})
.def("commit", [](ShardWriter &s) {
THROW_IF_ERROR(s.Commit());
return SUCCESS;
});
} }


void BindShardReader(const py::module *m) { void BindShardReader(const py::module *m) {
(void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local())
.def(py::init<>()) .def(py::init<>())
.def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardReader::OpenPy)
.def("launch", &ShardReader::Launch)
.def("open",
[](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
return SUCCESS;
})
.def("launch",
[](ShardReader &s) {
THROW_IF_ERROR(s.Launch(false));
return SUCCESS;
})
.def("get_header", &ShardReader::GetShardHeader) .def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields) .def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("get_next",
[](ShardReader &s) {
auto data = s.GetNext();
vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res;
std::transform(data.begin(), data.end(), std::back_inserter(res),
[&s](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
(void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr);
return std::make_tuple(*blob_data_ptr, std::move(obj));
});
return res;
})
.def("close", &ShardReader::Close); .def("close", &ShardReader::Close);
} }


void BindShardIndexGenerator(const py::module *m) { void BindShardIndexGenerator(const py::module *m) {
(void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local()) (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local())
.def(py::init<const std::string &, bool>()) .def(py::init<const std::string &, bool>())
.def("build", &ShardIndexGenerator::Build)
.def("write_to_db", &ShardIndexGenerator::WriteToDatabase);
.def("build",
[](ShardIndexGenerator &s) {
THROW_IF_ERROR(s.Build());
return SUCCESS;
})
.def("write_to_db", [](ShardIndexGenerator &s) {
THROW_IF_ERROR(s.WriteToDatabase());
return SUCCESS;
});
} }


void BindShardSegment(py::module *m) { void BindShardSegment(py::module *m) {
(void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local())
.def(py::init<>()) .def(py::init<>())
.def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &,
const std::vector<std::string> &,
const std::vector<std::shared_ptr<ShardOperator>> &)) &
ShardSegment::OpenPy)
.def("open",
[](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators));
return SUCCESS;
})
.def("get_category_fields", .def("get_category_fields",
(std::pair<MSRStatus, vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetCategoryFields)
.def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField)
.def("read_category_info", (std::pair<MSRStatus, std::string>(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo)
.def("read_at_page_by_id", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
ShardSegment::*)(int64_t, int64_t, int64_t)) &
ShardSegment::ReadAtPageByIdPy)
.def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
ShardSegment::*)(std::string, int64_t, int64_t)) &
ShardSegment::ReadAtPageByNamePy)
[](ShardSegment &s) {
auto fields_ptr = std::make_shared<vector<std::string>>();
THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr));
return *fields_ptr;
})
.def("set_category_field",
[](ShardSegment &s, const std::string &category_field) {
THROW_IF_ERROR(s.SetCategoryField(category_field));
return SUCCESS;
})
.def("read_category_info",
[](ShardSegment &s) {
std::shared_ptr<std::string> category_ptr;
THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr));
return *category_ptr;
})
.def("read_at_page_by_id",
[](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
auto pages_ptr = std::make_shared<PAGES>();
THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr));
(void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return *pages_load_ptr;
})
.def("read_at_page_by_name",
[](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto pages_load_ptr = std::make_shared<PAGES_LOAD>();
auto pages_ptr = std::make_shared<PAGES>();
THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr));
(void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return *pages_load_ptr;
})
.def("get_header", &ShardSegment::GetShardHeader) .def("get_header", &ShardSegment::GetShardHeader)
.def("get_blob_fields",
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields);
.def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); });
} }


void BindGlobalParams(py::module *m) { void BindGlobalParams(py::module *m) {


+ 32
- 33
mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc View File

@@ -57,26 +57,24 @@ bool ValidateFieldName(const std::string &str) {
return true; return true;
} }


std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr) {
RETURN_UNEXPECTED_IF_NULL(fn_ptr);
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0}; char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
} }
char tmp[PATH_MAX] = {0}; char tmp[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
} }
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
} }
#else #else
if (realpath(dirname(&(buf[0])), tmp) == nullptr) { if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
} }
if (realpath(common::SafeCStr(path), real_path) == nullptr) { if (realpath(common::SafeCStr(path), real_path) == nullptr) {
MS_LOG(DEBUG) << "Path: " << path << "check successfully"; MS_LOG(DEBUG) << "Path: " << path << "check successfully";
@@ -87,32 +85,32 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) {
size_t i = s.rfind(sep, s.length()); size_t i = s.rfind(sep, s.length());
if (i != std::string::npos) { if (i != std::string::npos) {
if (i + 1 < s.size()) { if (i + 1 < s.size()) {
return {SUCCESS, s.substr(i + 1)};
*fn_ptr = std::make_shared<std::string>(s.substr(i + 1));
return Status::OK();
} }
} }
return {SUCCESS, s};
*fn_ptr = std::make_shared<std::string>(s);
return Status::OK();
} }


std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr) {
RETURN_UNEXPECTED_IF_NULL(pd_ptr);
char real_path[PATH_MAX] = {0}; char real_path[PATH_MAX] = {0};
char buf[PATH_MAX] = {0}; char buf[PATH_MAX] = {0};
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
} }
char tmp[PATH_MAX] = {0}; char tmp[PATH_MAX] = {0};
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
} }
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
} }
#else #else
if (realpath(dirname(&(buf[0])), tmp) == nullptr) { if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
MS_LOG(ERROR) << "Invalid file path, path: " << buf;
return {FAILED, ""};
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
} }
if (realpath(common::SafeCStr(path), real_path) == nullptr) { if (realpath(common::SafeCStr(path), real_path) == nullptr) {
MS_LOG(DEBUG) << "Path: " << path << "check successfully"; MS_LOG(DEBUG) << "Path: " << path << "check successfully";
@@ -120,9 +118,11 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) {
#endif #endif
std::string s = real_path; std::string s = real_path;
if (s.rfind('/') + 1 <= s.size()) { if (s.rfind('/') + 1 <= s.size()) {
return {SUCCESS, s.substr(0, s.rfind('/') + 1)};
*pd_ptr = std::make_shared<std::string>(s.substr(0, s.rfind('/') + 1));
return Status::OK();
} }
return {SUCCESS, "/"};
*pd_ptr = std::make_shared<std::string>("/");
return Status::OK();
} }


bool CheckIsValidUtf8(const std::string &str) { bool CheckIsValidUtf8(const std::string &str) {
@@ -163,15 +163,16 @@ bool IsLegalFile(const std::string &path) {
return false; return false;
} }


std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) {
Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size_ptr) {
RETURN_UNEXPECTED_IF_NULL(size_ptr);
#if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__)
return {SUCCESS, 100};
*size_ptr = std::make_shared<uint64_t>(100);
return Status::OK();
#else #else
uint64_t ll_count = 0; uint64_t ll_count = 0;
struct statfs disk_info; struct statfs disk_info;
if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) {
MS_LOG(ERROR) << "Get disk size error";
return {FAILED, 0};
RETURN_STATUS_UNEXPECTED("Get disk size error.");
} }


switch (disk_type) { switch (disk_type) {
@@ -187,8 +188,8 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
ll_count = 0; ll_count = 0;
break; break;
} }
return {SUCCESS, ll_count};
*size_ptr = std::make_shared<uint64_t>(ll_count);
return Status::OK();
#endif #endif
} }


@@ -201,17 +202,15 @@ uint32_t GetMaxThreadNum() {
return thread_num; return thread_num;
} }


std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) {
auto ret = GetParentDir(path);
if (SUCCESS != ret.first) {
return {FAILED, {}};
}
std::vector<std::string> abs_addresses;
Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds) {
RETURN_UNEXPECTED_IF_NULL(ds);
std::shared_ptr<std::string> parent_dir;
RETURN_IF_NOT_OK(GetParentDir(path, &parent_dir));
for (const auto &p : addresses) { for (const auto &p : addresses) {
std::string abs_path = ret.second + std::string(p);
abs_addresses.emplace_back(abs_path);
std::string abs_path = *parent_dir + std::string(p);
(*ds)->emplace_back(abs_path);
} }
return {SUCCESS, abs_addresses};
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 13
- 8
mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h View File

@@ -33,6 +33,7 @@
#include <future> #include <future>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <memory>
#include <random> #include <random>
#include <set> #include <set>
#include <sstream> #include <sstream>
@@ -159,13 +160,15 @@ bool ValidateFieldName(const std::string &str);


/// \brief get the filename by the path /// \brief get the filename by the path
/// \param s file path /// \param s file path
/// \return
std::pair<MSRStatus, std::string> GetFileName(const std::string &s);
/// \param fn_ptr shared ptr of file name
/// \return Status
Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr);


/// \brief get parent dir /// \brief get parent dir
/// \param path file path /// \param path file path
/// \return parent path
std::pair<MSRStatus, std::string> GetParentDir(const std::string &path);
/// \param pd_ptr shared ptr of parent path
/// \return Status
Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr);


bool CheckIsValidUtf8(const std::string &str); bool CheckIsValidUtf8(const std::string &str);


@@ -179,8 +182,9 @@ enum DiskSizeType { kTotalSize = 0, kFreeSize };
/// \brief get the free space about the disk /// \brief get the free space about the disk
/// \param str_dir file path /// \param str_dir file path
/// \param disk_type: kTotalSize / kFreeSize /// \param disk_type: kTotalSize / kFreeSize
/// \return size in Megabytes
std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type);
/// \param size: shared ptr of size in Megabytes
/// \return Status
Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size);


/// \brief get the max hardware concurrency /// \brief get the max hardware concurrency
/// \return max concurrency /// \return max concurrency
@@ -189,8 +193,9 @@ uint32_t GetMaxThreadNum();
/// \brief get absolute path of all mindrecord files /// \brief get absolute path of all mindrecord files
/// \param path path to one fo mindrecord files /// \param path path to one fo mindrecord files
/// \param addresses relative path of all mindrecord files /// \param addresses relative path of all mindrecord files
/// \return vector of absolute path
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses);
/// \param ds shared ptr of vector of absolute path
/// \return Status
Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds);
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore




+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_category.h View File

@@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato


bool GetReplacement() const { return replacement_; } bool GetReplacement() const { return replacement_; }


MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;




+ 19
- 20
mindspore/ccsrc/minddata/mindrecord/include/shard_column.h View File

@@ -65,11 +65,11 @@ class __attribute__((visibility("default"))) ShardColumn {
~ShardColumn() = default; ~ShardColumn() = default;


/// \brief get column value by column name /// \brief get column value by column name
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
Status GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);


/// \brief compress blob /// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size);
@@ -90,19 +90,18 @@ class __attribute__((visibility("default"))) ShardColumn {
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }


/// \brief get column value from blob /// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes);
Status GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes);


/// \brief get column type /// \brief get column type
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type,
uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
Status GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type,
uint64_t *column_data_type_size, std::vector<int64_t> *column_shape,
ColumnCategory *column_category);


/// \brief get column value from json /// \brief get column value from json
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
Status GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);


private: private:
/// \brief initialization /// \brief initialization
@@ -110,15 +109,15 @@ class __attribute__((visibility("default"))) ShardColumn {


/// \brief get float value from json /// \brief get float value from json
template <typename T> template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
Status GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);


/// \brief get integer value from json /// \brief get integer value from json
template <typename T> template <typename T>
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
Status GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);


/// \brief get column offset address and size from blob /// \brief get column offset address and size from blob
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);
Status GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);


/// \brief check if column name is available /// \brief check if column name is available
ColumnCategory CheckColumnName(const std::string &column_name); ColumnCategory CheckColumnName(const std::string &column_name);
@@ -128,8 +127,8 @@ class __attribute__((visibility("default"))) ShardColumn {


/// \brief uncompress integer array column /// \brief uncompress integer array column
template <typename T> template <typename T>
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
static Status UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);


/// \brief convert big-endian bytes to unsigned int /// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array /// \param bytes_array bytes array


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h View File

@@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha


~ShardDistributedSample() override{}; ~ShardDistributedSample() override{};


MSRStatus PreExecute(ShardTaskList &tasks) override;
Status PreExecute(ShardTaskList &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;




+ 41
- 51
mindspore/ccsrc/minddata/mindrecord/include/shard_error.h View File

@@ -19,65 +19,55 @@


#include <map> #include <map>
#include <string> #include <string>
#include "include/api/status.h"


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
#define RETURN_IF_NOT_OK(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
return __rc; \
} \
} while (false)

#define RELEASE_AND_RETURN_IF_NOT_OK(_s, _db, _in) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
if ((_db) != nullptr) { \
sqlite3_close(_db); \
} \
(_in).close(); \
return __rc; \
} \
} while (false)

#define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \
do { \
if (!(_condition)) { \
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \
} \
} while (false)

#define RETURN_UNEXPECTED_IF_NULL(_ptr) \
do { \
if ((_ptr) == nullptr) { \
std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \
RETURN_STATUS_UNEXPECTED(err_msg); \
} \
} while (false)

#define RETURN_STATUS_UNEXPECTED(_e) \
do { \
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \
} while (false)

enum MSRStatus { enum MSRStatus {
SUCCESS = 0, SUCCESS = 0,
FAILED = 1, FAILED = 1,
OPEN_FILE_FAILED,
CLOSE_FILE_FAILED,
WRITE_METADATA_FAILED,
WRITE_RAWDATA_FAILED,
GET_SCHEMA_FAILED,
ILLEGAL_RAWDATA,
PYTHON_TO_JSON_FAILED,
DIR_CREATE_FAILED,
OPEN_DIR_FAILED,
INVALID_STATISTICS,
OPEN_DATABASE_FAILED,
CLOSE_DATABASE_FAILED,
DATABASE_OPERATE_FAILED,
BUILD_SCHEMA_FAILED,
DIVISOR_IS_ILLEGAL,
INVALID_FILE_PATH,
SECURE_FUNC_FAILED,
ALLOCATE_MEM_FAILED,
ILLEGAL_FIELD_NAME,
ILLEGAL_FIELD_TYPE,
SET_METADATA_FAILED,
ILLEGAL_SCHEMA_DEFINITION,
ILLEGAL_COLUMN_LIST,
SQL_ERROR,
ILLEGAL_SHARD_COUNT,
ILLEGAL_SCHEMA_COUNT,
VERSION_ERROR,
ADD_SCHEMA_FAILED,
ILLEGAL_Header_SIZE,
ILLEGAL_Page_SIZE,
ILLEGAL_SIZE_VALUE,
INDEX_FIELD_ERROR,
GET_CANDIDATE_CATEGORYFIELDS_FAILED,
GET_CATEGORY_INFO_FAILED,
ILLEGAL_CATEGORY_ID,
ILLEGAL_ROWNUMBER_OF_PAGE,
ILLEGAL_SCHEMA_ID,
DESERIALIZE_SCHEMA_FAILED,
DESERIALIZE_STATISTICS_FAILED,
ILLEGAL_DB_FILE,
OVERWRITE_DB_FILE,
OVERWRITE_MINDRECORD_FILE,
ILLEGAL_MINDRECORD_FILE,
PARSE_JSON_FAILED,
ILLEGAL_PARAMETERS,
GET_PAGE_BY_GROUP_ID_FAILED,
GET_SYSTEM_STATE_FAILED,
IO_FAILED,
MATCH_HEADER_FAILED
}; };


// convert error no to string message
std::string __attribute__((visibility("default"))) ErrnoToMessage(MSRStatus status);
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore




+ 33
- 33
mindspore/ccsrc/minddata/mindrecord/include/shard_header.h View File

@@ -37,9 +37,9 @@ class __attribute__((visibility("default"))) ShardHeader {


~ShardHeader() = default; ~ShardHeader() = default;


MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);
Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true);


static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path);
static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr);
/// \brief add the schema and save it /// \brief add the schema and save it
/// \param[in] schema the schema needs to be added /// \param[in] schema the schema needs to be added
/// \return the last schema's id /// \return the last schema's id
@@ -53,9 +53,9 @@ class __attribute__((visibility("default"))) ShardHeader {
/// \brief create index and add fields which from schema for each schema /// \brief create index and add fields which from schema for each schema
/// \param[in] fields the index fields needs to be added /// \param[in] fields the index fields needs to be added
/// \return SUCCESS if add successfully, FAILED if not /// \return SUCCESS if add successfully, FAILED if not
MSRStatus AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);
Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields);


MSRStatus AddIndexFields(const std::vector<std::string> &fields);
Status AddIndexFields(const std::vector<std::string> &fields);


/// \brief get the schema /// \brief get the schema
/// \return the schema /// \return the schema
@@ -79,9 +79,10 @@ class __attribute__((visibility("default"))) ShardHeader {
std::shared_ptr<Index> GetIndex(); std::shared_ptr<Index> GetIndex();


/// \brief get the schema by schemaid /// \brief get the schema by schemaid
/// \param[in] schemaId the id of schema needs to be got
/// \return the schema obtained by schemaId
std::pair<std::shared_ptr<Schema>, MSRStatus> GetSchemaByID(int64_t schema_id);
/// \param[in] schema_id the id of schema needs to be got
/// \param[in] schema_ptr the schema obtained by schemaId
/// \return Status
Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr);


/// \brief get the filepath to shard by shardID /// \brief get the filepath to shard by shardID
/// \param[in] shardID the id of shard which filepath needs to be obtained /// \param[in] shardID the id of shard which filepath needs to be obtained
@@ -89,25 +90,26 @@ class __attribute__((visibility("default"))) ShardHeader {
std::string GetShardAddressByID(int64_t shard_id); std::string GetShardAddressByID(int64_t shard_id);


/// \brief get the statistic by statistic id /// \brief get the statistic by statistic id
/// \param[in] statisticId the id of statistic needs to be get
/// \return the statistics obtained by statistic id
std::pair<std::shared_ptr<Statistics>, MSRStatus> GetStatisticByID(int64_t statistic_id);
/// \param[in] statistic_id the id of statistic needs to be get
/// \param[in] statistics_ptr the statistics obtained by statistic id
/// \return Status
Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr);


MSRStatus InitByFiles(const std::vector<std::string> &file_paths);
Status InitByFiles(const std::vector<std::string> &file_paths);


void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }


std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id);
Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr);


MSRStatus SetPage(const std::shared_ptr<Page> &new_page);
Status SetPage(const std::shared_ptr<Page> &new_page);


MSRStatus AddPage(const std::shared_ptr<Page> &new_page);
Status AddPage(const std::shared_ptr<Page> &new_page);


int64_t GetLastPageId(const int &shard_id); int64_t GetLastPageId(const int &shard_id);


int GetLastPageIdByType(const int &shard_id, const std::string &page_type); int GetLastPageIdByType(const int &shard_id, const std::string &page_type);


const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id);
Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr);


std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }


@@ -129,43 +131,41 @@ class __attribute__((visibility("default"))) ShardHeader {


std::vector<std::string> SerializeHeader(); std::vector<std::string> SerializeHeader();


MSRStatus PagesToFile(const std::string dump_file_name);
Status PagesToFile(const std::string dump_file_name);


MSRStatus FileToPages(const std::string dump_file_name);
Status FileToPages(const std::string dump_file_name);


static MSRStatus Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);
static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id);


private: private:
MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset);
Status InitializeHeader(const std::vector<json> &headers, bool load_dataset);


/// \brief get the headers from all the shard data /// \brief get the headers from all the shard data
/// \param[in] the shard data real path /// \param[in] the shard data real path
/// \param[in] the headers which read from the shard data /// \param[in] the headers which read from the shard data
/// \return SUCCESS/FAILED /// \return SUCCESS/FAILED
MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);
Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);


MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);


/// \brief check the binary file status /// \brief check the binary file status
static MSRStatus CheckFileStatus(const std::string &path);
static Status CheckFileStatus(const std::string &path);


static std::pair<MSRStatus, json> ValidateHeader(const std::string &path);

void ParseHeader(const json &header);
static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr);


void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses); void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses);


MSRStatus ParseIndexFields(const json &index_fields);
Status ParseIndexFields(const json &index_fields);


MSRStatus CheckIndexField(const std::string &field, const json &schema);
Status CheckIndexField(const std::string &field, const json &schema);


MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset);
Status ParsePage(const json &page, int shard_index, bool load_dataset);


MSRStatus ParseStatistics(const json &statistics);
Status ParseStatistics(const json &statistics);


MSRStatus ParseSchema(const json &schema);
Status ParseSchema(const json &schema);


void ParseShardAddress(const json &address); void ParseShardAddress(const json &address);


@@ -181,7 +181,7 @@ class __attribute__((visibility("default"))) ShardHeader {


std::shared_ptr<Index> InitIndexPtr(); std::shared_ptr<Index> InitIndexPtr();


MSRStatus GetAllSchemaID(std::set<uint64_t> &bucket_count);
Status GetAllSchemaID(std::set<uint64_t> &bucket_count);


uint32_t shard_count_; uint32_t shard_count_;
uint64_t header_size_; uint64_t header_size_;


+ 28
- 28
mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h View File

@@ -30,23 +30,24 @@


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
using INDEX_FIELDS = std::pair<MSRStatus, std::vector<std::tuple<std::string, std::string, std::string>>>;
using ROW_DATA = std::pair<MSRStatus, std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>>;
using INDEX_FIELDS = std::vector<std::tuple<std::string, std::string, std::string>>;
using ROW_DATA = std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>;
class __attribute__((visibility("default"))) ShardIndexGenerator { class __attribute__((visibility("default"))) ShardIndexGenerator {
public: public:
explicit ShardIndexGenerator(const std::string &file_path, bool append = false); explicit ShardIndexGenerator(const std::string &file_path, bool append = false);


MSRStatus Build();
Status Build();


static std::pair<MSRStatus, std::string> GenerateFieldName(const std::pair<uint64_t, std::string> &field);
static Status GenerateFieldName(const std::pair<uint64_t, std::string> &field, std::shared_ptr<std::string> *fn_ptr);


~ShardIndexGenerator() {} ~ShardIndexGenerator() {}


/// \brief fetch value in json by field name /// \brief fetch value in json by field name
/// \param[in] field /// \param[in] field
/// \param[in] input /// \param[in] input
/// \return pair<MSRStatus, value>
std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input);
/// \param[in] value
/// \return Status
Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value);


/// \brief fetch field type in schema n by field path /// \brief fetch field type in schema n by field path
/// \param[in] field_path /// \param[in] field_path
@@ -55,55 +56,54 @@ class __attribute__((visibility("default"))) ShardIndexGenerator {
static std::string TakeFieldType(const std::string &field_path, json schema); static std::string TakeFieldType(const std::string &field_path, json schema);


/// \brief create databases for indexes /// \brief create databases for indexes
MSRStatus WriteToDatabase();
Status WriteToDatabase();


static MSRStatus Finalize(const std::vector<std::string> file_names);
static Status Finalize(const std::vector<std::string> file_names);


private: private:
static int Callback(void *not_used, int argc, char **argv, char **az_col_name); static int Callback(void *not_used, int argc, char **argv, char **az_col_name);


static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = "");
static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = "");


static std::string ConvertJsonToSQL(const std::string &json); static std::string ConvertJsonToSQL(const std::string &json);


std::pair<MSRStatus, sqlite3 *> CreateDatabase(int shard_no);
Status CreateDatabase(int shard_no, sqlite3 **db);


std::pair<MSRStatus, std::vector<json>> GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in);
Status GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
std::shared_ptr<std::vector<json>> *detail_ptr);


static std::pair<MSRStatus, std::string> GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields);
static Status GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
std::shared_ptr<std::string> *sql_ptr);


std::pair<MSRStatus, sqlite3 *> CheckDatabase(const std::string &shard_address);
Status CheckDatabase(const std::string &shard_address, sqlite3 **db);


/// ///
/// \param shard_no /// \param shard_no
/// \param blob_id_to_page_id /// \param blob_id_to_page_id
/// \param raw_page_id /// \param raw_page_id
/// \param in /// \param in
/// \return field name, db type, field value
ROW_DATA GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
std::fstream &in);
/// \return Status
Status GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, std::fstream &in,
std::shared_ptr<ROW_DATA> *row_data_ptr);
/// ///
/// \param db /// \param db
/// \param sql /// \param sql
/// \param data /// \param data
/// \return /// \return
MSRStatus BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data);


INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
Status GenerateIndexFields(const std::vector<json> &schema_detail, std::shared_ptr<INDEX_FIELDS> *index_fields_ptr);


MSRStatus ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id);


MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name);


MSRStatus AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
std::fstream &in);
Status AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in);


void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
Status AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);


void DatabaseWriter(); // worker thread void DatabaseWriter(); // worker thread




+ 16
- 21
mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h View File

@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */

#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_


@@ -28,33 +27,29 @@ class __attribute__((visibility("default"))) ShardOperator {
public: public:
virtual ~ShardOperator() = default; virtual ~ShardOperator() = default;


MSRStatus operator()(ShardTaskList &tasks) {
if (SUCCESS != this->PreExecute(tasks)) {
return FAILED;
}
if (SUCCESS != this->Execute(tasks)) {
return FAILED;
}
if (SUCCESS != this->SufExecute(tasks)) {
return FAILED;
}
return SUCCESS;
Status operator()(ShardTaskList &tasks) {
RETURN_IF_NOT_OK(this->PreExecute(tasks));
RETURN_IF_NOT_OK(this->Execute(tasks));
RETURN_IF_NOT_OK(this->SufExecute(tasks));
return Status::OK();
} }


virtual bool HasChildOp() { return child_op_ != nullptr; } virtual bool HasChildOp() { return child_op_ != nullptr; }


virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) child_op_ = child_op;
return SUCCESS;
virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) {
if (child_op != nullptr) {
child_op_ = child_op;
}
return Status::OK();
} }


virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; }


virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); }


virtual MSRStatus Execute(ShardTaskList &tasks) = 0;
virtual Status Execute(ShardTaskList &tasks) = 0;


virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; }
virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }


virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }


@@ -72,9 +67,9 @@ class __attribute__((visibility("default"))) ShardOperator {
std::shared_ptr<ShardOperator> child_op_ = nullptr; std::shared_ptr<ShardOperator> child_op_ = nullptr;


// indicate shard_id : inc_count // indicate shard_id : inc_count
// 0 : 15 - shard0 has 15 samples
// 1 : 41 - shard1 has 26 samples
// 2 : 58 - shard2 has 17 samples
// // 0 : 15 - shard0 has 15 samples
// // 1 : 41 - shard1 has 26 samples
// // 2 : 58 - shard2 has 17 samples
std::vector<uint32_t> shard_sample_count_; std::vector<uint32_t> shard_sample_count_;


dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h View File

@@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor


~ShardPkSample() override{}; ~ShardPkSample() override{};


MSRStatus SufExecute(ShardTaskList &tasks) override;
Status SufExecute(ShardTaskList &tasks) override;


int64_t GetNumSamples() const { return num_samples_; } int64_t GetNumSamples() const { return num_samples_; }




+ 62
- 80
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h View File

@@ -59,12 +59,9 @@


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
using ROW_GROUPS =
std::tuple<MSRStatus, std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
using ROW_GROUP_BRIEF =
std::tuple<MSRStatus, std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
using TASK_RETURN_CONTENT =
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>;
using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>;
using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>;
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode


class API_PUBLIC ShardReader { class API_PUBLIC ShardReader {
@@ -82,21 +79,10 @@ class API_PUBLIC ShardReader {
/// \param[in] num_padded the number of padded samples /// \param[in] num_padded the number of padded samples
/// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
bool lazy_load = false);

/// \brief open files and initialize reader, python API
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
/// \param[in] load_dataset load dataset from single file or not
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \return MSRStatus the status of MSRStatus
MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {});
Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
bool lazy_load = false);


/// \brief close reader /// \brief close reader
/// \return null /// \return null
@@ -104,16 +90,16 @@ class API_PUBLIC ShardReader {


/// \brief read the file, get schema meta,statistics and index, single-thread mode /// \brief read the file, get schema meta,statistics and index, single-thread mode
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Open();
Status Open();


/// \brief read the file, get schema meta,statistics and index, multiple-thread mode /// \brief read the file, get schema meta,statistics and index, multiple-thread mode
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Open(int n_consumer);
Status Open(int n_consumer);


/// \brief launch threads to get batches /// \brief launch threads to get batches
/// \param[in] is_simple_reader trigger threads if false; do nothing if true /// \param[in] is_simple_reader trigger threads if false; do nothing if true
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Launch(bool is_simple_reader = false);
Status Launch(bool is_simple_reader = false);


/// \brief aim to get the meta data /// \brief aim to get the meta data
/// \return the metadata /// \return the metadata
@@ -133,8 +119,8 @@ class API_PUBLIC ShardReader {
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[in] op smart pointer refer to ShardCategory or ShardSample object
/// \param[out] count # of rows /// \param[out] count # of rows
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded);
Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded);


/// \brief shuffle task with incremental seed /// \brief shuffle task with incremental seed
/// \return void /// \return void
@@ -162,8 +148,8 @@ class API_PUBLIC ShardReader {
/// 3. Offset address of row group in file /// 3. Offset address of row group in file
/// 4. The list of image offset in page [startOffset, endOffset) /// 4. The list of image offset in page [startOffset, endOffset)
/// 5. The list of columns data /// 5. The list of columns data
ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id,
const std::vector<std::string> &columns = std::vector<std::string>());
Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);


/// \brief Read 1 row group data, excluding images, following an index field criteria /// \brief Read 1 row group data, excluding images, following an index field criteria
/// \param[in] groupID row group ID /// \param[in] groupID row group ID
@@ -176,8 +162,9 @@ class API_PUBLIC ShardReader {
/// 3. Offset address of row group in file /// 3. Offset address of row group in file
/// 4. The list of image offset in page [startOffset, endOffset) /// 4. The list of image offset in page [startOffset, endOffset)
/// 5. The list of columns data /// 5. The list of columns data
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns = std::vector<std::string>());
Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns,
std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr);


/// \brief return a batch, given that one is ready /// \brief return a batch, given that one is ready
/// \return a batch of images and image data /// \return a batch of images and image data
@@ -185,13 +172,7 @@ class API_PUBLIC ShardReader {


/// \brief return a row by id /// \brief return a row by id
/// \return a batch of images and image data /// \return a batch of images and image data
std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>> GetNextById(const int64_t &task_id,
const int32_t &consumer_id);

/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();

TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id);
/// \brief get blob filed list /// \brief get blob filed list
/// \return blob field list /// \return blob field list
std::pair<ShardType, std::vector<std::string>> GetBlobFields(); std::pair<ShardType, std::vector<std::string>> GetBlobFields();
@@ -205,83 +186,86 @@ class API_PUBLIC ShardReader {
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }


/// \brief get all classes /// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);

/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);


/// \brief get a read-only ptr to the sampled ids for this epoch /// \brief get a read-only ptr to the sampled ids for this epoch
const std::vector<int> *GetSampleIds(); const std::vector<int> *GetSampleIds();


/// \brief get the size of blob data
Status GetTotalBlobSize(int64_t *total_blob_size);

/// \brief extract uncompressed data based on column list
Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr);

protected: protected:
/// \brief sqlite call back function /// \brief sqlite call back function
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);


private: private:
/// \brief wrap up labels to json format /// \brief wrap up labels to json format
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id,
const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);


/// \brief read all rows for specified columns /// \brief read all rows for specified columns
ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns);
Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr);


/// \brief read row meta by shard_id and sample_id /// \brief read row meta by shard_id and sample_id
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
const uint32_t &sample_id);
Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr);


/// \brief read all rows in one shard /// \brief read all rows in one shard
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);


/// \brief initialize reader /// \brief initialize reader
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
Status Init(const std::vector<std::string> &file_paths, bool load_dataset);


/// \brief validate column list /// \brief validate column list
MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns);
Status CheckColumnList(const std::vector<std::string> &selected_columns);


/// \brief populate one row by task list in row-reader mode /// \brief populate one row by task list in row-reader mode
MSRStatus ConsumerByRow(int consumer_id);
void ConsumerByRow(int consumer_id);


/// \brief get offset address of images within page /// \brief get offset address of images within page
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
const std::pair<std::string, std::string> &criteria = {"", ""}); const std::pair<std::string, std::string> &criteria = {"", ""});


/// \brief get page id by category /// \brief get page id by category
std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id,
const std::pair<std::string, std::string> &criteria);
Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria,
std::shared_ptr<std::vector<uint64_t>> *pages_ptr);
/// \brief execute sqlite query with prepare statement /// \brief execute sqlite query with prepare statement
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
/// \brief verify the validity of dataset /// \brief verify the validity of dataset
MSRStatus VerifyDataset(sqlite3 **db, const string &file);
Status VerifyDataset(sqlite3 **db, const string &file);


/// \brief get column values /// \brief get column values
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria = {"", ""});
Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr);


/// \brief get column values from raw data page /// \brief get column values from raw data page
std::pair<MSRStatus, std::vector<json>> GetLabelsFromPage(int group_id, int shard_id,
const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria = {"",
""});
Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns,
const std::pair<std::string, std::string> &criteria,
std::shared_ptr<std::vector<json>> *labels_ptr);


/// \brief create category-applied task list /// \brief create category-applied task list
MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);


/// \brief create task list in row-reader mode /// \brief create task list in row-reader mode
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);


/// \brief create task list in row-reader mode and lazy mode /// \brief create task list in row-reader mode and lazy mode
MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);


/// \brief crate task list /// \brief crate task list
MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);


/// \brief check if all specified columns are in index table /// \brief check if all specified columns are in index table
void CheckIfColumnInIndex(const std::vector<std::string> &columns); void CheckIfColumnInIndex(const std::vector<std::string> &columns);
@@ -290,11 +274,12 @@ class API_PUBLIC ShardReader {
void FileStreamsOperator(); void FileStreamsOperator();


/// \brief read one row by one task /// \brief read one row by one task
TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id);
Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt);


/// \brief get labels from binary file /// \brief get labels from binary file
std::pair<MSRStatus, std::vector<json>> GetLabelsFromBinaryFile(
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns,
const std::vector<std::vector<std::string>> &label_offsets,
std::shared_ptr<std::vector<json>> *labels_ptr);


/// \brief get classes in one shard /// \brief get classes in one shard
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
@@ -304,11 +289,8 @@ class API_PUBLIC ShardReader {
int64_t GetNumClasses(const std::string &category_field); int64_t GetNumClasses(const std::string &category_field);


/// \brief get meta of header /// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path,
std::shared_ptr<json> meta_data_ptr);

/// \brief extract uncompressed data based on column list
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
std::shared_ptr<std::vector<std::string>> *addresses_ptr);


protected: protected:
uint64_t header_size_; // header size uint64_t header_size_; // header size


+ 3
- 3
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h View File

@@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator


~ShardSample() override{}; ~ShardSample() override{};


MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;


MSRStatus UpdateTasks(ShardTaskList &tasks, int taking);
Status UpdateTasks(ShardTaskList &tasks, int taking);


MSRStatus SufExecute(ShardTaskList &tasks) override;
Status SufExecute(ShardTaskList &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;




+ 0
- 9
mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h View File

@@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Schema {
/// \param[in] schema the schema's json /// \param[in] schema the schema's json
static std::shared_ptr<Schema> Build(std::string desc, const json &schema); static std::shared_ptr<Schema> Build(std::string desc, const json &schema);


/// \brief obtain the json schema and its description for python
/// \param[in] desc the description of the schema
/// \param[in] schema the schema's json
static std::shared_ptr<Schema> Build(std::string desc, pybind11::handle schema);

/// \brief compare two schema to judge if they are equal /// \brief compare two schema to judge if they are equal
/// \param b another schema to be judged /// \param b another schema to be judged
/// \return true if they are equal,false if not /// \return true if they are equal,false if not
@@ -57,10 +52,6 @@ class __attribute__((visibility("default"))) Schema {
/// \return the json format of the schema and its description /// \return the json format of the schema and its description
json GetSchema() const; json GetSchema() const;


/// \brief get the schema and its description for python method
/// \return the python object of the schema and its description
pybind11::object GetSchemaForPython() const;

/// set the schema id /// set the schema id
/// \param[in] id the id need to be set /// \param[in] id the id need to be set
void SetSchemaID(int64_t id); void SetSchemaID(int64_t id);


+ 19
- 20
mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h View File

@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_
#define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_


#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
@@ -25,6 +26,10 @@


namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>;
using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>;
using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>;

class __attribute__((visibility("default"))) ShardSegment : public ShardReader { class __attribute__((visibility("default"))) ShardSegment : public ShardReader {
public: public:
ShardSegment(); ShardSegment();
@@ -33,12 +38,12 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader {


/// \brief Get candidate category fields /// \brief Get candidate category fields
/// \return a list of fields names which are the candidates of category /// \return a list of fields names which are the candidates of category
std::pair<MSRStatus, vector<std::string>> GetCategoryFields();
Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr);


/// \brief Set category field /// \brief Set category field
/// \param[in] category_field category name /// \param[in] category_field category name
/// \return true if category name is existed /// \return true if category name is existed
MSRStatus SetCategoryField(std::string category_field);
Status SetCategoryField(std::string category_field);


/// \brief Thread-safe implementation of ReadCategoryInfo /// \brief Thread-safe implementation of ReadCategoryInfo
/// \return statistics data in json format with 2 field: "key" and "categories". /// \return statistics data in json format with 2 field: "key" and "categories".
@@ -50,47 +55,41 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader {
/// { "key": "label", /// { "key": "label",
/// "categories": [ { "count": 3, "id": 0, "name": "sport", }, /// "categories": [ { "count": 3, "id": 0, "name": "sport", },
/// { "count": 3, "id": 1, "name": "finance", } ] } /// { "count": 3, "id": 1, "name": "finance", } ] }
std::pair<MSRStatus, std::string> ReadCategoryInfo();
Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr);


/// \brief Thread-safe implementation of ReadAtPageById /// \brief Thread-safe implementation of ReadAtPageById
/// \param[in] category_id category ID /// \param[in] category_id category ID
/// \param[in] page_no page number /// \param[in] page_no page number
/// \param[in] n_rows_of_page rows number in one page /// \param[in] n_rows_of_page rows number in one page
/// \return images array, image is a vector of uint8_t /// \return images array, image is a vector of uint8_t
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageById(int64_t category_id, int64_t page_no,
int64_t n_rows_of_page);
Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr);


/// \brief Thread-safe implementation of ReadAtPageByName /// \brief Thread-safe implementation of ReadAtPageByName
/// \param[in] category_name category Name /// \param[in] category_name category Name
/// \param[in] page_no page number /// \param[in] page_no page number
/// \param[in] n_rows_of_page rows number in one page /// \param[in] n_rows_of_page rows number in one page
/// \return images array, image is a vector of uint8_t /// \return images array, image is a vector of uint8_t
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageByName(std::string category_name, int64_t page_no,
int64_t n_rows_of_page);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageById(int64_t category_id,
int64_t page_no,
int64_t n_rows_of_page);

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageByName(
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr);


std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByIdPy(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page);
Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr);


std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy(
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr);


std::pair<ShardType, std::vector<std::string>> GetBlobFields(); std::pair<ShardType, std::vector<std::string>> GetBlobFields();


private: private:
std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo();
Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr);


std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec); std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec);


std::string CleanUp(std::string fieldName); std::string CleanUp(std::string fieldName);


std::pair<MSRStatus, std::vector<uint8_t>> PackImages(int group_id, int shard_id, std::vector<uint64_t> offset);
Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
std::shared_ptr<std::vector<uint8_t>> *images_ptr);


std::vector<std::string> candidate_category_fields_; std::vector<std::string> candidate_category_fields_;
std::string current_category_field_; std::string current_category_field_;


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h View File

@@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar


~ShardSequentialSample() override{}; ~ShardSequentialSample() override{};


MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;




+ 4
- 4
mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h View File

@@ -31,19 +31,19 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator


~ShardShuffle() override{}; ~ShardShuffle() override{};


MSRStatus Execute(ShardTaskList &tasks) override;
Status Execute(ShardTaskList &tasks) override;


int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;


private: private:
// Private helper function // Private helper function
MSRStatus CategoryShuffle(ShardTaskList &tasks);
Status CategoryShuffle(ShardTaskList &tasks);


// Keep the file sequence the same but shuffle the data within each file // Keep the file sequence the same but shuffle the data within each file
MSRStatus ShuffleInfile(ShardTaskList &tasks);
Status ShuffleInfile(ShardTaskList &tasks);


// Shuffle the file sequence but keep the order of data within each file // Shuffle the file sequence but keep the order of data within each file
MSRStatus ShuffleFiles(ShardTaskList &tasks);
Status ShuffleFiles(ShardTaskList &tasks);


uint32_t shuffle_seed_; uint32_t shuffle_seed_;
int64_t no_of_samples_; int64_t no_of_samples_;


+ 0
- 9
mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h View File

@@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Statistics {
/// \param[in] statistics the statistic needs to be saved /// \param[in] statistics the statistic needs to be saved
static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics); static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics);


/// \brief save the statistic from python and its description
/// \param[in] desc the statistic's description
/// \param[in] statistics the statistic needs to be saved
static std::shared_ptr<Statistics> Build(std::string desc, pybind11::handle statistics);

~Statistics() = default; ~Statistics() = default;


/// \brief compare two statistics to judge if they are equal /// \brief compare two statistics to judge if they are equal
@@ -59,10 +54,6 @@ class __attribute__((visibility("default"))) Statistics {
/// \return json format of the statistic /// \return json format of the statistic
json GetStatistics() const; json GetStatistics() const;


/// \brief get the statistic for python
/// \return the python object of statistics
pybind11::object GetStatisticsForPython() const;

/// \brief decode the bson statistics to json /// \brief decode the bson statistics to json
/// \param[in] encodedStatistics the bson type of statistics /// \param[in] encodedStatistics the bson type of statistics
/// \return json type of statistic /// \return json type of statistic


+ 61
- 71
mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h View File

@@ -55,69 +55,60 @@ class __attribute__((visibility("default"))) ShardWriter {
/// \brief Open file at the beginning /// \brief Open file at the beginning
/// \param[in] paths the file names list /// \param[in] paths the file names list
/// \param[in] append new data at the end of file if true, otherwise overwrite file /// \param[in] append new data at the end of file if true, otherwise overwrite file
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::vector<std::string> &paths, bool append = false);
/// \return Status
Status Open(const std::vector<std::string> &paths, bool append = false);


/// \brief Open file at the ending /// \brief Open file at the ending
/// \param[in] paths the file names list /// \param[in] paths the file names list
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus OpenForAppend(const std::string &path);
Status OpenForAppend(const std::string &path);


/// \brief Write header to disk /// \brief Write header to disk
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus Commit();
Status Commit();


/// \brief Set file size /// \brief Set file size
/// \param[in] header_size the size of header, only (1<<N) is accepted /// \param[in] header_size the size of header, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus SetHeaderSize(const uint64_t &header_size);
Status SetHeaderSize(const uint64_t &header_size);


/// \brief Set page size /// \brief Set page size
/// \param[in] page_size the size of page, only (1<<N) is accepted /// \param[in] page_size the size of page, only (1<<N) is accepted
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus SetPageSize(const uint64_t &page_size);
Status SetPageSize(const uint64_t &page_size);


/// \brief Set shard header /// \brief Set shard header
/// \param[in] header_data the info of header /// \param[in] header_data the info of header
/// WARNING, only called when file is empty /// WARNING, only called when file is empty
/// \return MSRStatus the status of MSRStatus /// \return MSRStatus the status of MSRStatus
MSRStatus SetShardHeader(std::shared_ptr<ShardHeader> header_data);
Status SetShardHeader(std::shared_ptr<ShardHeader> header_data);


/// \brief write raw data by group size /// \brief write raw data by group size
/// \param[in] raw_data the vector of raw json data, vector format /// \param[in] raw_data the vector of raw json data, vector format
/// \param[in] blob_data the vector of image data /// \param[in] blob_data the vector of image data
/// \param[in] sign validate data or not /// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully /// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);

/// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format
/// \param[in] blob_data the vector of image data
/// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);
Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign = true, bool parallel_writer = false);


/// \brief write raw data by group size for call from python /// \brief write raw data by group size for call from python
/// \param[in] raw_data the vector of raw json data, python-handle format /// \param[in] raw_data the vector of raw json data, python-handle format
/// \param[in] blob_data the vector of blob json data, python-handle format /// \param[in] blob_data the vector of blob json data, python-handle format
/// \param[in] sign validate data or not /// \param[in] sign validate data or not
/// \return MSRStatus the status of MSRStatus to judge if write successfully /// \return MSRStatus the status of MSRStatus to judge if write successfully
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);
Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
bool parallel_writer = false);


MSRStatus MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);
Status MergeBlobData(const std::vector<string> &blob_fields,
const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data,
std::shared_ptr<std::vector<uint8_t>> *output);


static MSRStatus Initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names);
static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names);


private: private:
/// \brief write shard header data to disk /// \brief write shard header data to disk
MSRStatus WriteShardHeader();
Status WriteShardHeader();


/// \brief erase error data /// \brief erase error data
void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data); void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data);
@@ -130,108 +121,107 @@ class __attribute__((visibility("default"))) ShardWriter {
std::map<int, std::string> &err_raw_data); std::map<int, std::string> &err_raw_data);


/// \brief write shard header data to disk /// \brief write shard header data to disk
std::tuple<MSRStatus, int, int> ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &blob_data, bool sign);
Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data,
bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr);


/// \brief fill data array in multiple thread run /// \brief fill data array in multiple thread run
void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data, void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &bin_data); std::vector<std::vector<uint8_t>> &bin_data);


/// \brief serialized raw data /// \brief serialized raw data
MSRStatus SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data,
std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count);
Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data,
uint32_t row_count);


/// \brief write all data parallel /// \brief write all data parallel
MSRStatus ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief write data shard by shard /// \brief write data shard by shard
MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief break image data up into multiple row groups /// \brief break image data up into multiple row groups
MSRStatus CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page,
const std::shared_ptr<Page> &last_blob_page);
Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data,
std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page,
const std::shared_ptr<Page> &last_blob_page);


/// \brief append partial blob data to previous page /// \brief append partial blob data to previous page
MSRStatus AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page);

/// \brief write new blob data page to disk
MSRStatus NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group, const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page); const std::shared_ptr<Page> &last_blob_page);


/// \brief write new blob data page to disk
Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data,
const std::vector<std::pair<int, int>> &rows_in_group,
const std::shared_ptr<Page> &last_blob_page);

/// \brief shift last row group to next raw page for new appending /// \brief shift last row group to next raw page for new appending
MSRStatus ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page);
Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page);


/// \brief write raw data page to disk /// \brief write raw data page to disk
MSRStatus WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief generate empty raw data page /// \brief generate empty raw data page
void EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);


/// \brief append a row group at the end of raw page /// \brief append a row group at the end of raw page
MSRStatus AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
const int &chunk_id, int &last_row_groupId, std::shared_ptr<Page> last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
int &last_row_groupId, std::shared_ptr<Page> last_raw_page,
const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief write blob chunk to disk /// \brief write blob chunk to disk
MSRStatus FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
const std::pair<int, int> &blob_row);
Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data,
const std::pair<int, int> &blob_row);


/// \brief write raw chunk to disk /// \brief write raw chunk to disk
MSRStatus FlushRawChunk(const std::shared_ptr<std::fstream> &out,
const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id,
const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group,
const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief break up into tasks by shard /// \brief break up into tasks by shard
std::vector<std::pair<int, int>> BreakIntoShards(); std::vector<std::pair<int, int>> BreakIntoShards();


/// \brief calculate raw data size row by row /// \brief calculate raw data size row by row
MSRStatus SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);
Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data);


/// \brief calculate blob data size row by row /// \brief calculate blob data size row by row
MSRStatus SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);
Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data);


/// \brief populate last raw page pointer /// \brief populate last raw page pointer
void SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);
Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page);


/// \brief populate last blob page pointer /// \brief populate last blob page pointer
void SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);
Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page);


/// \brief check the data by schema /// \brief check the data by schema
MSRStatus CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);
Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data);


/// \brief check the data and type /// \brief check the data and type
MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);
Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
std::map<int, std::string> &err_raw_data);


/// \brief Lock writer and save pages info /// \brief Lock writer and save pages info
int LockWriter(bool parallel_writer = false);
Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr);


/// \brief Unlock writer and save pages info /// \brief Unlock writer and save pages info
MSRStatus UnlockWriter(int fd, bool parallel_writer = false);
Status UnlockWriter(int fd, bool parallel_writer = false);


/// \brief Check raw data before writing /// \brief Check raw data before writing
MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);
Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
bool sign, int *schema_count, int *row_count);


/// \brief Get full path from file name /// \brief Get full path from file name
MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths);
Status GetFullPathFromFileName(const std::vector<std::string> &paths);


/// \brief Open files /// \brief Open files
MSRStatus OpenDataFiles(bool append);
Status OpenDataFiles(bool append);


/// \brief Remove lock file /// \brief Remove lock file
MSRStatus RemoveLockFile();
Status RemoveLockFile();


/// \brief Remove lock file /// \brief Remove lock file
MSRStatus InitLockFile();
Status InitLockFile();


private: private:
const std::string kLockFileSuffix = "_Locker"; const std::string kLockFileSuffix = "_Locker";


+ 204
- 329
mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc View File

@@ -37,70 +37,48 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
task_(0), task_(0),
write_success_(true) {} write_success_(true) {}


MSRStatus ShardIndexGenerator::Build() {
auto ret = ShardHeader::BuildSingleHeader(file_path_);
if (ret.first != SUCCESS) {
return FAILED;
}
auto json_header = ret.second;

auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]);
if (SUCCESS != ret2.first) {
return FAILED;
}
Status ShardIndexGenerator::Build() {
std::shared_ptr<json> header_ptr;
RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr));
auto ds = std::make_shared<std::vector<std::string>>();
RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds));
ShardHeader header = ShardHeader(); ShardHeader header = ShardHeader();
auto addresses = ret2.second;
if (header.BuildDataset(addresses) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(header.BuildDataset(*ds));
shard_header_ = header; shard_header_ = header;
MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; MS_LOG(INFO) << "Init header from mindrecord file for index successfully.";
return SUCCESS;
return Status::OK();
} }


std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) {
if (field.empty()) {
MS_LOG(ERROR) << "The input field is None.";
return {FAILED, ""};
}

if (input.empty()) {
MS_LOG(ERROR) << "The input json is None.";
return {FAILED, ""};
}
Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) {
RETURN_UNEXPECTED_IF_NULL(value);
CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty.");


// parameter input does not contain the field // parameter input does not contain the field
if (input.find(field) == input.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input;
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(),
"The field " + field + " is not found in json " + input.dump());


// schema does not contain the field // schema does not contain the field
auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
"The field " + field + " is not found in schema " + schema.dump());


// field should be scalar type // field should be scalar type
if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) {
MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable";
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(
kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
"The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");


if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
auto schema_field_options = schema[field]; auto schema_field_options = schema[field];
if (schema_field_options.find("shape") == schema_field_options.end()) {
return {SUCCESS, input[field].dump()};
} else {
// field with shape option
MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable";
return {FAILED, ""};
}
CHECK_FAIL_RETURN_UNEXPECTED(
schema_field_options.find("shape") == schema_field_options.end(),
"The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable.");
*value = std::make_shared<std::string>(input[field].dump());
} else {
// the field type is string in here
*value = std::make_shared<std::string>(input[field].get<std::string>());
} }

// the field type is string in here
return {SUCCESS, input[field].get<std::string>()};
return Status::OK();
} }


std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
@@ -150,24 +128,28 @@ int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **
return 0; return 0;
} }


MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) {
char *z_err_msg = nullptr; char *z_err_msg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Sql error: " << z_err_msg;
std::ostringstream oss;
oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg;
MS_LOG(DEBUG) << oss.str();
sqlite3_free(z_err_msg); sqlite3_free(z_err_msg);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED(oss.str());
} else { } else {
if (!success_msg.empty()) { if (!success_msg.empty()) {
MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg;
MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg;
} }
sqlite3_free(z_err_msg); sqlite3_free(z_err_msg);
return SUCCESS;
return Status::OK();
} }
} }


std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName(
const std::pair<uint64_t, std::string> &field) {
Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field,
std::shared_ptr<std::string> *fn_ptr) {
RETURN_UNEXPECTED_IF_NULL(fn_ptr);
// Replaces dots and dashes with underscores for SQL use // Replaces dots and dashes with underscores for SQL use
std::string field_name = field.second; std::string field_name = field.second;
// white list to avoid sql injection // white list to avoid sql injection
@@ -176,95 +158,71 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName(
auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
}); });
if (pos != field_name.end()) {
MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name;
return {FAILED, ""};
}
return {SUCCESS, field_name + "_" + std::to_string(field.first)};
CHECK_FAIL_RETURN_UNEXPECTED(
pos == field_name.end(),
"Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name);
*fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
return Status::OK();
} }


std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CheckDatabase(const std::string &shard_address) {
Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) {
auto realpath = Common::GetRealPath(shard_address); auto realpath = Common::GetRealPath(shard_address);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << shard_address;
return {FAILED, nullptr};
}

sqlite3 *db = nullptr;
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
std::ifstream fin(realpath.value()); std::ifstream fin(realpath.value());
if (!append_ && fin.good()) { if (!append_ && fin.good()) {
MS_LOG(ERROR) << "Invalid file, DB file already exist: " << shard_address;
fin.close(); fin.close();
return {FAILED, nullptr};
RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address);
} }
fin.close(); fin.close();
int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
if (rc) {
MS_LOG(ERROR) << "Invalid file, failed to open database: " << shard_address << ", error" << sqlite3_errmsg(db);
return {FAILED, nullptr};
} else {
MS_LOG(DEBUG) << "Opened database successfully";
return {SUCCESS, db};
if (sqlite3_open_v2(common::SafeCStr(shard_address), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" +
std::string(sqlite3_errmsg(*db)));
} }
MS_LOG(DEBUG) << "Opened database successfully";
return Status::OK();
} }


MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) {
// create shard_name table // create shard_name table
std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;";
if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully."));
sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);";
if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully."));
sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);"; sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);";
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt != nullptr) { if (stmt != nullptr) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
} }
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
} }


int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME"); int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) { if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << shard_name;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(shard_name));
} }


if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
} }

(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
return SUCCESS;
return Status::OK();
} }


std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) {
Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
std::string shard_address = shard_header_.GetShardAddressByID(shard_no); std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no;
return {FAILED, nullptr};
}

string shard_name = GetFileName(shard_address).second;
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no);
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr));
shard_address += ".db"; shard_address += ".db";
auto ret1 = CheckDatabase(shard_address);
if (ret1.first != SUCCESS) {
return {FAILED, nullptr};
}
sqlite3 *db = ret1.second;
RETURN_IF_NOT_OK(CheckDatabase(shard_address, db));
std::string sql = "DROP TABLE IF EXISTS INDEXES;"; std::string sql = "DROP TABLE IF EXISTS INDEXES;";
if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) {
return {FAILED, nullptr};
}
RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully."));
sql = sql =
"CREATE TABLE INDEXES(" "CREATE TABLE INDEXES("
" ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL"
@@ -273,95 +231,79 @@ std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no
", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL";


int field_no = 0; int field_no = 0;
std::shared_ptr<std::string> field_ptr;
for (const auto &field : fields_) { for (const auto &field : fields_) {
uint64_t schema_id = field.first; uint64_t schema_id = field.first;
auto result = shard_header_.GetSchemaByID(schema_id);
if (result.second != SUCCESS) {
return {FAILED, nullptr};
}
json json_schema = (result.first->GetSchema())["schema"];
std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr));
json json_schema = (schema_ptr->GetSchema())["schema"];
std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema));
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, nullptr};
}
sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type;
RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr));
sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type;
} }
sql += ", PRIMARY KEY(ROW_ID"; sql += ", PRIMARY KEY(ROW_ID";
for (uint64_t i = 0; i < fields_.size(); ++i) { for (uint64_t i = 0; i < fields_.size(); ++i) {
sql += ",INC_" + std::to_string(i); sql += ",INC_" + std::to_string(i);
} }
sql += "));"; sql += "));";
if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) {
return {FAILED, nullptr};
}

if (CreateShardNameTable(db, shard_name) != SUCCESS) {
return {FAILED, nullptr};
}
return {SUCCESS, db};
RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully."));
RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr));
return Status::OK();
} }


std::pair<MSRStatus, std::vector<json>> ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens,
std::fstream &in) {
std::vector<json> schema_details;
Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in,
std::shared_ptr<std::vector<json>> *detail_ptr) {
RETURN_UNEXPECTED_IF_NULL(detail_ptr);
if (schema_count_ <= kMaxSchemaCount) { if (schema_count_ <= kMaxSchemaCount) {
for (int sc = 0; sc < schema_count_; ++sc) { for (int sc = 0; sc < schema_count_; ++sc) {
std::vector<char> schema_detail(schema_lens[sc]); std::vector<char> schema_detail(schema_lens[sc]);

auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
if (!io_read.good() || io_read.fail() || io_read.bad()) { if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
in.close(); in.close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to read file.");
} }
schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end())));
auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
(*detail_ptr)->emplace_back(j);
} }
} }

return {SUCCESS, schema_details};
return Status::OK();
} }


std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
const std::vector<std::pair<uint64_t, std::string>> &fields) {
Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields,
std::shared_ptr<std::string> *sql_ptr) {
std::string sql = std::string sql =
"INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END,"
"PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END";


int field_no = 0; int field_no = 0;
for (const auto &field : fields) { for (const auto &field : fields) {
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, ""};
}
sql += ",INC_" + std::to_string(field_no++) + "," + ret.second;
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr;
} }
sql += sql +=
") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB,"
":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END";
field_no = 0; field_no = 0;
for (const auto &field : fields) { for (const auto &field : fields) {
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, ""};
}
sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second;
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr;
} }
sql += " )"; sql += " )";
return {SUCCESS, sql};

*sql_ptr = std::make_shared<std::string>(sql);
return Status::OK();
} }


MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) {
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt != nullptr) { if (stmt != nullptr) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
} }
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
} }
for (auto &row : data) { for (auto &row : data) {
for (auto &field : row) { for (auto &field : row) {
@@ -373,45 +315,47 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
if (field_type == "INTEGER") { if (field_type == "INTEGER") {
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoll(field_value);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
} }
} else if (field_type == "NUMERIC") { } else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stold(field_value);
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
} }
} else if (field_type == "NULL") { } else if (field_type == "NULL") {
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
return FAILED;

sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: NULL");
} }
} else { } else {
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
return FAILED;
sqlite3_close(db);
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
", field value: " + std::string(field_value));
} }
} }
} }
if (sqlite3_step(stmt) != SQLITE_DONE) { if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
} }
(void)sqlite3_reset(stmt); (void)sqlite3_reset(stmt);
} }
(void)sqlite3_finalize(stmt); (void)sqlite3_finalize(stmt);
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page,
uint64_t &cur_blob_page_offset, std::fstream &in) {
Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset,
std::fstream &in) {
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));


// blob data start // blob data start
@@ -419,89 +363,71 @@ MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::strin
auto &io_seekg_blob = auto &io_seekg_blob =
in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
MS_LOG(ERROR) << "File seekg failed";
in.close(); in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
} }

uint64_t image_size = 0; uint64_t image_size = 0;

auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len); auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) { if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed"; MS_LOG(ERROR) << "File read failed";
in.close(); in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to read file.");
} }


cur_blob_page_offset += (kInt64Len + image_size); cur_blob_page_offset += (kInt64Len + image_size);
row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset));


return SUCCESS;
return Status::OK();
} }


void ShardIndexGenerator::AddIndexFieldByRawData(
Status ShardIndexGenerator::AddIndexFieldByRawData(
const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) { const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) {
auto result = GenerateIndexFields(schema_detail);
if (result.first == SUCCESS) {
int index = 0;
for (const auto &field : result.second) {
// assume simple field: string , number etc.
row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
}
}
auto index_fields_ptr = std::make_shared<INDEX_FIELDS>();
RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr));
int index = 0;
for (const auto &field : *index_fields_ptr) {
// assume simple field: string , number etc.
row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0");
row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field));
}
return Status::OK();
} }


ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id,
int raw_page_id, std::fstream &in) {
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;

Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id,
std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) {
RETURN_UNEXPECTED_IF_NULL(row_data_ptr);
// current raw data page // current raw data page
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
if (ret1.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_raw_page = ret1.first;

std::shared_ptr<Page> page_ptr;
RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr));
// related blob page // related blob page
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds();


// pair: row_group id, offset in raw data page // pair: row_group id, offset in raw data page
for (pair<int, int> blob_ids : row_group_list) { for (pair<int, int> blob_ids : row_group_list) {
// get blob data page according to row_group id // get blob data page according to row_group id
auto iter = blob_id_to_page_id.find(blob_ids.first); auto iter = blob_id_to_page_id.find(blob_ids.first);
if (iter == blob_id_to_page_id.end()) {
MS_LOG(ERROR) << "Convert blob id failed";
return {FAILED, {}};
}
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
if (ret2.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_blob_page = ret2.first;

CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id.");
std::shared_ptr<Page> blob_page_ptr;
RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
// offset in current raw data page // offset in current raw data page
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
uint64_t cur_blob_page_offset = 0; uint64_t cur_blob_page_offset = 0;
for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) {
for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) {
std::vector<std::tuple<std::string, std::string, std::string>> row_data; std::vector<std::tuple<std::string, std::string, std::string>> row_data;
row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID()));
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID()));
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID()));


// raw data start // raw data start
row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));


// calculate raw data end // calculate raw data end
auto &io_seekg = auto &io_seekg =
in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
return {FAILED, {}};
in.close();
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
} }

std::vector<uint64_t> schema_lens; std::vector<uint64_t> schema_lens;
if (schema_count_ <= kMaxSchemaCount) { if (schema_count_ <= kMaxSchemaCount) {
for (int sc = 0; sc < schema_count_; sc++) { for (int sc = 0; sc < schema_count_; sc++) {
@@ -509,8 +435,8 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,


auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len); auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) { if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
return {FAILED, {}};
in.close();
RETURN_STATUS_UNEXPECTED("Failed to read file.");
} }


cur_raw_page_offset += (kInt64Len + schema_size); cur_raw_page_offset += (kInt64Len + schema_size);
@@ -520,122 +446,79 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset));


// Getting schema for getting data for fields // Getting schema for getting data for fields
auto st_schema_detail = GetSchemaDetails(schema_lens, in);
if (st_schema_detail.first != SUCCESS) {
return {FAILED, {}};
}

auto detail_ptr = std::make_shared<std::vector<json>>();
RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr));
// start blob page info // start blob page info
if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) {
return {FAILED, {}};
}
RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in));


// start index field // start index field
AddIndexFieldByRawData(st_schema_detail.second, row_data);
full_data.push_back(std::move(row_data));
AddIndexFieldByRawData(*detail_ptr, row_data);
(*row_data_ptr)->push_back(std::move(row_data));
} }
} }
return {SUCCESS, full_data};
return Status::OK();
} }


INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) {
std::vector<std::tuple<std::string, std::string, std::string>> fields;
Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail,
std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) {
RETURN_UNEXPECTED_IF_NULL(index_fields_ptr);
// index fields // index fields
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
for (const auto &field : index_fields) { for (const auto &field : index_fields) {
if (field.first >= schema_detail.size()) {
return {FAILED, {}};
}
auto field_value = GetValueByField(field.second, schema_detail[field.first]);
if (field_value.first != SUCCESS) {
MS_LOG(ERROR) << "Get value from json by field name failed";
return {FAILED, {}};
}

auto result = shard_header_.GetSchemaByID(field.first);
if (result.second != SUCCESS) {
return {FAILED, {}};
}

std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"]));
auto ret = GenerateFieldName(field);
if (ret.first != SUCCESS) {
return {FAILED, {}};
}

fields.emplace_back(ret.second, field_type, field_value.second);
}
return {SUCCESS, std::move(fields)};
CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range.");
std::shared_ptr<std::string> field_val_ptr;
RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr));
std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"]));
std::shared_ptr<std::string> fn_ptr;
RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr));
(*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr);
}
return Status::OK();
} }


MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) {
// Add index data to database // Add index data to database
std::string shard_address = shard_header_.GetShardAddressByID(shard_no); std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
if (shard_address.empty()) {
MS_LOG(ERROR) << "Invalid data, shard address is null";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty.");


auto realpath = Common::GetRealPath(shard_address); auto realpath = Common::GetRealPath(shard_address);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << shard_address;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
std::fstream in; std::fstream in;
in.open(realpath.value(), std::ios::in | std::ios::binary); in.open(realpath.value(), std::ios::in | std::ios::binary);
if (!in.good()) { if (!in.good()) {
MS_LOG(ERROR) << "Invalid file, failed to open file: " << shard_address;
in.close(); in.close();
return FAILED;
RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address);
} }
(void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
(void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
for (int raw_page_id : raw_page_ids) { for (int raw_page_id : raw_page_ids) {
auto sql = GenerateRawSQL(fields_);
if (sql.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw SQL failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in);
if (data.first != SUCCESS) {
MS_LOG(ERROR) << "Generate raw data failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
MS_LOG(ERROR) << "Execute SQL failed";
in.close();
sqlite3_close(db.second);
return FAILED;
}
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
}
(void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr);
std::shared_ptr<std::string> sql_ptr;
RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in);
auto row_data_ptr = std::make_shared<ROW_DATA>();
RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in);
RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in);
MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db.";
}
(void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr);
in.close(); in.close();


// Close database // Close database
if (sqlite3_close(db.second) != SQLITE_OK) {
MS_LOG(ERROR) << "Close database failed";
return FAILED;
}
db.second = nullptr;
return SUCCESS;
sqlite3_close(db);
db = nullptr;
return Status::OK();
} }


MSRStatus ShardIndexGenerator::WriteToDatabase() {
Status ShardIndexGenerator::WriteToDatabase() {
fields_ = shard_header_.GetFields(); fields_ = shard_header_.GetFields();
page_size_ = shard_header_.GetPageSize(); page_size_ = shard_header_.GetPageSize();
header_size_ = shard_header_.GetHeaderSize(); header_size_ = shard_header_.GetHeaderSize();
schema_count_ = shard_header_.GetSchemaCount(); schema_count_ = shard_header_.GetSchemaCount();
if (shard_header_.GetShardCount() > kMaxShardCount) {
MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount,
"num shards: " + std::to_string(shard_header_.GetShardCount()) +
" exceeds max count:" + std::to_string(kMaxSchemaCount));
task_ = 0; // set two atomic vars to initial value task_ = 0; // set two atomic vars to initial value
write_success_ = true; write_success_ = true;


@@ -653,40 +536,41 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
for (size_t t = 0; t < threads.capacity(); t++) { for (size_t t = 0; t < threads.capacity(); t++) {
threads[t].join(); threads[t].join();
} }
return write_success_ ? SUCCESS : FAILED;
CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db.");
return Status::OK();
} }


void ShardIndexGenerator::DatabaseWriter() { void ShardIndexGenerator::DatabaseWriter() {
int shard_no = task_++; int shard_no = task_++;
while (shard_no < shard_header_.GetShardCount()) { while (shard_no < shard_header_.GetShardCount()) {
auto db = CreateDatabase(shard_no);
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
sqlite3 *db = nullptr;
if (CreateDatabase(shard_no, &db).IsError()) {
MS_LOG(ERROR) << "Failed to create Generate database.";
write_success_ = false; write_success_ = false;
return; return;
} }

MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";

// Pre-processing page information // Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;


std::map<int, int> blob_id_to_page_id; std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids; std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) { for (uint64_t i = 0; i < total_pages; ++i) {
auto ret = shard_header_.GetPage(shard_no, i);
if (ret.second != SUCCESS) {
std::shared_ptr<Page> page_ptr;
if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) {
MS_LOG(ERROR) << "Failed to get page.";
write_success_ = false; write_success_ = false;
return; return;
} }
std::shared_ptr<Page> cur_page = ret.first;
if (cur_page->GetPageType() == "RAW_DATA") {
if (page_ptr->GetPageType() == "RAW_DATA") {
raw_page_ids.push_back(i); raw_page_ids.push_back(i);
} else if (cur_page->GetPageType() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->GetPageTypeID()] = i;
} else if (page_ptr->GetPageType() == "BLOB_DATA") {
blob_id_to_page_id[page_ptr->GetPageTypeID()] = i;
} }
} }


if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) {
MS_LOG(ERROR) << "Failed to execute transaction.";
write_success_ = false; write_success_ = false;
return; return;
} }
@@ -694,21 +578,12 @@ void ShardIndexGenerator::DatabaseWriter() {
shard_no = task_++; shard_no = task_++;
} }
} }
MSRStatus ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
if (file_names.empty()) {
MS_LOG(ERROR) << "Mindrecord files is empty.";
return FAILED;
}
Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty.");
ShardIndexGenerator sg{file_names[0]}; ShardIndexGenerator sg{file_names[0]};
if (SUCCESS != sg.Build()) {
MS_LOG(ERROR) << "Failed to build index generator.";
return FAILED;
}
if (SUCCESS != sg.WriteToDatabase()) {
MS_LOG(ERROR) << "Failed to write to database.";
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(sg.Build());
RETURN_IF_NOT_OK(sg.WriteToDatabase());
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 374
- 549
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
File diff suppressed because it is too large
View File


+ 122
- 181
mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc View File

@@ -30,9 +30,13 @@ namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardSegment::ShardSegment() { SetAllInIndex(false); } ShardSegment::ShardSegment() { SetAllInIndex(false); }


std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) {
RETURN_UNEXPECTED_IF_NULL(fields_ptr);
// Skip if already populated // Skip if already populated
if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_};
if (!candidate_category_fields_.empty()) {
*fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
return Status::OK();
}


std::string sql = "PRAGMA table_info(INDEXES);"; std::string sql = "PRAGMA table_info(INDEXES);";
std::vector<std::vector<std::string>> field_names; std::vector<std::vector<std::string>> field_names;
@@ -40,11 +44,12 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
char *errmsg = nullptr; char *errmsg = nullptr;
int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
std::ostringstream oss;
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
sqlite3_free(errmsg); sqlite3_free(errmsg);
sqlite3_close(database_paths_[0]); sqlite3_close(database_paths_[0]);
database_paths_[0] = nullptr; database_paths_[0] = nullptr;
return {FAILED, vector<std::string>{}};
RETURN_STATUS_UNEXPECTED(oss.str());
} else { } else {
MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index."; MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index.";
} }
@@ -55,53 +60,46 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
sqlite3_free(errmsg); sqlite3_free(errmsg);
sqlite3_close(database_paths_[0]); sqlite3_close(database_paths_[0]);
database_paths_[0] = nullptr; database_paths_[0] = nullptr;
return {FAILED, vector<std::string>{}};
RETURN_STATUS_UNEXPECTED("idx is out of range.");
} }
candidate_category_fields_.push_back(field_names[idx][1]); candidate_category_fields_.push_back(field_names[idx][1]);
idx += 2; idx += 2;
} }
sqlite3_free(errmsg); sqlite3_free(errmsg);
return {SUCCESS, candidate_category_fields_};
*fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_);
return Status::OK();
} }


MSRStatus ShardSegment::SetCategoryField(std::string category_field) {
if (GetCategoryFields().first != SUCCESS) {
MS_LOG(ERROR) << "Get candidate category field failed";
return FAILED;
}
Status ShardSegment::SetCategoryField(std::string category_field) {
std::shared_ptr<vector<std::string>> fields_ptr;
RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr));
category_field = category_field + "_0"; category_field = category_field + "_0";
if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_),
[category_field](std::string x) { return x == category_field; })) { [category_field](std::string x) { return x == category_field; })) {
current_category_field_ = category_field; current_category_field_ = category_field;
return SUCCESS;
return Status::OK();
} }
MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field.";
return FAILED;
RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field.");
} }


std::pair<MSRStatus, std::string> ShardSegment::ReadCategoryInfo() {
Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
RETURN_UNEXPECTED_IF_NULL(category_ptr);
MS_LOG(INFO) << "Read category begin"; MS_LOG(INFO) << "Read category begin";
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info failed";
return {FAILED, ""};
}
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
// Convert category info to json string // Convert category info to json string
auto category_json_string = ToJsonForCategory(ret.second);
*category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));


MS_LOG(INFO) << "Read category end"; MS_LOG(INFO) << "Read category end";


return {SUCCESS, category_json_string};
return Status::OK();
} }


std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegment::WrapCategoryInfo() {
Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) {
RETURN_UNEXPECTED_IF_NULL(category_info_ptr);
std::map<std::string, int> counter; std::map<std::string, int> counter;

if (!ValidateFieldName(current_category_field_)) {
MS_LOG(ERROR) << "category field error from index, it is: " << current_category_field_;
return {FAILED, std::vector<std::tuple<int, std::string, int>>()};
}

CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_),
"Category field error from index, it is: " + current_category_field_);
std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";


@@ -109,13 +107,13 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen
std::vector<std::vector<std::string>> field_count; std::vector<std::vector<std::string>> field_count;


char *errmsg = nullptr; char *errmsg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
std::ostringstream oss;
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
sqlite3_free(errmsg); sqlite3_free(errmsg);
sqlite3_close(db); sqlite3_close(db);
db = nullptr; db = nullptr;
return {FAILED, std::vector<std::tuple<int, std::string, int>>()};
RETURN_STATUS_UNEXPECTED(oss.str());
} else { } else {
MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index."; MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index.";
} }
@@ -127,14 +125,14 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen
} }


int idx = 0; int idx = 0;
std::vector<std::tuple<int, std::string, int>> category_vec(counter.size());
(void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple<std::string, int> item) {
return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item));
});
return {SUCCESS, std::move(category_vec)};
(*category_info_ptr)->resize(counter.size());
(void)std::transform(
counter.begin(), counter.end(), (*category_info_ptr)->begin(),
[&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); });
return Status::OK();
} }


std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec) {
std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) {
std::vector<json> category_json_vec; std::vector<json> category_json_vec;
for (auto q : tri_vec) { for (auto q : tri_vec) {
json j; json j;
@@ -152,27 +150,20 @@ std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, st
return category_info.dump(); return category_info.dump();
} }


std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageById(int64_t category_id,
int64_t page_no,
int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
if (category_id >= static_cast<int>(ret.second.size()) || category_id < 0) {
MS_LOG(ERROR) << "Illegal category id, id: " << category_id;
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
int total_rows_in_category = std::get<2>(ret.second[category_id]);
Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
"Invalid category id, id: " + std::to_string(category_id));
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
// Quit if category not found or page number is out of range // Quit if category not found or page number is out of range
if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 ||
page_no * n_rows_of_page >= total_rows_in_category) {
MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page;
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
page_no * n_rows_of_page < total_rows_in_category,
"Invalid page no / page size, page no: " + std::to_string(page_no) +
", page size: " + std::to_string(n_rows_of_page));


std::vector<std::vector<uint8_t>> page;
auto row_group_summary = ReadRowGroupSummary(); auto row_group_summary = ReadRowGroupSummary();


uint64_t i_start = page_no * n_rows_of_page; uint64_t i_start = page_no * n_rows_of_page;
@@ -183,12 +174,12 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage


auto shard_id = std::get<0>(rg); auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg); auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(
group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id])));
if (SUCCESS != std::get<0>(details)) {
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
auto offsets = std::get<4>(details);
std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
RETURN_IF_NOT_OK(ReadRowGroupCriteria(
group_id, shard_id,
std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
&row_group_brief_ptr));
auto offsets = std::get<3>(*row_group_brief_ptr);
uint64_t number_of_rows = offsets.size(); uint64_t number_of_rows = offsets.size();
if (idx + number_of_rows < i_start) { if (idx + number_of_rows < i_start) {
idx += number_of_rows; idx += number_of_rows;
@@ -197,131 +188,116 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage


for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) {
if (idx >= i_start && idx < i_end) { if (idx >= i_start && idx < i_end) {
auto ret1 = PackImages(group_id, shard_id, offsets[i]);
if (SUCCESS != ret1.first) {
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
page.push_back(std::move(ret1.second));
auto images_ptr = std::make_shared<std::vector<uint8_t>>();
RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
(*page_ptr)->push_back(std::move(*images_ptr));
} }
} }
} }


return {SUCCESS, std::move(page)};
return Status::OK();
} }


std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id, int shard_id,
std::vector<uint64_t> offset) {
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
return {FAILED, std::vector<uint8_t>()};
}
const std::shared_ptr<Page> &blob_page = ret.second;

Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset,
std::shared_ptr<std::vector<uint8_t>> *images_ptr) {
RETURN_UNEXPECTED_IF_NULL(images_ptr);
std::shared_ptr<Page> page_ptr;
RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
// Pack image list // Pack image list
std::vector<uint8_t> images(offset[1] - offset[0]);
auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0];
(*images_ptr)->resize(offset[1] - offset[0]);

auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0];
auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
file_streams_random_[0][shard_id]->close(); file_streams_random_[0][shard_id]->close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
} }


auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&images[0]), offset[1] - offset[0]);
auto &io_read =
file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]);
if (!io_read.good() || io_read.fail() || io_read.bad()) { if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
file_streams_random_[0][shard_id]->close(); file_streams_random_[0][shard_id]->close();
return {FAILED, {}};
RETURN_STATUS_UNEXPECTED("Failed to read file.");
} }

return {SUCCESS, std::move(images)};
return Status::OK();
} }


std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageByName(std::string category_name,
int64_t page_no,
int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::vector<uint8_t>>{}};
}
for (const auto &categories : ret.second) {
Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
for (const auto &categories : *category_info_ptr) {
if (std::get<1>(categories) == category_name) { if (std::get<1>(categories) == category_name) {
auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page);
return result;
RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr));
return Status::OK();
} }
} }


return {FAILED, std::vector<std::vector<uint8_t>>()};
RETURN_STATUS_UNEXPECTED("Category name can not match.");
} }


std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageById(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS || category_id >= static_cast<int>(ret.second.size())) {
MS_LOG(ERROR) << "Illegal category id, id: " << category_id;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
int total_rows_in_category = std::get<2>(ret.second[category_id]);
// Quit if category not found or page number is out of range
if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) {
MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()),
"Invalid category id: " + std::to_string(category_id));


std::vector<std::tuple<std::vector<uint8_t>, json>> page;
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
// Quit if category not found or page number is out of range
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
page_no * n_rows_of_page < total_rows_in_category,
"Invalid page no / page size / total size, page no: " + std::to_string(page_no) +
", page size of page: " + std::to_string(n_rows_of_page) +
", total size: " + std::to_string(total_rows_in_category));
auto row_group_summary = ReadRowGroupSummary(); auto row_group_summary = ReadRowGroupSummary();


int i_start = page_no * n_rows_of_page; int i_start = page_no * n_rows_of_page;
int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page); int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page);
int idx = 0; int idx = 0;
for (const auto &rg : row_group_summary) { for (const auto &rg : row_group_summary) {
if (idx >= i_end) break;
if (idx >= i_end) {
break;
}


auto shard_id = std::get<0>(rg); auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg); auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(
group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id])));
if (SUCCESS != std::get<0>(details)) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
auto offsets = std::get<4>(details);
auto labels = std::get<5>(details);
std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr;
RETURN_IF_NOT_OK(ReadRowGroupCriteria(
group_id, shard_id,
std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""},
&row_group_brief_ptr));
auto offsets = std::get<3>(*row_group_brief_ptr);
auto labels = std::get<4>(*row_group_brief_ptr);


int number_of_rows = offsets.size(); int number_of_rows = offsets.size();
if (idx + number_of_rows < i_start) { if (idx + number_of_rows < i_start) {
idx += number_of_rows; idx += number_of_rows;
continue; continue;
} }

if (number_of_rows > static_cast<int>(labels.size())) {
MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows;
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()),
"Invalid row number of page: " + number_of_rows);
for (int i = 0; i < number_of_rows; ++i, ++idx) { for (int i = 0; i < number_of_rows; ++i, ++idx) {
if (idx >= i_start && idx < i_end) { if (idx >= i_start && idx < i_end) {
auto ret1 = PackImages(group_id, shard_id, offsets[i]);
if (SUCCESS != ret1.first) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}
page.emplace_back(std::move(ret1.second), std::move(labels[i]));
auto images_ptr = std::make_shared<std::vector<uint8_t>>();
RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr));
(*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i]));
} }
} }
} }
return {SUCCESS, std::move(page)};
return Status::OK();
} }


std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageByName(
std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto ret = WrapCategoryInfo();
if (ret.first != SUCCESS) {
MS_LOG(ERROR) << "Get category info";
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}

Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page,
std::shared_ptr<PAGES> *pages_ptr) {
RETURN_UNEXPECTED_IF_NULL(pages_ptr);
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
// category_name to category_id // category_name to category_id
int64_t category_id = -1; int64_t category_id = -1;
for (const auto &categories : ret.second) {
for (const auto &categories : *category_info_ptr) {
std::string categories_name = std::get<1>(categories); std::string categories_name = std::get<1>(categories);


if (categories_name == category_name) { if (categories_name == category_name) {
@@ -329,45 +305,8 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
break; break;
} }
} }

if (category_id == -1) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
}

return ReadAllAtPageById(category_id, page_no, n_rows_of_page);
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy(
int64_t category_id, int64_t page_no, int64_t n_rows_of_page) {
auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page);
if (res.first != SUCCESS) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}};
}

vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data;
std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return {SUCCESS, std::move(json_data)};
}

std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByNamePy(
std::string category_name, int64_t page_no, int64_t n_rows_of_page) {
auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page);
if (res.first != SUCCESS) {
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}};
}
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data;
std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
});
return {SUCCESS, std::move(json_data)};
CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name.");
return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
} }


std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
@@ -382,7 +321,9 @@ std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
} }


std::string ShardSegment::CleanUp(std::string field_name) { std::string ShardSegment::CleanUp(std::string field_name) {
while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back();
while (field_name.back() >= '0' && field_name.back() <= '9') {
field_name.pop_back();
}
field_name.pop_back(); field_name.pop_back();
return field_name; return field_name;
} }


+ 272
- 496
mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc
File diff suppressed because it is too large
View File


+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc View File

@@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
num_categories_(num_categories), num_categories_(num_categories),
replacement_(replacement) {} replacement_(replacement) {}


MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
Status ShardCategory::Execute(ShardTaskList &tasks) { return Status::OK(); }


int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size; if (dataset_size == 0) return dataset_size;


+ 74
- 100
mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc View File

@@ -72,36 +72,36 @@ void ShardColumn::Init(const json &schema_json, bool compress_integer) {
num_blob_column_ = blob_column_.size(); num_blob_column_ = blob_column_.size();
} }


std::pair<MSRStatus, ColumnCategory> ShardColumn::GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type,
uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type,
uint64_t *column_data_type_size, std::vector<int64_t> *column_shape,
ColumnCategory *column_category) {
RETURN_UNEXPECTED_IF_NULL(column_data_type);
RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
RETURN_UNEXPECTED_IF_NULL(column_shape);
RETURN_UNEXPECTED_IF_NULL(column_category);
// Skip if column not found // Skip if column not found
auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return {FAILED, ColumnNotFound};
}
*column_category = CheckColumnName(column_name);
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category.");


// Get data type and size // Get data type and size
auto column_id = column_name_id_[column_name]; auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id]; *column_data_type = column_data_type_[column_id];
*column_data_type_size = ColumnDataTypeSize[*column_data_type]; *column_data_type_size = ColumnDataTypeSize[*column_data_type];
*column_shape = column_shape_[column_id]; *column_shape = column_shape_[column_id];

return {SUCCESS, column_category};
return Status::OK();
} }


MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape) {
RETURN_UNEXPECTED_IF_NULL(column_data_type);
RETURN_UNEXPECTED_IF_NULL(column_data_type_size);
RETURN_UNEXPECTED_IF_NULL(column_shape);
// Skip if column not found // Skip if column not found
auto column_category = CheckColumnName(column_name); auto column_category = CheckColumnName(column_name);
if (column_category == ColumnNotFound) {
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category.");
// Get data type and size // Get data type and size
auto column_id = column_name_id_[column_name]; auto column_id = column_name_id_[column_name];
*column_data_type = column_data_type_[column_id]; *column_data_type = column_data_type_[column_id];
@@ -110,37 +110,31 @@ MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, cons


// Retrieve value from json // Retrieve value from json
if (column_category == ColumnInRaw) { if (column_category == ColumnInRaw) {
if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << ".";
return FAILED;
}
RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes));
*data = reinterpret_cast<const unsigned char *>(data_ptr->get()); *data = reinterpret_cast<const unsigned char *>(data_ptr->get());
return SUCCESS;
return Status::OK();
} }


// Retrieve value from blob // Retrieve value from blob
if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << ".";
return FAILED;
}
RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes));
if (*data == nullptr) { if (*data == nullptr) {
*data = reinterpret_cast<const unsigned char *>(data_ptr->get()); *data = reinterpret_cast<const unsigned char *>(data_ptr->get());
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) {
RETURN_UNEXPECTED_IF_NULL(n_bytes);
RETURN_UNEXPECTED_IF_NULL(data_ptr);
auto column_id = column_name_id_[column_name]; auto column_id = column_name_id_[column_name];
auto column_data_type = column_data_type_[column_id]; auto column_data_type = column_data_type_[column_id];


// Initialize num bytes // Initialize num bytes
*n_bytes = ColumnDataTypeSize[column_data_type]; *n_bytes = ColumnDataTypeSize[column_data_type];
auto json_column_value = columns_json[column_name]; auto json_column_value = columns_json[column_name];
if (!json_column_value.is_string() && !json_column_value.is_number()) {
MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ").";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(),
"Conversion to string or number failed (" + json_column_value.dump() + ").");
switch (column_data_type) { switch (column_data_type) {
case ColumnFloat32: { case ColumnFloat32: {
return GetFloat<float>(data_ptr, json_column_value, false); return GetFloat<float>(data_ptr, json_column_value, false);
@@ -171,12 +165,13 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j
break; break;
} }
} }
return SUCCESS;
return Status::OK();
} }


template <typename T> template <typename T>
MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value,
bool use_double) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
if (json_column_value.is_number()) { if (json_column_value.is_number()) {
array_data[0] = json_column_value; array_data[0] = json_column_value;
@@ -189,8 +184,7 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons
array_data[0] = json_column_value.get<float>(); array_data[0] = json_column_value.get<float>();
} }
} catch (json::exception &e) { } catch (json::exception &e) {
MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ").";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ").");
} }
} }


@@ -199,54 +193,43 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons
for (uint32_t i = 0; i < sizeof(T); i++) { for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i); (*data_ptr)[i] = *(data + i);
} }

return SUCCESS;
return Status::OK();
} }


template <typename T> template <typename T>
MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1);
int64_t temp_value; int64_t temp_value;
bool less_than_zero = false; bool less_than_zero = false;


if (json_column_value.is_number_integer()) { if (json_column_value.is_number_integer()) {
const json json_zero = 0; const json json_zero = 0;
if (json_column_value < json_zero) less_than_zero = true;
if (json_column_value < json_zero) {
less_than_zero = true;
}
temp_value = json_column_value; temp_value = json_column_value;
} else if (json_column_value.is_string()) { } else if (json_column_value.is_string()) {
std::string string_value = json_column_value; std::string string_value = json_column_value;

if (!string_value.empty() && string_value[0] == '-') {
try {
try {
if (!string_value.empty() && string_value[0] == '-') {
temp_value = std::stoll(string_value); temp_value = std::stoll(string_value);
less_than_zero = true; less_than_zero = true;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
}
} else {
try {
} else {
temp_value = static_cast<int64_t>(std::stoull(string_value)); temp_value = static_cast<int64_t>(std::stoull(string_value));
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Conversion to int failed, invalid argument.";
return FAILED;
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Conversion to int failed, out of range.";
return FAILED;
} }
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
} }
} else { } else {
MS_LOG(ERROR) << "Conversion to int failed.";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
} }


if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
MS_LOG(ERROR) << "Conversion to int failed. Out of range";
return FAILED;
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
} }
array_data[0] = static_cast<T>(temp_value); array_data[0] = static_cast<T>(temp_value);


@@ -255,33 +238,26 @@ MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const
for (uint32_t i = 0; i < sizeof(T); i++) { for (uint32_t i = 0; i < sizeof(T); i++) {
(*data_ptr)[i] = *(data + i); (*data_ptr)[i] = *(data + i);
} }

return SUCCESS;
return Status::OK();
} }


MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes) {
Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes) {
RETURN_UNEXPECTED_IF_NULL(data);
uint64_t offset_address = 0; uint64_t offset_address = 0;
auto column_id = column_name_id_[column_name]; auto column_id = column_name_id_[column_name];
if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) {
return FAILED;
}

RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address));
auto column_data_type = column_data_type_[column_id]; auto column_data_type = column_data_type_[column_id];
if (has_compress_blob_ && column_data_type == ColumnInt32) { if (has_compress_blob_ && column_data_type == ColumnInt32) {
if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
} else if (has_compress_blob_ && column_data_type == ColumnInt64) { } else if (has_compress_blob_ && column_data_type == ColumnInt64) {
if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address));
} else { } else {
*data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address]));
} }


return SUCCESS;
return Status::OK();
} }


ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
@@ -296,7 +272,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
// Skip if no compress columns // Skip if no compress columns
*compression_size = 0; *compression_size = 0;
if (!CheckCompressBlob()) return blob;
if (!CheckCompressBlob()) {
return blob;
}


std::vector<uint8_t> dst_blob; std::vector<uint8_t> dst_blob;
uint64_t i_src = 0; uint64_t i_src = 0;
@@ -380,12 +358,14 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
return dst_bytes; return dst_bytes;
} }


MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx) {
RETURN_UNEXPECTED_IF_NULL(num_bytes);
RETURN_UNEXPECTED_IF_NULL(shift_idx);
if (num_blob_column_ == 1) { if (num_blob_column_ == 1) {
*num_bytes = columns_blob.size(); *num_bytes = columns_blob.size();
*shift_idx = 0; *shift_idx = 0;
return SUCCESS;
return Status::OK();
} }
auto blob_id = blob_column_id_[column_name_[column_id]]; auto blob_id = blob_column_id_[column_name_[column_id]];


@@ -396,13 +376,14 @@ MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const


(*shift_idx) += kInt64Len; (*shift_idx) += kInt64Len;


return SUCCESS;
return Status::OK();
} }


template <typename T> template <typename T>
MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes,
uint64_t shift_idx) {
Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) {
RETURN_UNEXPECTED_IF_NULL(data_ptr);
RETURN_UNEXPECTED_IF_NULL(num_bytes);
auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type);
*num_bytes = sizeof(T) * num_elements; *num_bytes = sizeof(T) * num_elements;


@@ -421,19 +402,12 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<


auto data = reinterpret_cast<const unsigned char *>(array_data.get()); auto data = reinterpret_cast<const unsigned char *>(array_data.get());
*data_ptr = std::make_unique<unsigned char[]>(*num_bytes); *data_ptr = std::make_unique<unsigned char[]>(*num_bytes);

// field is none. for example: numpy is null // field is none. for example: numpy is null
if (*num_bytes == 0) { if (*num_bytes == 0) {
return SUCCESS;
return Status::OK();
} }

int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes);
if (ret_code != 0) {
MS_LOG(ERROR) << "Failed to copy data!";
return FAILED;
}

return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!");
return Status::OK();
} }


uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,


+ 5
- 11
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc View File

@@ -55,15 +55,11 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
return 0; return 0;
} }


MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
auto total_no = tasks.Size(); auto total_no = tasks.Size();
if (no_of_padded_samples_ > 0 && first_epoch_) { if (no_of_padded_samples_ > 0 && first_epoch_) {
if (total_no % denominator_ != 0) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. "
<< "task size: " << total_no << ", number padded: " << no_of_padded_samples_
<< ", denominator: " << denominator_;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0,
"Dataset size plus number of padded samples is not divisible by number of shards.");
} }
if (first_epoch_) { if (first_epoch_) {
first_epoch_ = false; first_epoch_ = false;
@@ -74,11 +70,9 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
if (shuffle_ == true) { if (shuffle_ == true) {
shuffle_op_->SetShardSampleCount(GetShardSampleCount()); shuffle_op_->SetShardSampleCount(GetShardSampleCount());
shuffle_op_->UpdateShuffleMode(GetShuffleMode()); shuffle_op_->UpdateShuffleMode(GetShuffleMode());
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
} }
return SUCCESS;
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 166
- 316
mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc View File

@@ -38,104 +38,74 @@ ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), co
index_ = std::make_shared<Index>(); index_ = std::make_shared<Index>();
} }


MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size(); shard_count_ = headers.size();
int shard_index = 0; int shard_index = 0;
bool first = true; bool first = true;
for (const auto &header : headers) { for (const auto &header : headers) {
if (first) { if (first) {
first = false; first = false;
if (ParseSchema(header["schema"]) != SUCCESS) {
return FAILED;
}
if (ParseIndexFields(header["index_fields"]) != SUCCESS) {
return FAILED;
}
if (ParseStatistics(header["statistics"]) != SUCCESS) {
return FAILED;
}
RETURN_IF_NOT_OK(ParseSchema(header["schema"]));
RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"]));
RETURN_IF_NOT_OK(ParseStatistics(header["statistics"]));
ParseShardAddress(header["shard_addresses"]); ParseShardAddress(header["shard_addresses"]);
header_size_ = header["header_size"].get<uint64_t>(); header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>(); page_size_ = header["page_size"].get<uint64_t>();
compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
} }
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED;
}
RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset));
shard_index++; shard_index++;
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::CheckFileStatus(const std::string &path) {
Status ShardHeader::CheckFileStatus(const std::string &path) {
auto realpath = Common::GetRealPath(path); auto realpath = Common::GetRealPath(path);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << path;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path);
std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary); std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
if (!fin) {
MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path;
return FAILED;
}
if (fin.fail()) {
MS_LOG(ERROR) << "Failed to open file. path: " << path;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path);
// fetch file size // fetch file size
auto &io_seekg = fin.seekg(0, std::ios::end); auto &io_seekg = fin.seekg(0, std::ios::end);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
fin.close(); fin.close();
MS_LOG(ERROR) << "File seekg failed. path: " << path;
return FAILED;
RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path);
} }


size_t file_size = fin.tellg(); size_t file_size = fin.tellg();
if (file_size < kMinFileSize) { if (file_size < kMinFileSize) {
fin.close(); fin.close();
MS_LOG(ERROR) << "Invalid file. path: " << path;
return FAILED;
RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path);
} }
fin.close(); fin.close();
return SUCCESS;
return Status::OK();
} }


std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) {
if (CheckFileStatus(path) != SUCCESS) {
return {FAILED, {}};
}

Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
RETURN_IF_NOT_OK(CheckFileStatus(path));
// read header size // read header size
json json_header; json json_header;
std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary);
if (!fin.is_open()) {
MS_LOG(ERROR) << "File seekg failed. path: " << path;
return {FAILED, json_header};
}
CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path);


uint64_t header_size = 0; uint64_t header_size = 0;
auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len); auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
if (!io_read.good() || io_read.fail() || io_read.bad()) { if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
fin.close(); fin.close();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("File read failed");
} }


if (header_size > kMaxHeaderSize) { if (header_size > kMaxHeaderSize) {
fin.close(); fin.close();
MS_LOG(ERROR) << "Invalid file content. path: " << path;
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path);
} }


// read header content // read header content
std::vector<uint8_t> header_content(header_size); std::vector<uint8_t> header_content(header_size);
auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size); auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
MS_LOG(ERROR) << "File read failed. path: " << path;
fin.close(); fin.close();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("File read failed. path: " + path);
} }


fin.close(); fin.close();
@@ -144,34 +114,35 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path)
try { try {
json_header = json::parse(raw_header_content); json_header = json::parse(raw_header_content);
} catch (json::parse_error &e) { } catch (json::parse_error &e) {
MS_LOG(ERROR) << "Json parse error: " << e.what();
return {FAILED, json_header};
RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what()));
} }
return {SUCCESS, json_header};
*header_ptr = std::make_shared<json>(json_header);
return Status::OK();
} }


std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) {
auto ret = ValidateHeader(file_path);
if (SUCCESS != ret.first) {
return {FAILED, json()};
}
json raw_header = ret.second;
Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
std::shared_ptr<json> raw_header;
RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header));
uint64_t compression_size = uint64_t compression_size =
raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]},
raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", (*raw_header)["shard_addresses"]},
{"header_size", (*raw_header)["header_size"]},
{"page_size", (*raw_header)["page_size"]},
{"compression_size", compression_size}, {"compression_size", compression_size},
{"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]},
{"version", raw_header["version"]}};
return {SUCCESS, header};
{"index_fields", (*raw_header)["index_fields"]},
{"blob_fields", (*raw_header)["schema"][0]["blob_fields"]},
{"schema", (*raw_header)["schema"][0]["schema"]},
{"version", (*raw_header)["version"]}};
*header_ptr = std::make_shared<json>(header);
return Status::OK();
} }


MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) {
uint32_t thread_num = std::thread::hardware_concurrency(); uint32_t thread_num = std::thread::hardware_concurrency();
if (thread_num == 0) thread_num = kThreadNumber;
if (thread_num == 0) {
thread_num = kThreadNumber;
}
uint32_t work_thread_num = 0; uint32_t work_thread_num = 0;
uint32_t shard_count = file_paths.size(); uint32_t shard_count = file_paths.size();
int group_num = ceil(shard_count * 1.0 / thread_num); int group_num = ceil(shard_count * 1.0 / thread_num);
@@ -194,12 +165,10 @@ MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths,
} }
if (thread_status) { if (thread_status) {
thread_status = false; thread_status = false;
return FAILED;
RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread.");
} }
if (SUCCESS != InitializeHeader(headers, load_dataset)) {
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset));
return Status::OK();
} }


void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers, void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers,
@@ -208,48 +177,39 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
return; return;
} }
for (int x = start; x < end; ++x) { for (int x = start; x < end; ++x) {
auto ret = ValidateHeader(realAddresses[x]);
if (SUCCESS != ret.first) {
std::shared_ptr<json> header;
if (ValidateHeader(realAddresses[x], &header).IsError()) {
thread_status = true; thread_status = true;
return; return;
} }
json header;
header = ret.second;
header["shard_addresses"] = realAddresses;
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
(*header)["shard_addresses"] = realAddresses;
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump()
<< ", lib version is: " << kVersion; << ", lib version is: " << kVersion;
thread_status = true; thread_status = true;
return; return;
} }
headers[x] = header;
headers[x] = *header;
} }
} }


MSRStatus ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
std::vector<std::string> file_names(file_paths.size()); std::vector<std::string> file_names(file_paths.size());
std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string {
if (GetFileName(fp).first == SUCCESS) {
return GetFileName(fp).second;
}
std::shared_ptr<std::string> fn;
return GetFileName(fp, &fn).IsOk() ? *fn : "";
}); });


shard_addresses_ = std::move(file_names); shard_addresses_ = std::move(file_names);
shard_count_ = file_paths.size(); shard_count_ = file_paths.size();
if (shard_count_ == 0) {
return FAILED;
}
if (shard_count_ <= kMaxShardCount) {
pages_.resize(shard_count_);
} else {
return FAILED;
}
return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
"shard count is invalid. shard count: " + std::to_string(shard_count_));
pages_.resize(shard_count_);
return Status::OK();
} }


void ShardHeader::ParseHeader(const json &header) {}

MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
Status ShardHeader::ParseIndexFields(const json &index_fields) {
std::vector<std::pair<uint64_t, std::string>> parsed_index_fields; std::vector<std::pair<uint64_t, std::string>> parsed_index_fields;
for (auto &index_field : index_fields) { for (auto &index_field : index_fields) {
auto schema_id = index_field["schema_id"].get<uint64_t>(); auto schema_id = index_field["schema_id"].get<uint64_t>();
@@ -257,18 +217,15 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name); std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name);
parsed_index_fields.push_back(parsed_index_field); parsed_index_fields.push_back(parsed_index_field);
} }
if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) {
return FAILED;
}
return SUCCESS;
RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields));
return Status::OK();
} }


MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false // set shard_index when load_dataset is false
if (shard_count_ > kMaxFileCount) {
MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount;
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(
shard_count_ <= kMaxFileCount,
"The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount));
if (pages_.empty() && shard_count_ <= kMaxFileCount) { if (pages_.empty() && shard_count_ <= kMaxFileCount) {
pages_.resize(shard_count_); pages_.resize(shard_count_);
} }
@@ -295,44 +252,37 @@ MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_d
pages_[shard_index].push_back(std::move(parsed_page)); pages_[shard_index].push_back(std::move(parsed_page));
} }
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::ParseStatistics(const json &statistics) {
Status ShardHeader::ParseStatistics(const json &statistics) {
for (auto &statistic : statistics) { for (auto &statistic : statistics) {
if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) {
MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump();
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(
statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
"Deserialize statistics failed, statistic: " + statistics.dump());
std::string statistic_description = statistic["desc"].get<std::string>(); std::string statistic_description = statistic["desc"].get<std::string>();
json statistic_body = statistic["statistics"]; json statistic_body = statistic["statistics"];
std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body); std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
if (!parsed_statistic) {
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(parsed_statistic);
AddStatistic(parsed_statistic); AddStatistic(parsed_statistic);
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::ParseSchema(const json &schemas) {
Status ShardHeader::ParseSchema(const json &schemas) {
for (auto &schema : schemas) { for (auto &schema : schemas) {
// change how we get schemaBody once design is finalized // change how we get schemaBody once design is finalized
if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() ||
schema.find("schema") == schema.end()) {
MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump();
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
schema.find("schema") != schema.end(),
"Deserialize schema failed. schema: " + schema.dump());
std::string schema_description = schema["desc"].get<std::string>(); std::string schema_description = schema["desc"].get<std::string>();
std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>(); std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
json schema_body = schema["schema"]; json schema_body = schema["schema"];
std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body); std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body);
if (!parsed_schema) {
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(parsed_schema);
AddSchema(parsed_schema); AddSchema(parsed_schema);
} }
return SUCCESS;
return Status::OK();
} }


void ShardHeader::ParseShardAddress(const json &address) { void ShardHeader::ParseShardAddress(const json &address) {
@@ -340,7 +290,7 @@ void ShardHeader::ParseShardAddress(const json &address) {
} }


std::vector<std::string> ShardHeader::SerializeHeader() { std::vector<std::string> ShardHeader::SerializeHeader() {
std::vector<string> header;
std::vector<std::string> header;
auto index = SerializeIndexFields(); auto index = SerializeIndexFields();
auto stats = SerializeStatistics(); auto stats = SerializeStatistics();
auto schema = SerializeSchema(); auto schema = SerializeSchema();
@@ -406,45 +356,42 @@ std::string ShardHeader::SerializeSchema() {


std::string ShardHeader::SerializeShardAddress() { std::string ShardHeader::SerializeShardAddress() {
json j; json j;
(void)std::transform(shard_addresses_.begin(), shard_addresses_.end(), std::back_inserter(j),
[](const std::string &addr) { return GetFileName(addr).second; });
std::shared_ptr<std::string> fn_ptr;
for (const auto &addr : shard_addresses_) {
(void)GetFileName(addr, &fn_ptr);
j.emplace_back(*fn_ptr);
}
return j.dump(); return j.dump();
} }


std::pair<std::shared_ptr<Page>, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) {
Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
return std::make_pair(pages_[shard_id][page_id], SUCCESS);
} else {
return std::make_pair(nullptr, FAILED);
*page_ptr = pages_[shard_id][page_id];
return Status::OK();
} }
page_ptr = nullptr;
RETURN_STATUS_UNEXPECTED("Failed to Get Page.");
} }


MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
int shard_id = new_page->GetShardID(); int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID(); int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id][page_id] = new_page; pages_[shard_id][page_id] = new_page;
return SUCCESS;
} else {
return FAILED;
return Status::OK();
} }
RETURN_STATUS_UNEXPECTED("Failed to Set Page.");
} }


MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
if (new_page == nullptr) {
return FAILED;
}
Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
int shard_id = new_page->GetShardID(); int shard_id = new_page->GetShardID();
int page_id = new_page->GetPageID(); int page_id = new_page->GetPageID();
if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
pages_[shard_id].push_back(new_page); pages_[shard_id].push_back(new_page);
return SUCCESS;
} else {
return FAILED;
return Status::OK();
} }
RETURN_STATUS_UNEXPECTED("Failed to Add Page.");
} }


int64_t ShardHeader::GetLastPageId(const int &shard_id) { int64_t ShardHeader::GetLastPageId(const int &shard_id) {
@@ -468,20 +415,18 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag
return last_page_id; return last_page_id;
} }


const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(const int &group_id,
const int &shard_id) {
if (shard_id >= static_cast<int>(pages_.size())) {
MS_LOG(ERROR) << "Shard id is more than sum of shards.";
return {FAILED, nullptr};
}
Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) {
RETURN_UNEXPECTED_IF_NULL(page_ptr);
CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards.");
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
auto page = pages_[shard_id][i - 1]; auto page = pages_[shard_id][i - 1];
if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
return {SUCCESS, page};
*page_ptr = std::make_shared<Page>(*page);
return Status::OK();
} }
} }
MS_LOG(ERROR) << "Could not get page by group id " << group_id;
return {FAILED, nullptr};
page_ptr = nullptr;
RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id);
} }


int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
@@ -524,151 +469,88 @@ std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
return index; return index;
} }


MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
// check field name is or is not valid // check field name is or is not valid
if (schema.find(field) == schema.end()) {
MS_LOG(ERROR) << "Schema do not contain the field: " << field << ".";
return FAILED;
}

if (schema[field]["type"] == "bytes") {
MS_LOG(ERROR) << field << " is bytes type, can not be schema index field.";
return FAILED;
}

if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) {
MS_LOG(ERROR) << field << " array can not be schema index field.";
return FAILED;
}
return SUCCESS;
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema.");
CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field.");
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
"array can not be as index field.");
return Status::OK();
} }


MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
if (fields.empty()) {
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty.");
// create index Object // create index Object
std::shared_ptr<Index> index = InitIndexPtr(); std::shared_ptr<Index> index = InitIndexPtr();

if (fields.size() == kInt0) {
MS_LOG(ERROR) << "There are no index fields";
return FAILED;
}

if (GetSchemas().empty()) {
MS_LOG(ERROR) << "No schema is set";
return FAILED;
}

for (const auto &schemaPtr : schema_) { for (const auto &schemaPtr : schema_) {
auto result = GetSchemaByID(schemaPtr->GetSchemaID());
if (result.second != SUCCESS) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}

if (result.first == nullptr) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}

json schema = result.first->GetSchema().at("schema");

std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr));
json schema = schema_ptr->GetSchema().at("schema");
// checkout and add fields for each schema // checkout and add fields for each schema
std::set<std::string> field_set; std::set<std::string> field_set;
for (const auto &item : index->GetFields()) { for (const auto &item : index->GetFields()) {
field_set.insert(item.second); field_set.insert(item.second);
} }
for (const auto &field : fields) { for (const auto &field : fields) {
if (field_set.find(field) != field_set.end()) {
MS_LOG(ERROR) << "Add same index field twice";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
// check field name is or is not valid // check field name is or is not valid
if (CheckIndexField(field, schema) == FAILED) {
return FAILED;
}
RETURN_IF_NOT_OK(CheckIndexField(field, schema));
field_set.insert(field); field_set.insert(field);

// add field into index // add field into index
index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
} }
} }

index_ = index; index_ = index;
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
// get all schema id // get all schema id
for (const auto &schema : schema_) { for (const auto &schema : schema_) {
auto bucket_it = bucket_count.find(schema->GetSchemaID()); auto bucket_it = bucket_count.find(schema->GetSchemaID());
if (bucket_it != bucket_count.end()) {
MS_LOG(ERROR) << "Schema duplication";
return FAILED;
} else {
bucket_count.insert(schema->GetSchemaID());
}
CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication.");
bucket_count.insert(schema->GetSchemaID());
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) {
if (fields.empty()) {
return Status::OK();
}
// create index Object // create index Object
std::shared_ptr<Index> index = InitIndexPtr(); std::shared_ptr<Index> index = InitIndexPtr();

if (fields.size() == kInt0) {
MS_LOG(ERROR) << "There are no index fields";
return FAILED;
}

// get all schema id // get all schema id
std::set<uint64_t> bucket_count; std::set<uint64_t> bucket_count;
if (GetAllSchemaID(bucket_count) != SUCCESS) {
return FAILED;
}

RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count));
// check and add fields for each schema // check and add fields for each schema
std::set<std::pair<uint64_t, std::string>> field_set; std::set<std::pair<uint64_t, std::string>> field_set;
for (const auto &item : index->GetFields()) { for (const auto &item : index->GetFields()) {
field_set.insert(item); field_set.insert(item);
} }
for (const auto &field : fields) { for (const auto &field : fields) {
if (field_set.find(field) != field_set.end()) {
MS_LOG(ERROR) << "Add same index field twice";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
uint64_t schema_id = field.first; uint64_t schema_id = field.first;
std::string field_name = field.second; std::string field_name = field.second;


// check schemaId is or is not valid // check schemaId is or is not valid
if (bucket_count.find(schema_id) == bucket_count.end()) {
MS_LOG(ERROR) << "Illegal schema id: " << schema_id;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id);
// check field name is or is not valid // check field name is or is not valid
auto result = GetSchemaByID(schema_id);
if (result.second != SUCCESS) {
MS_LOG(ERROR) << "Could not get schema by id.";
return FAILED;
}
json schema = result.first->GetSchema().at("schema");
if (schema.find(field_name) == schema.end()) {
MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name;
return FAILED;
}

if (CheckIndexField(field_name, schema) == FAILED) {
return FAILED;
}

std::shared_ptr<Schema> schema_ptr;
RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr));
json schema = schema_ptr->GetSchema().at("schema");
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(),
"Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name);
RETURN_IF_NOT_OK(CheckIndexField(field_name, schema));
field_set.insert(field); field_set.insert(field);

// add field into index // add field into index
index.get()->AddIndexField(schema_id, field_name);
index->AddIndexField(schema_id, field_name);
} }
index_ = index; index_ = index;
return SUCCESS;
return Status::OK();
} }


std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
@@ -686,103 +568,71 @@ std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return


std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }


std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) {
Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
RETURN_UNEXPECTED_IF_NULL(schema_ptr);
int64_t schemaSize = schema_.size(); int64_t schemaSize = schema_.size();
if (schema_id < 0 || schema_id >= schemaSize) {
MS_LOG(ERROR) << "Illegal schema id";
return std::make_pair(nullptr, FAILED);
}
return std::make_pair(schema_.at(schema_id), SUCCESS);
CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid.");
*schema_ptr = schema_.at(schema_id);
return Status::OK();
} }


std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) {
Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
RETURN_UNEXPECTED_IF_NULL(statistics_ptr);
int64_t statistics_size = statistics_.size(); int64_t statistics_size = statistics_.size();
if (statistic_id < 0 || statistic_id >= statistics_size) {
return std::make_pair(nullptr, FAILED);
}
return std::make_pair(statistics_.at(statistic_id), SUCCESS);
CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid.");
*statistics_ptr = statistics_.at(statistic_id);
return Status::OK();
} }


MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) {
Status ShardHeader::PagesToFile(const std::string dump_file_name) {
auto realpath = Common::GetRealPath(dump_file_name); auto realpath = Common::GetRealPath(dump_file_name);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
// write header content to file, dump whatever is in the file before // write header content to file, dump whatever is in the file before
std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out); std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
if (page_out_handle.fail()) {
MS_LOG(ERROR) << "Failed in opening page file";
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file.");
auto pages = SerializePage(); auto pages = SerializePage();
for (const auto &shard_pages : pages) { for (const auto &shard_pages : pages) {
page_out_handle << shard_pages << "\n"; page_out_handle << shard_pages << "\n";
} }

page_out_handle.close(); page_out_handle.close();
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
Status ShardHeader::FileToPages(const std::string dump_file_name) {
for (auto &v : pages_) { // clean pages for (auto &v : pages_) { // clean pages
v.clear(); v.clear();
} }

auto realpath = Common::GetRealPath(dump_file_name); auto realpath = Common::GetRealPath(dump_file_name);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name;
return FAILED;
}

CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
// attempt to open the file contains the page in json // attempt to open the file contains the page in json
std::ifstream page_in_handle(realpath.value()); std::ifstream page_in_handle(realpath.value());

if (!page_in_handle.good()) {
MS_LOG(INFO) << "No page file exists.";
return SUCCESS;
}

CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists.");
std::string line; std::string line;
while (std::getline(page_in_handle, line)) { while (std::getline(page_in_handle, line)) {
if (SUCCESS != ParsePage(json::parse(line), -1, true)) {
return FAILED;
}
RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true));
} }

page_in_handle.close(); page_in_handle.close();
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
if (header_ptr == nullptr) {
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
return FAILED;
}
Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
RETURN_UNEXPECTED_IF_NULL(header_ptr);
auto schema_ptr = Schema::Build("mindrecord", schema); auto schema_ptr = Schema::Build("mindrecord", schema);
if (schema_ptr == nullptr) {
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
return FAILED;
}
RETURN_UNEXPECTED_IF_NULL(schema_ptr);
schema_id = (*header_ptr)->AddSchema(schema_ptr); schema_id = (*header_ptr)->AddSchema(schema_ptr);
// create index // create index
std::vector<std::pair<uint64_t, std::string>> id_index_fields; std::vector<std::pair<uint64_t, std::string>> id_index_fields;
if (!index_fields.empty()) { if (!index_fields.empty()) {
(void)std::transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
[schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) {
MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index.";
return FAILED;
}
(void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields),
[schema_id](const std::string &el) { return std::make_pair(schema_id, el); });
RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields));
} }


auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; auto build_schema_ptr = (*header_ptr)->GetSchemas()[0];
blob_fields = build_schema_ptr->GetBlobFields(); blob_fields = build_schema_ptr->GetBlobFields();
return SUCCESS;
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 3
- 5
mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc View File

@@ -37,13 +37,11 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
} }


MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) {
Status ShardPkSample::SufExecute(ShardTaskList &tasks) {
if (shuffle_ == true) { if (shuffle_ == true) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
} }
return SUCCESS;
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 10
- 17
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc View File

@@ -80,7 +80,7 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return 0; return 0;
} }


MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
if (tasks.permutation_.empty()) { if (tasks.permutation_.empty()) {
ShardTaskList new_tasks; ShardTaskList new_tasks;
int total_no = static_cast<int>(tasks.sample_ids_.size()); int total_no = static_cast<int>(tasks.sample_ids_.size());
@@ -110,9 +110,7 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
} else { } else {
ShardTaskList new_tasks; ShardTaskList new_tasks;
if (taking > static_cast<int>(tasks.sample_ids_.size())) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int>(tasks.sample_ids_.size()), "taking is out of range.");
int total_no = static_cast<int>(tasks.permutation_.size()); int total_no = static_cast<int>(tasks.permutation_.size());
int cnt = 0; int cnt = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
@@ -122,10 +120,10 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
} }
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
} }
return SUCCESS;
return Status::OK();
} }


MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
Status ShardSample::Execute(ShardTaskList &tasks) {
if (offset_ != -1) { if (offset_ != -1) {
int64_t old_v = 0; int64_t old_v = 0;
int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); int num_rows_ = static_cast<int>(tasks.sample_ids_.size());
@@ -146,10 +144,8 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
no_of_samples_ = std::min(no_of_samples_, total_no); no_of_samples_ = std::min(no_of_samples_, total_no);
taking = no_of_samples_ - no_of_samples_ % no_of_categories; taking = no_of_samples_ - no_of_samples_ % no_of_categories;
} else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
if (indices_.size() > static_cast<size_t>(total_no)) {
MS_LOG(ERROR) << "parameter indices's size is greater than dataset size.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no),
"Parameter indices's size is greater than dataset size.");
} else { // constructor TopPercent } else { // constructor TopPercent
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
if (numerator_ == 1 && denominator_ > 1) { // sharding if (numerator_ == 1 && denominator_ > 1) { // sharding
@@ -159,20 +155,17 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) {
taking -= (taking % no_of_categories); taking -= (taking % no_of_categories);
} }
} else { } else {
MS_LOG(ERROR) << "parameter numerator or denominator is illegal";
return FAILED;
RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid.");
} }
} }
return UpdateTasks(tasks, taking); return UpdateTasks(tasks, taking);
} }


MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) {
Status ShardSample::SufExecute(ShardTaskList &tasks) {
if (sampler_type_ == kSubsetRandomSampler) { if (sampler_type_ == kSubsetRandomSampler) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
}
RETURN_IF_NOT_OK((*shuffle_op_)(tasks));
} }
return SUCCESS;
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 0
- 12
mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc View File

@@ -38,12 +38,6 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) {
return std::make_shared<Schema>(object_schema); return std::make_shared<Schema>(object_schema);
} }


std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) {
// validate check
json schema_json = nlohmann::detail::ToJsonImpl(schema);
return Build(std::move(desc), schema_json);
}

std::string Schema::GetDesc() const { return desc_; } std::string Schema::GetDesc() const { return desc_; }


json Schema::GetSchema() const { json Schema::GetSchema() const {
@@ -54,12 +48,6 @@ json Schema::GetSchema() const {
return str_schema; return str_schema;
} }


pybind11::object Schema::GetSchemaForPython() const {
json schema_json = GetSchema();
pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json);
return schema_py;
}

void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }


int64_t Schema::GetSchemaID() const { return schema_id_; } int64_t Schema::GetSchemaID() const { return schema_id_; }


+ 4
- 5
mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc View File

@@ -38,7 +38,7 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c
return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); return std::min(static_cast<int64_t>(no_of_samples_), dataset_size);
} }


MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
Status ShardSequentialSample::Execute(ShardTaskList &tasks) {
int64_t taking; int64_t taking;
int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size());
if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) {
@@ -58,16 +58,15 @@ MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) {
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
} else { // shuffled } else { // shuffled
ShardTaskList new_tasks; ShardTaskList new_tasks;
if (taking > static_cast<int64_t>(tasks.permutation_.size())) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int64_t>(tasks.permutation_.size()),
"Taking is out of task range.");
total_no = static_cast<int64_t>(tasks.permutation_.size()); total_no = static_cast<int64_t>(tasks.permutation_.size());
for (size_t i = offset_; i < taking + offset_; ++i) { for (size_t i = offset_; i < taking + offset_; ++i) {
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
} }
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
} }
return SUCCESS;
return Status::OK();
} }


} // namespace mindrecord } // namespace mindrecord


+ 14
- 29
mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc View File

@@ -42,7 +42,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_);
} }


MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories; uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories;
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
for (uint32_t i = 0; i < tasks.categories; i++) { for (uint32_t i = 0; i < tasks.categories; i++) {
@@ -62,17 +62,14 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) {
} }
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);


return SUCCESS;
return Status::OK();
} }


MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
if (no_of_samples_ == 0) { if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size()); no_of_samples_ = static_cast<int>(tasks.Size());
} }
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
auto shard_sample_cout = GetShardSampleCount(); auto shard_sample_cout = GetShardSampleCount();


// shuffle the files index // shuffle the files index
@@ -118,16 +115,14 @@ MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]); new_tasks.AssignTask(tasks, tasks.permutation_[i]);
} }
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
return Status::OK();
} }


MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
if (no_of_samples_ == 0) { if (no_of_samples_ == 0) {
no_of_samples_ = static_cast<int>(tasks.Size()); no_of_samples_ = static_cast<int>(tasks.Size());
} }
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
// reconstruct the permutation in file // reconstruct the permutation in file
// -- before -- // -- before --
// file1: [0, 1, 2] // file1: [0, 1, 2]
@@ -154,13 +149,12 @@ MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
new_tasks.AssignTask(tasks, tasks.permutation_[i]); new_tasks.AssignTask(tasks, tasks.permutation_[i]);
} }
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
return Status::OK();
} }


MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
Status ShardShuffle::Execute(ShardTaskList &tasks) {
if (reshuffle_each_epoch_) shuffle_seed_++; if (reshuffle_each_epoch_) shuffle_seed_++;
if (tasks.categories < 1) {
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid.");
if (shuffle_type_ == kShuffleSample) { // shuffle each sample if (shuffle_type_ == kShuffleSample) { // shuffle each sample
if (tasks.permutation_.empty() == true) { if (tasks.permutation_.empty() == true) {
tasks.MakePerm(); tasks.MakePerm();
@@ -169,10 +163,7 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
if (replacement_ == true) { if (replacement_ == true) {
ShardTaskList new_tasks; ShardTaskList new_tasks;
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
if (no_of_samples_ <= 0) {
MS_LOG(ERROR) << "no_of_samples need to be positive.";
return FAILED;
}
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
for (uint32_t i = 0; i < no_of_samples_; ++i) { for (uint32_t i = 0; i < no_of_samples_; ++i) {
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
} }
@@ -190,20 +181,14 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) {
ShardTaskList::TaskListSwap(tasks, new_tasks); ShardTaskList::TaskListSwap(tasks, new_tasks);
} }
} else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) { } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) {
auto ret = ShuffleInfile(tasks);
if (ret != SUCCESS) {
return ret;
}
RETURN_IF_NOT_OK(ShuffleInfile(tasks));
} else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) { } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) {
auto ret = ShuffleFiles(tasks);
if (ret != SUCCESS) {
return ret;
}
RETURN_IF_NOT_OK(ShuffleFiles(tasks));
} }
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
return this->CategoryShuffle(tasks); return this->CategoryShuffle(tasks);
} }
return SUCCESS;
return Status::OK();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 0
- 18
mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc View File

@@ -35,19 +35,6 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, const json &stat
return std::make_shared<Statistics>(object_statistics); return std::make_shared<Statistics>(object_statistics);
} }


std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle statistics) {
// validate check
json statistics_json = nlohmann::detail::ToJsonImpl(statistics);
if (!Validate(statistics_json)) {
return nullptr;
}
Statistics object_statistics;
object_statistics.desc_ = std::move(desc);
object_statistics.statistics_ = statistics_json;
object_statistics.statistics_id_ = -1;
return std::make_shared<Statistics>(object_statistics);
}

std::string Statistics::GetDesc() const { return desc_; } std::string Statistics::GetDesc() const { return desc_; }


json Statistics::GetStatistics() const { json Statistics::GetStatistics() const {
@@ -57,11 +44,6 @@ json Statistics::GetStatistics() const {
return str_statistics; return str_statistics;
} }


pybind11::object Statistics::GetStatisticsForPython() const {
json str_statistics = Statistics::GetStatistics();
return nlohmann::detail::FromJsonImpl(str_statistics);
}

void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; }


int64_t Statistics::GetStatisticsID() const { return statistics_id_; } int64_t Statistics::GetStatisticsID() const { return statistics_id_; }


+ 2
- 8
mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc View File

@@ -85,15 +85,9 @@ uint32_t ShardTaskList::SizeOfRows() const {
return nRows; return nRows;
} }


ShardTask &ShardTaskList::GetTaskByID(size_t id) {
MS_ASSERT(id < task_list_.size());
return task_list_[id];
}
ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; }


int ShardTaskList::GetTaskSampleByID(size_t id) {
MS_ASSERT(id < sample_ids_.size());
return sample_ids_[id];
}
int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; }


int ShardTaskList::GetRandomTaskID() { int ShardTaskList::GetRandomTaskID() {
std::mt19937 gen = mindspore::dataset::GetRandomDevice(); std::mt19937 gen = mindspore::dataset::GetRandomDevice();


+ 1
- 1
mindspore/mindrecord/shardreader.py View File

@@ -70,7 +70,7 @@ class ShardReader:
Raises: Raises:
MRMLaunchError: If failed to launch worker threads. MRMLaunchError: If failed to launch worker threads.
""" """
ret = self._reader.launch(False)
ret = self._reader.launch()
if ret != ms.MSRStatus.SUCCESS: if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to launch worker threads.") logger.error("Failed to launch worker threads.")
raise MRMLaunchError raise MRMLaunchError


+ 4
- 26
mindspore/mindrecord/shardsegment.py View File

@@ -19,7 +19,6 @@ import mindspore._c_mindrecord as ms
from mindspore import log as logger from mindspore import log as logger
from .shardutils import populate_data, SUCCESS from .shardutils import populate_data, SUCCESS
from .shardheader import ShardHeader from .shardheader import ShardHeader
from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError


__all__ = ['ShardSegment'] __all__ = ['ShardSegment']


@@ -73,15 +72,8 @@ class ShardSegment:
Returns: Returns:
list[str], by which data could be grouped. list[str], by which data could be grouped.


Raises:
MRMFetchCandidateFieldsError: If failed to get candidate category fields.
""" """
ret, fields = self._segment.get_category_fields()
if ret != SUCCESS:
logger.error("Failed to get candidate category fields.")
raise MRMFetchCandidateFieldsError
return fields

return self._segment.get_category_fields()


def set_category_field(self, category_field): def set_category_field(self, category_field):
"""Select one category field to use.""" """Select one category field to use."""
@@ -94,14 +86,8 @@ class ShardSegment:
Returns: Returns:
str, description fo group information. str, description fo group information.


Raises:
MRMReadCategoryInfoError: If failed to read category information.
""" """
ret, category_info = self._segment.read_category_info()
if ret != SUCCESS:
logger.error("Failed to read category information.")
raise MRMReadCategoryInfoError
return category_info
return self._segment.read_category_info()


def read_at_page_by_id(self, category_id, page, num_row): def read_at_page_by_id(self, category_id, page, num_row):
""" """
@@ -116,13 +102,9 @@ class ShardSegment:
list[dict] list[dict]


Raises: Raises:
MRMFetchDataError: If failed to read by category id.
MRMUnsupportedSchemaError: If schema is invalid. MRMUnsupportedSchemaError: If schema is invalid.
""" """
ret, data = self._segment.read_at_page_by_id(category_id, page, num_row)
if ret != SUCCESS:
logger.error("Failed to read by category id.")
raise MRMFetchDataError
data = self._segment.read_at_page_by_id(category_id, page, num_row)
return [populate_data(raw, blob, self._columns, self._header.blob_fields, return [populate_data(raw, blob, self._columns, self._header.blob_fields,
self._header.schema) for blob, raw in data] self._header.schema) for blob, raw in data]


@@ -139,12 +121,8 @@ class ShardSegment:
list[dict] list[dict]


Raises: Raises:
MRMFetchDataError: If failed to read by category name.
MRMUnsupportedSchemaError: If schema is invalid. MRMUnsupportedSchemaError: If schema is invalid.
""" """
ret, data = self._segment.read_at_page_by_name(category_name, page, num_row)
if ret != SUCCESS:
logger.error("Failed to read by category name.")
raise MRMFetchDataError
data = self._segment.read_at_page_by_name(category_name, page, num_row)
return [populate_data(raw, blob, self._columns, self._header.blob_fields, return [populate_data(raw, blob, self._columns, self._header.blob_fields,
self._header.schema) for blob, raw in data] self._header.schema) for blob, raw in data]

+ 2
- 2
tests/ut/cpp/mindrecord/ut_common.cc View File

@@ -384,8 +384,8 @@ void ShardWriterImageNetOpenForAppend(string filename) {
{ {
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
mindrecord::ShardWriter fw; mindrecord::ShardWriter fw;
auto ret = fw.OpenForAppend(filename);
if (ret == FAILED) {
auto status = fw.OpenForAppend(filename);
if (status.IsError()) {
return; return;
} }




+ 8
- 3
tests/ut/cpp/mindrecord/ut_shard.cc View File

@@ -121,14 +121,19 @@ TEST_F(TestShard, TestShardHeaderPart) {
re_statistics.push_back(*statistic); re_statistics.push_back(*statistic);
} }
ASSERT_EQ(re_statistics, validate_statistics); ASSERT_EQ(re_statistics, validate_statistics);
ASSERT_EQ(header_data.GetStatisticByID(-1).second, FAILED);
ASSERT_EQ(header_data.GetStatisticByID(10).second, FAILED);
std::shared_ptr<Statistics> statistics_ptr;

auto status = header_data.GetStatisticByID(-1, &statistics_ptr);
EXPECT_FALSE(status.IsOk());
status = header_data.GetStatisticByID(10, &statistics_ptr);
EXPECT_FALSE(status.IsOk());


// test add index fields // test add index fields
std::vector<std::pair<uint64_t, std::string>> fields; std::vector<std::pair<uint64_t, std::string>> fields;
std::pair<uint64_t, std::string> pair1(0, "name"); std::pair<uint64_t, std::string> pair1(0, "name");
fields.push_back(pair1); fields.push_back(pair1);
ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS);
status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields();
ASSERT_EQ(resFields, fields); ASSERT_EQ(resFields, fields);
} }


+ 19
- 18
tests/ut/cpp/mindrecord/ut_shard_header_test.cc View File

@@ -79,36 +79,37 @@ TEST_F(TestShardHeader, AddIndexFields) {
std::pair<uint64_t, std::string> index_field2(schema_id1, "box"); std::pair<uint64_t, std::string> index_field2(schema_id1, "box");
fields.push_back(index_field1); fields.push_back(index_field1);
fields.push_back(index_field2); fields.push_back(index_field2);
MSRStatus res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, SUCCESS);
Status status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());

ASSERT_EQ(header_data.GetFields().size(), 2); ASSERT_EQ(header_data.GetFields().size(), 2);


fields.clear(); fields.clear();
std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); std::pair<uint64_t, std::string> index_field3(schema_id1, "name");
fields.push_back(index_field3); fields.push_back(index_field3);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2); ASSERT_EQ(header_data.GetFields().size(), 2);


fields.clear(); fields.clear();
std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); std::pair<uint64_t, std::string> index_field4(schema_id1, "names");
fields.push_back(index_field4); fields.push_back(index_field4);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2); ASSERT_EQ(header_data.GetFields().size(), 2);


fields.clear(); fields.clear();
std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name");
fields.push_back(index_field5); fields.push_back(index_field5);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2); ASSERT_EQ(header_data.GetFields().size(), 2);


fields.clear(); fields.clear();
std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); std::pair<uint64_t, std::string> index_field6(schema_id1, "label");
fields.push_back(index_field6); fields.push_back(index_field6);
res = header_data.AddIndexFields(fields);
ASSERT_EQ(res, FAILED);
status = header_data.AddIndexFields(fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data.GetFields().size(), 2); ASSERT_EQ(header_data.GetFields().size(), 2);


std::string desc_new = "this is a test1"; std::string desc_new = "this is a test1";
@@ -129,26 +130,26 @@ TEST_F(TestShardHeader, AddIndexFields) {
single_fields.push_back("name"); single_fields.push_back("name");
single_fields.push_back("name"); single_fields.push_back("name");
single_fields.push_back("box"); single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1); ASSERT_EQ(header_data_new.GetFields().size(), 1);


single_fields.push_back("name"); single_fields.push_back("name");
single_fields.push_back("box"); single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1); ASSERT_EQ(header_data_new.GetFields().size(), 1);


single_fields.clear(); single_fields.clear();
single_fields.push_back("names"); single_fields.push_back("names");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, FAILED);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_FALSE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 1); ASSERT_EQ(header_data_new.GetFields().size(), 1);


single_fields.clear(); single_fields.clear();
single_fields.push_back("box"); single_fields.push_back("box");
res = header_data_new.AddIndexFields(single_fields);
ASSERT_EQ(res, SUCCESS);
status = header_data_new.AddIndexFields(single_fields);
EXPECT_TRUE(status.IsOk());
ASSERT_EQ(header_data_new.GetFields().size(), 2); ASSERT_EQ(header_data_new.GetFields().size(), 2);
} }
} // namespace mindrecord } // namespace mindrecord


+ 8
- 8
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc View File

@@ -167,8 +167,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label"}; auto column_list = std::vector<std::string>{"label"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({file_name}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


while (true) { while (true) {
@@ -188,16 +188,16 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_namex"}; auto column_list = std::vector<std::string>{"file_namex"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST);
auto status= dataset.Open({file_name}, true, 4, column_list);
EXPECT_FALSE(status.IsOk());
} }


TEST_F(TestShardReader, TestShardVersion) { TEST_F(TestShardReader, TestShardVersion) {
MS_LOG(INFO) << FormatInfo("Test shard version"); MS_LOG(INFO) << FormatInfo("Test shard version");
std::string file_name = "./imagenet.shard01"; std::string file_name = "./imagenet.shard01";
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({file_name}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


while (true) { while (true) {
@@ -219,8 +219,8 @@ TEST_F(TestShardReader, TestShardReaderDir) {
auto column_list = std::vector<std::string>{"file_name"}; auto column_list = std::vector<std::string>{"file_name"};


ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({file_name}, true, 4, column_list);
ASSERT_EQ(ret, FAILED);
auto status = dataset.Open({file_name}, true, 4, column_list);
EXPECT_FALSE(status.IsOk());
} }


TEST_F(TestShardReader, TestShardReaderConsumer) { TEST_F(TestShardReader, TestShardReaderConsumer) {


+ 82
- 45
tests/ut/cpp/mindrecord/ut_shard_segment_test.cc View File

@@ -61,35 +61,44 @@ TEST_F(TestShardSegment, TestShardSegment) {
ShardSegment dataset; ShardSegment dataset;
dataset.Open({file_name}, true, 4); dataset.Open({file_name}, true, 4);


auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields; MS_LOG(INFO) << "Get category field: " << fields;
} }


ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
ASSERT_TRUE(dataset.SetCategoryField("laabel_0") == FAILED);
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
status = dataset.SetCategoryField("laabel_0");
EXPECT_FALSE(status.IsOk());


MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;


auto ret = dataset.ReadAtPageByName("822", 0, 10);
auto images = ret.second;
MS_LOG(INFO) << "category field: 822, images count: " << images.size() << ", image[0] size: " << images[0].size();
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;


auto ret1 = dataset.ReadAtPageByName("823", 0, 10);
auto images2 = ret1.second;
MS_LOG(INFO) << "category field: 823, images count: " << images2.size();
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName("822", 0, 10, &pages_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr->size() << ", image[0] size: " << ((*pages_ptr)[0]).size();


auto ret2 = dataset.ReadAtPageById(1, 0, 10);
auto images3 = ret2.second;
MS_LOG(INFO) << "category id: 1, images count: " << images3.size() << ", image[0] size: " << images3[0].size();
auto pages_ptr_1 = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName("823", 0, 10, &pages_ptr_1);
MS_LOG(INFO) << "category field: 823, images count: " << pages_ptr_1->size();


auto ret3 = dataset.ReadAllAtPageByName("822", 0, 10);
auto images4 = ret3.second;
MS_LOG(INFO) << "category field: 822, images count: " << images4.size();
auto pages_ptr_2 = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, 0, 10, &pages_ptr_2);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_2->size() << ", image[0] size: " << ((*pages_ptr_2)[0]).size();


auto ret4 = dataset.ReadAllAtPageById(1, 0, 10);
auto images5 = ret4.second;
MS_LOG(INFO) << "category id: 1, images count: " << images5.size();
auto pages_ptr_3 = std::make_shared<PAGES>();
status = dataset.ReadAllAtPageByName("822", 0, 10, &pages_ptr_3);
MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr_3->size();

auto pages_ptr_4 = std::make_shared<PAGES>();
status = dataset.ReadAllAtPageById(1, 0, 10, &pages_ptr_4);
MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_4->size();
} }


TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
@@ -99,21 +108,28 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) {
ShardSegment dataset; ShardSegment dataset;
dataset.Open({file_name}, true, 4); dataset.Open({file_name}, true, 4);


auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields; MS_LOG(INFO) << "Get category field: " << fields;
} }


string category_name = "82Cus"; string category_name = "82Cus";
string category_field = "laabel_0"; string category_field = "laabel_0";


ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
ASSERT_TRUE(dataset.SetCategoryField(category_field) == FAILED);
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
status = dataset.SetCategoryField(category_field);
EXPECT_FALSE(status.IsOk());


MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;


auto ret = dataset.ReadAtPageByName(category_name, 0, 10);
EXPECT_TRUE(ret.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageByName(category_name, 0, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
} }


TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
@@ -123,19 +139,25 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) {
ShardSegment dataset; ShardSegment dataset;
dataset.Open({file_name}, true, 4); dataset.Open({file_name}, true, 4);


auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields; MS_LOG(INFO) << "Get category field: " << fields;
} }


int64_t categoryId = 2251799813685247; int64_t categoryId = 2251799813685247;
MS_LOG(INFO) << "Input category id: " << categoryId; MS_LOG(INFO) << "Input category id: " << categoryId;


ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;


auto ret2 = dataset.ReadAtPageById(categoryId, 0, 10);
EXPECT_TRUE(ret2.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(categoryId, 0, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
} }


TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
@@ -145,19 +167,27 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) {
ShardSegment dataset; ShardSegment dataset;
dataset.Open({file_name}, true, 4); dataset.Open({file_name}, true, 4);


auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields; MS_LOG(INFO) << "Get category field: " << fields;
} }


int64_t page_no = 2251799813685247; int64_t page_no = 2251799813685247;
MS_LOG(INFO) << "Input page no: " << page_no; MS_LOG(INFO) << "Input page no: " << page_no;


ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());



auto ret2 = dataset.ReadAtPageById(1, page_no, 10);
EXPECT_TRUE(ret2.first == FAILED);
std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;

auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, page_no, 10, &pages_ptr);
EXPECT_FALSE(status.IsOk());
} }


TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
@@ -167,19 +197,26 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) {
ShardSegment dataset; ShardSegment dataset;
dataset.Open({file_name}, true, 4); dataset.Open({file_name}, true, 4);


auto x = dataset.GetCategoryFields();
for (const auto &fields : x.second) {
auto fields_ptr = std::make_shared<vector<std::string>>();
auto status = dataset.GetCategoryFields(&fields_ptr);
for (const auto &fields : *fields_ptr) {
MS_LOG(INFO) << "Get category field: " << fields; MS_LOG(INFO) << "Get category field: " << fields;
} }


int64_t pageRows = 0; int64_t pageRows = 0;
MS_LOG(INFO) << "Input page rows: " << pageRows; MS_LOG(INFO) << "Input page rows: " << pageRows;


ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS);
MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second;
status = dataset.SetCategoryField("label");
EXPECT_TRUE(status.IsOk());

std::shared_ptr<std::string> category_ptr;
status = dataset.ReadCategoryInfo(&category_ptr);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Read category info: " << *category_ptr;


auto ret2 = dataset.ReadAtPageById(1, 0, pageRows);
EXPECT_TRUE(ret2.first == FAILED);
auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>();
status = dataset.ReadAtPageById(1, 0, pageRows, &pages_ptr);
EXPECT_FALSE(status.IsOk());
} }


} // namespace mindrecord } // namespace mindrecord


+ 49
- 29
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc View File

@@ -60,8 +60,8 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
std::string filename = "./OneSample.shard01"; std::string filename = "./OneSample.shard01";


ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({filename}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


while (true) { while (true) {
@@ -675,8 +675,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) {
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));


// write rawdata // write rawdata
MSRStatus res = fw.WriteRawData(rawdatas, bin_data);
ASSERT_EQ(res, SUCCESS);
auto status = fw.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
for (const auto &filename : file_names) { for (const auto &filename : file_names) {
auto filename_db = filename + ".db"; auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db)); remove(common::SafeCStr(filename_db));
@@ -716,7 +716,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
fields.push_back(index_field2); fields.push_back(index_field2);


// add index to shardHeader // add index to shardHeader
ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS);
auto status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Init Index Fields Already."; MS_LOG(INFO) << "Init Index Fields Already.";


// load meta data // load meta data
@@ -736,28 +737,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
} }


mindrecord::ShardWriter fw_init; mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());


// set shardHeader // set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());



// write raw data // write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());


// create the index file // create the index file
std::string filename = "./imagenet.shard01"; std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename}; mindrecord::ShardIndexGenerator sg{filename};
sg.Build(); sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index"; MS_LOG(INFO) << "Done create index";


// read the mindrecord file // read the mindrecord file
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name", "data"}; auto column_list = std::vector<std::string>{"label", "file_name", "data"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


int count = 0; int count = 0;
@@ -822,28 +829,34 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
} }


mindrecord::ShardWriter fw_init; mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
auto status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());



// set shardHeader // set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());


// write raw data // write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());


// create the index file // create the index file
std::string filename = "./imagenet.shard01"; std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename}; mindrecord::ShardIndexGenerator sg{filename};
sg.Build(); sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index"; MS_LOG(INFO) << "Done create index";


// read the mindrecord file // read the mindrecord file
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "file_name"}; auto column_list = std::vector<std::string>{"label", "file_name"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


int count = 0; int count = 0;
@@ -896,7 +909,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
fields.push_back(index_field1); fields.push_back(index_field1);


// add index to shardHeader // add index to shardHeader
ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS);
auto status = header_data.AddIndexFields(fields);
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Init Index Fields Already."; MS_LOG(INFO) << "Init Index Fields Already.";


// load meta data // load meta data
@@ -916,28 +930,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
} }


mindrecord::ShardWriter fw_init; mindrecord::ShardWriter fw_init;
ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS);
status = fw_init.Open(file_names);
EXPECT_TRUE(status.IsOk());



// set shardHeader // set shardHeader
ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS);
status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
EXPECT_TRUE(status.IsOk());


// write raw data // write raw data
ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS);
ASSERT_TRUE(fw_init.Commit() == SUCCESS);
status = fw_init.WriteRawData(rawdatas, bin_data);
EXPECT_TRUE(status.IsOk());
status = fw_init.Commit();
EXPECT_TRUE(status.IsOk());


// create the index file // create the index file
std::string filename = "./imagenet.shard01"; std::string filename = "./imagenet.shard01";
mindrecord::ShardIndexGenerator sg{filename}; mindrecord::ShardIndexGenerator sg{filename};
sg.Build(); sg.Build();
ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS);
status = sg.WriteToDatabase();
EXPECT_TRUE(status.IsOk());
MS_LOG(INFO) << "Done create index"; MS_LOG(INFO) << "Done create index";


// read the mindrecord file // read the mindrecord file
filename = "./imagenet.shard01"; filename = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"label", "data"}; auto column_list = std::vector<std::string>{"label", "data"};
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4, column_list);
ASSERT_EQ(ret, SUCCESS);
status = dataset.Open({filename}, true, 4, column_list);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


int count = 0; int count = 0;
@@ -1043,8 +1063,8 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {


filename = "./TenSampleFortyShard.shard01"; filename = "./TenSampleFortyShard.shard01";
ShardReader dataset; ShardReader dataset;
MSRStatus ret = dataset.Open({filename}, true, 4);
ASSERT_EQ(ret, SUCCESS);
auto status = dataset.Open({filename}, true, 4);
EXPECT_TRUE(status.IsOk());
dataset.Launch(); dataset.Launch();


int count = 0; int count = 0;


+ 5
- 5
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -95,7 +95,7 @@ def test_invalid_mindrecord():
f.write('just for test') f.write('just for test')
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"):
data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -114,7 +114,7 @@ def test_minddataset_lack_db():
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -133,7 +133,7 @@ def test_cv_minddataset_pk_sample_error_class_column():
columns_list = ["data", "file_name", "label"] columns_list = ["data", "file_name", "label"]
num_readers = 4 num_readers = 4
sampler = ds.PKSampler(5, None, True, 'no_exist_column') sampler = ds.PKSampler(5, None, True, 'no_exist_column')
with pytest.raises(Exception, match="MindRecordOp launch failed"):
with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
num_iter = 0 num_iter = 0
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
@@ -162,7 +162,7 @@ def test_cv_minddataset_reader_different_schema():
create_diff_schema_cv_mindrecord(1) create_diff_schema_cv_mindrecord(1)
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers) num_readers)
num_iter = 0 num_iter = 0
@@ -179,7 +179,7 @@ def test_cv_minddataset_reader_different_page_size():
create_diff_page_size_cv_mindrecord(1) create_diff_page_size_cv_mindrecord(1)
columns_list = ["data", "label"] columns_list = ["data", "label"]
num_readers = 4 num_readers = 4
with pytest.raises(Exception, match="MindRecordOp init failed"):
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
num_readers) num_readers)
num_iter = 0 num_iter = 0


+ 4
- 5
tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py View File

@@ -19,7 +19,6 @@ import pytest
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import Cifar100ToMR from mindspore.mindrecord import Cifar100ToMR
from mindspore.mindrecord import FileReader from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError
from mindspore.mindrecord import SUCCESS from mindspore.mindrecord import SUCCESS


CIFAR100_DIR = "../data/mindrecord/testCifar100Data" CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
@@ -119,8 +118,8 @@ def test_cifar100_to_mindrecord_directory(fixture_file):
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path is directory. when destination path is directory.
""" """
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
CIFAR100_DIR) CIFAR100_DIR)
cifar100_transformer.transform() cifar100_transformer.transform()
@@ -130,8 +129,8 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path equals source path. when destination path equals source path.
""" """
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="indRecord file already existed, please delete file:"):
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
CIFAR100_DIR + "/train") CIFAR100_DIR + "/train")
cifar100_transformer.transform() cifar100_transformer.transform()

+ 5
- 5
tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py View File

@@ -19,7 +19,7 @@ import pytest
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import Cifar10ToMR from mindspore.mindrecord import Cifar10ToMR
from mindspore.mindrecord import FileReader from mindspore.mindrecord import FileReader
from mindspore.mindrecord import MRMOpenError, SUCCESS
from mindspore.mindrecord import SUCCESS


CIFAR10_DIR = "../data/mindrecord/testCifar10Data" CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord" MINDRECORD_FILE = "./cifar10.mindrecord"
@@ -146,8 +146,8 @@ def test_cifar10_to_mindrecord_directory(fixture_file):
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path is directory. when destination path is directory.
""" """
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR) cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR)
cifar10_transformer.transform() cifar10_transformer.transform()


@@ -157,8 +157,8 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10():
test transform cifar10 dataset to mindrecord test transform cifar10 dataset to mindrecord
when destination path equals source path. when destination path equals source path.
""" """
with pytest.raises(MRMOpenError,
match="MindRecord File could not open successfully"):
with pytest.raises(RuntimeError,
match="MindRecord file already existed, please delete file:"):
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, cifar10_transformer = Cifar10ToMR(CIFAR10_DIR,
CIFAR10_DIR + "/data_batch_0") CIFAR10_DIR + "/data_batch_0")
cifar10_transformer.transform() cifar10_transformer.transform()

+ 38
- 49
tests/ut/python/mindrecord/test_mindrecord_exception.py View File

@@ -21,8 +21,7 @@ from utils import get_data


from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \
MRMFetchDataError
from mindspore.mindrecord import ParamValueError, MRMGetMetaError


CV_FILE_NAME = "./imagenet.mindrecord" CV_FILE_NAME = "./imagenet.mindrecord"
NLP_FILE_NAME = "./aclImdb.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord"
@@ -106,21 +105,19 @@ def create_cv_mindrecord(files_num):


def test_lack_partition_and_db(): def test_lack_partition_and_db():
"""test file reader when mindrecord file does not exist.""" """test file reader when mindrecord file does not exist."""
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader('dummy.mindrecord') reader = FileReader('dummy.mindrecord')
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_lack_db(fixture_cv_file): def test_lack_db(fixture_cv_file):
"""test file reader when db file does not exist.""" """test file reader when db file does not exist."""
create_cv_mindrecord(1) create_cv_mindrecord(1)
os.remove("{}.db".format(CV_FILE_NAME)) os.remove("{}.db".format(CV_FILE_NAME))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME) reader = FileReader(CV_FILE_NAME)
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid database file:' in str(err.value)


def test_lack_some_partition_and_db(fixture_cv_file): def test_lack_some_partition_and_db(fixture_cv_file):
"""test file reader when some partition and db do not exist.""" """test file reader when some partition and db do not exist."""
@@ -129,11 +126,10 @@ def test_lack_some_partition_and_db(fixture_cv_file):
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}".format(paths[3])) os.remove("{}".format(paths[3]))
os.remove("{}.db".format(paths[3])) os.remove("{}.db".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_lack_some_partition_first(fixture_cv_file): def test_lack_some_partition_first(fixture_cv_file):
"""test file reader when first partition does not exist.""" """test file reader when first partition does not exist."""
@@ -141,11 +137,10 @@ def test_lack_some_partition_first(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}".format(paths[0])) os.remove("{}".format(paths[0]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_lack_some_partition_middle(fixture_cv_file): def test_lack_some_partition_middle(fixture_cv_file):
"""test file reader when some partition does not exist.""" """test file reader when some partition does not exist."""
@@ -153,11 +148,10 @@ def test_lack_some_partition_middle(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}".format(paths[1])) os.remove("{}".format(paths[1]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_lack_some_partition_last(fixture_cv_file): def test_lack_some_partition_last(fixture_cv_file):
"""test file reader when last partition does not exist.""" """test file reader when last partition does not exist."""
@@ -165,11 +159,10 @@ def test_lack_some_partition_last(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}".format(paths[3])) os.remove("{}".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_mindpage_lack_some_partition(fixture_cv_file): def test_mindpage_lack_some_partition(fixture_cv_file):
"""test page reader when some partition does not exist.""" """test page reader when some partition does not exist."""
@@ -177,10 +170,9 @@ def test_mindpage_lack_some_partition(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}".format(paths[0])) os.remove("{}".format(paths[0]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
MindPage(CV_FILE_NAME + "0") MindPage(CV_FILE_NAME + "0")
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file path:' in str(err.value)


def test_lack_some_db(fixture_cv_file): def test_lack_some_db(fixture_cv_file):
"""test file reader when some db does not exist.""" """test file reader when some db does not exist."""
@@ -188,11 +180,10 @@ def test_lack_some_db(fixture_cv_file):
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)] for x in range(FILES_NUM)]
os.remove("{}.db".format(paths[3])) os.remove("{}.db".format(paths[3]))
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
reader = FileReader(CV_FILE_NAME + "0") reader = FileReader(CV_FILE_NAME + "0")
reader.close() reader.close()
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid database file:' in str(err.value)




def test_invalid_mindrecord(): def test_invalid_mindrecord():
@@ -200,10 +191,9 @@ def test_invalid_mindrecord():
with open(CV_FILE_NAME, 'w') as f: with open(CV_FILE_NAME, 'w') as f:
dummy = 's' * 100 dummy = 's' * 100
f.write(dummy) f.write(dummy)
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
FileReader(CV_FILE_NAME) FileReader(CV_FILE_NAME)
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Invalid file content. path:' in str(err.value)
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)


def test_invalid_db(fixture_cv_file): def test_invalid_db(fixture_cv_file):
@@ -212,27 +202,26 @@ def test_invalid_db(fixture_cv_file):
os.remove("imagenet.mindrecord.db") os.remove("imagenet.mindrecord.db")
with open('imagenet.mindrecord.db', 'w') as f: with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test') f.write('just for test')
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
FileReader('imagenet.mindrecord') FileReader('imagenet.mindrecord')
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
in str(err.value)
assert 'Unexpected error. Error in execute sql:' in str(err.value)


def test_overwrite_invalid_mindrecord(fixture_cv_file): def test_overwrite_invalid_mindrecord(fixture_cv_file):
"""test file writer when overwrite invalid mindreocrd file.""" """test file writer when overwrite invalid mindreocrd file."""
with open(CV_FILE_NAME, 'w') as f: with open(CV_FILE_NAME, 'w') as f:
f.write('just for test') f.write('just for test')
with pytest.raises(MRMOpenError) as err:
with pytest.raises(RuntimeError) as err:
create_cv_mindrecord(1) create_cv_mindrecord(1)
assert '[MRMOpenError]: MindRecord File could not open successfully.' \
assert 'Unexpected error. MindRecord file already existed, please delete file:' \
in str(err.value) in str(err.value)


def test_overwrite_invalid_db(fixture_cv_file): def test_overwrite_invalid_db(fixture_cv_file):
"""test file writer when overwrite invalid db file.""" """test file writer when overwrite invalid db file."""
with open('imagenet.mindrecord.db', 'w') as f: with open('imagenet.mindrecord.db', 'w') as f:
f.write('just for test') f.write('just for test')
with pytest.raises(MRMGenerateIndexError) as err:
with pytest.raises(RuntimeError) as err:
create_cv_mindrecord(1) create_cv_mindrecord(1)
assert '[MRMGenerateIndexError]: Failed to generate index.' in str(err.value)
assert 'Unexpected error. Failed to write data to db.' in str(err.value)


def test_read_after_close(fixture_cv_file): def test_read_after_close(fixture_cv_file):
"""test file reader when close read.""" """test file reader when close read."""
@@ -302,7 +291,7 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
with pytest.raises(ParamValueError): with pytest.raises(ParamValueError):
reader.read_at_page_by_name("822", 0, "qwer") reader.read_at_page_by_name("822", 0, "qwer")


with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
reader.read_at_page_by_id(99999, 0, 1) reader.read_at_page_by_id(99999, 0, 1)




@@ -320,10 +309,10 @@ def test_mindpage_filename_not_exist(fixture_cv_file):
info = reader.read_category_info() info = reader.read_category_info()
logger.info("category info: {}".format(info)) logger.info("category info: {}".format(info))


with pytest.raises(MRMFetchDataError):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
reader.read_at_page_by_id(9999, 0, 1) reader.read_at_page_by_id(9999, 0, 1)


with pytest.raises(MRMFetchDataError):
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."):
reader.read_at_page_by_name("abc.jpg", 0, 1) reader.read_at_page_by_name("abc.jpg", 0, 1)


with pytest.raises(ParamValueError): with pytest.raises(ParamValueError):
@@ -475,7 +464,7 @@ def test_write_with_invalid_data():
mindrecord_file_name = "test.mindrecord" mindrecord_file_name = "test.mindrecord"


# field: file_name => filename # field: file_name => filename
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -510,7 +499,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field: mask => masks # field: mask => masks
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -545,7 +534,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field: data => image # field: data => image
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -580,7 +569,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field: label => labels # field: label => labels
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -615,7 +604,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field: score => scores # field: score => scores
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -650,7 +639,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# string type with int value # string type with int value
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -685,7 +674,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field with int64 type, but the real data is string # field with int64 type, but the real data is string
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -720,7 +709,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# bytes field is string # bytes field is string
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -755,7 +744,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# field is not numpy type # field is not numpy type
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")


@@ -790,7 +779,7 @@ def test_write_with_invalid_data():
writer.commit() writer.commit()


# not enough field # not enough field
with pytest.raises(Exception, match="Failed to write dataset"):
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name)
remove_one_file(mindrecord_file_name + ".db") remove_one_file(mindrecord_file_name + ".db")




Loading…
Cancel
Save