| @@ -408,8 +408,13 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr<Datas | |||||
| } | } | ||||
| std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>(); | std::shared_ptr<MindRecordOp::Builder> builder = std::make_shared<MindRecordOp::Builder>(); | ||||
| (void)builder->SetDatasetFile(ToString(args["dataset_file"])); | |||||
| bool load_dataset = ToBool(args["load_dataset"]); | |||||
| if (load_dataset == true) { | |||||
| (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); | |||||
| } else { | |||||
| (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); | |||||
| } | |||||
| (void)builder->SetLoadDataset(load_dataset); | |||||
| std::vector<std::string> in_col_names; | std::vector<std::string> in_col_names; | ||||
| if (!args["columns_list"].is_none()) { | if (!args["columns_list"].is_none()) { | ||||
| in_col_names = ToStringVector(args["columns_list"]); | in_col_names = ToStringVector(args["columns_list"]); | ||||
| @@ -151,16 +151,17 @@ void bindDatasetOps(py::module *m) { | |||||
| }); | }); | ||||
| (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | ||||
| .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { | |||||
| int64_t count = 0; | |||||
| std::shared_ptr<mindrecord::ShardOperator> op; | |||||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||||
| auto create = sampler.attr("_create_for_minddataset"); | |||||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| } | |||||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); | |||||
| return count; | |||||
| }); | |||||
| .def_static("get_num_rows", | |||||
| [](const std::vector<std::string> &paths, bool load_dataset, const py::object &sampler) { | |||||
| int64_t count = 0; | |||||
| std::shared_ptr<mindrecord::ShardOperator> op; | |||||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||||
| auto create = sampler.attr("_create_for_minddataset"); | |||||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| } | |||||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count)); | |||||
| return count; | |||||
| }); | |||||
| (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp") | (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp") | ||||
| .def_static("get_num_rows_and_classes", | .def_static("get_num_rows_and_classes", | ||||
| @@ -40,7 +40,7 @@ using mindrecord::ShardOperator; | |||||
| using mindrecord::ShardReader; | using mindrecord::ShardReader; | ||||
| // Builder constructor. Creates the builder object. | // Builder constructor. Creates the builder object. | ||||
| MindRecordOp::Builder::Builder() : build_dataset_file_("") { | |||||
| MindRecordOp::Builder::Builder() : build_dataset_file_({}) { | |||||
| // Some arguments to the MindRecordOp constructor have a default argument that is taken | // Some arguments to the MindRecordOp constructor have a default argument that is taken | ||||
| // from the client config. | // from the client config. | ||||
| // The user may choose to change these values for the construction of the StorageOp by | // The user may choose to change these values for the construction of the StorageOp by | ||||
| @@ -63,9 +63,9 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { | |||||
| "Building a MindRecordOp that has not provided a file."); | "Building a MindRecordOp that has not provided a file."); | ||||
| } | } | ||||
| new_mind_record_op = std::make_shared<MindRecordOp>(build_num_mind_record_workers_, build_rows_per_buffer_, | |||||
| build_dataset_file_, build_op_connector_queue_size_, | |||||
| build_columns_to_load_, build_operators_, build_block_reader_); | |||||
| new_mind_record_op = std::make_shared<MindRecordOp>( | |||||
| build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, | |||||
| build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_); | |||||
| RETURN_IF_NOT_OK(new_mind_record_op->Init()); | RETURN_IF_NOT_OK(new_mind_record_op->Init()); | ||||
| @@ -76,12 +76,14 @@ Status MindRecordOp::Builder::Build(std::shared_ptr<MindRecordOp> *ptr) { | |||||
| Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } | Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } | ||||
| // Constructor of the MindRecordOp. | // Constructor of the MindRecordOp. | ||||
| MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, | |||||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | |||||
| MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, | |||||
| std::vector<std::string> dataset_file, bool load_dataset, int32_t op_connector_queue_size, | |||||
| const std::vector<std::string> &columns_to_load, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) | const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) | ||||
| : ParallelOp(num_mind_record_workers, op_connector_queue_size), | : ParallelOp(num_mind_record_workers, op_connector_queue_size), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| dataset_file_(dataset_file), | dataset_file_(dataset_file), | ||||
| load_dataset_(load_dataset), | |||||
| columns_to_load_(columns_to_load), | columns_to_load_(columns_to_load), | ||||
| operators_(operators), | operators_(operators), | ||||
| num_mind_record_workers_(num_mind_record_workers), | num_mind_record_workers_(num_mind_record_workers), | ||||
| @@ -101,9 +103,10 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf | |||||
| // Private helper method to encapsulate some common construction/reset tasks | // Private helper method to encapsulate some common construction/reset tasks | ||||
| Status MindRecordOp::Init() { | Status MindRecordOp::Init() { | ||||
| shard_reader_ = std::make_unique<ShardReader>(); | shard_reader_ = std::make_unique<ShardReader>(); | ||||
| auto rc = shard_reader_->Open(dataset_file_, num_mind_record_workers_, columns_to_load_, operators_, block_reader_); | |||||
| auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | |||||
| block_reader_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rc != MSRStatus::FAILED, | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, | |||||
| "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); | "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); | ||||
| data_schema_ = std::make_unique<DataSchema>(); | data_schema_ = std::make_unique<DataSchema>(); | ||||
| @@ -201,8 +204,12 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const { | |||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| ParallelOp::Print(out, show_all); | ParallelOp::Print(out, show_all); | ||||
| // Then show any custom derived-internal stuff | // Then show any custom derived-internal stuff | ||||
| out << "\n1 Dataset file : " << dataset_file_ << "\nNumber of rows : " << num_rows_ | |||||
| << "\nRows per buffer : " << rows_per_buffer_ << "\nNumber of buffers : " << buffers_needed_ | |||||
| out << "\n Dataset file : "; | |||||
| for (auto &file : dataset_file_) { | |||||
| out << file << " "; | |||||
| } | |||||
| out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ | |||||
| << "\nNumber of buffers : " << buffers_needed_ | |||||
| << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; | << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -668,10 +675,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count) { | |||||
| Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count) { | |||||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | ||||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); | |||||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count); | |||||
| if (rc == MSRStatus::FAILED) { | if (rc == MSRStatus::FAILED) { | ||||
| RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); | RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); | ||||
| } | } | ||||
| @@ -77,8 +77,8 @@ class MindRecordOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| Builder &SetDatasetFile(const std::string &file) { | |||||
| build_dataset_file_ = file; | |||||
| Builder &SetDatasetFile(const std::vector<std::string> &files) { | |||||
| build_dataset_file_ = files; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -97,6 +97,11 @@ class MindRecordOp : public ParallelOp { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| Builder &SetLoadDataset(bool load_dataset) { | |||||
| build_load_dataset_ = load_dataset; | |||||
| return *this; | |||||
| } | |||||
| Status SanityCheck() const; | Status SanityCheck() const; | ||||
| static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } | static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } | ||||
| @@ -109,7 +114,8 @@ class MindRecordOp : public ParallelOp { | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t build_rows_per_buffer_; | int32_t build_rows_per_buffer_; | ||||
| int32_t build_op_connector_queue_size_; | int32_t build_op_connector_queue_size_; | ||||
| std::string build_dataset_file_; | |||||
| std::vector<std::string> build_dataset_file_; | |||||
| bool build_load_dataset_; | |||||
| std::vector<std::string> build_columns_to_load_; | std::vector<std::string> build_columns_to_load_; | ||||
| std::vector<std::shared_ptr<ShardOperator>> build_operators_; | std::vector<std::shared_ptr<ShardOperator>> build_operators_; | ||||
| bool build_block_reader_; | bool build_block_reader_; | ||||
| @@ -119,12 +125,12 @@ class MindRecordOp : public ParallelOp { | |||||
| // @note The builder class should be used to call it | // @note The builder class should be used to call it | ||||
| // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) | // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) | ||||
| // @param rows_per_buffer - The requested number of rows per buffer | // @param rows_per_buffer - The requested number of rows per buffer | ||||
| // @param dataset_file - A shard file | |||||
| // @param dataset_file - dataset files | |||||
| // @param op_connector_queue_size - The output connector queue size | // @param op_connector_queue_size - The output connector queue size | ||||
| // @param columns_to_load - The list of columns to use (column name) | // @param columns_to_load - The list of columns to use (column name) | ||||
| // @param operators - ShardOperators for Shuffle, Category, Sample | // @param operators - ShardOperators for Shuffle, Category, Sample | ||||
| MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::string dataset_file, | |||||
| int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | |||||
| MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector<std::string> dataset_file, | |||||
| bool load_dataset, int32_t op_connector_queue_size, const std::vector<std::string> &columns_to_load, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader); | const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader); | ||||
| // Destructor | // Destructor | ||||
| @@ -169,21 +175,22 @@ class MindRecordOp : public ParallelOp { | |||||
| // Getter method | // Getter method | ||||
| int32_t num_rows() const { return num_rows_; } | int32_t num_rows() const { return num_rows_; } | ||||
| // Getter method | |||||
| static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count); | |||||
| static Status CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count); | |||||
| // Getter method | // Getter method | ||||
| int32_t rows_per_buffer() const { return rows_per_buffer_; } | int32_t rows_per_buffer() const { return rows_per_buffer_; } | ||||
| // Getter method | // Getter method | ||||
| std::string dataset_file() const { return dataset_file_; } | |||||
| std::vector<std::string> dataset_file() const { return dataset_file_; } | |||||
| // Getter method | // Getter method | ||||
| std::vector<std::string> columns_to_load() const { return columns_to_load_; } | std::vector<std::string> columns_to_load() const { return columns_to_load_; } | ||||
| bool block_reader() const { return block_reader_; } | bool block_reader() const { return block_reader_; } | ||||
| bool load_dataset() const { return load_dataset_; } | |||||
| Status Init(); | Status Init(); | ||||
| Status SetColumnsBlob(); | Status SetColumnsBlob(); | ||||
| @@ -246,7 +253,8 @@ class MindRecordOp : public ParallelOp { | |||||
| Status FetchBlockBuffer(const int32_t &buffer_id); | Status FetchBlockBuffer(const int32_t &buffer_id); | ||||
| int32_t rows_per_buffer_; // The number of requested rows per buffer. | int32_t rows_per_buffer_; // The number of requested rows per buffer. | ||||
| std::string dataset_file_; // A dataset file | |||||
| std::vector<std::string> dataset_file_; // dataset files | |||||
| bool load_dataset_; // load dataset from single file or not | |||||
| std::vector<std::string> columns_to_load_; // Columns to load from dataset | std::vector<std::string> columns_to_load_; // Columns to load from dataset | ||||
| std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use | std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use | ||||
| int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader | int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader | ||||
| @@ -170,6 +170,9 @@ std::string ErrnoToMessage(MSRStatus status) { | |||||
| case IO_FAILED: | case IO_FAILED: | ||||
| return "io operate failed"; | return "io operate failed"; | ||||
| break; | break; | ||||
| case MATCH_HEADER_FAILED: | |||||
| return "match header failed"; | |||||
| break; | |||||
| default: | default: | ||||
| return "invalid error no"; | return "invalid error no"; | ||||
| } | } | ||||
| @@ -84,7 +84,8 @@ void BindShardWriter(py::module *m) { | |||||
| void BindShardReader(const py::module *m) { | void BindShardReader(const py::module *m) { | ||||
| (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) | (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) | ||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", (MSRStatus(ShardReader::*)(const std::string &, const int &, const std::vector<std::string> &, | |||||
| .def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &, | |||||
| const std::vector<std::string> &, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | const std::vector<std::shared_ptr<ShardOperator>> &)) & | ||||
| ShardReader::OpenPy) | ShardReader::OpenPy) | ||||
| .def("launch", &ShardReader::Launch) | .def("launch", &ShardReader::Launch) | ||||
| @@ -106,7 +107,8 @@ void BindShardIndexGenerator(const py::module *m) { | |||||
| void BindShardSegment(py::module *m) { | void BindShardSegment(py::module *m) { | ||||
| (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) | (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) | ||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", (MSRStatus(ShardSegment::*)(const std::string &, const int &, const std::vector<std::string> &, | |||||
| .def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &, | |||||
| const std::vector<std::string> &, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | const std::vector<std::shared_ptr<ShardOperator>> &)) & | ||||
| ShardSegment::OpenPy) | ShardSegment::OpenPy) | ||||
| .def("get_category_fields", | .def("get_category_fields", | ||||
| @@ -72,7 +72,8 @@ enum MSRStatus { | |||||
| ILLEGAL_PARAMETERS, | ILLEGAL_PARAMETERS, | ||||
| GET_PAGE_BY_GROUP_ID_FAILED, | GET_PAGE_BY_GROUP_ID_FAILED, | ||||
| GET_SYSTEM_STATE_FAILED, | GET_SYSTEM_STATE_FAILED, | ||||
| IO_FAILED | |||||
| IO_FAILED, | |||||
| MATCH_HEADER_FAILED | |||||
| }; | }; | ||||
| // convert error no to string message | // convert error no to string message | ||||
| @@ -35,10 +35,11 @@ class ShardHeader { | |||||
| public: | public: | ||||
| ShardHeader(); | ShardHeader(); | ||||
| MSRStatus Build(const std::string &file_path); | |||||
| ~ShardHeader() = default; | ~ShardHeader() = default; | ||||
| MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); | |||||
| static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path); | |||||
| /// \brief add the schema and save it | /// \brief add the schema and save it | ||||
| /// \param[in] schema the schema needs to be added | /// \param[in] schema the schema needs to be added | ||||
| /// \return the last schema's id | /// \return the last schema's id | ||||
| @@ -126,7 +127,7 @@ class ShardHeader { | |||||
| MSRStatus FileToPages(const std::string dump_file_name); | MSRStatus FileToPages(const std::string dump_file_name); | ||||
| private: | private: | ||||
| MSRStatus InitializeHeader(const std::vector<json> &headers); | |||||
| MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||||
| /// \brief get the headers from all the shard data | /// \brief get the headers from all the shard data | ||||
| /// \param[in] the shard data real path | /// \param[in] the shard data real path | ||||
| @@ -137,9 +138,9 @@ class ShardHeader { | |||||
| MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | ||||
| /// \brief check the binary file status | /// \brief check the binary file status | ||||
| MSRStatus CheckFileStatus(const std::string &path); | |||||
| static MSRStatus CheckFileStatus(const std::string &path); | |||||
| std::pair<MSRStatus, json> ValidateHeader(const std::string &path); | |||||
| static std::pair<MSRStatus, json> ValidateHeader(const std::string &path); | |||||
| void ParseHeader(const json &header); | void ParseHeader(const json &header); | ||||
| @@ -149,7 +150,7 @@ class ShardHeader { | |||||
| MSRStatus CheckIndexField(const std::string &field, const json &schema); | MSRStatus CheckIndexField(const std::string &field, const json &schema); | ||||
| void ParsePage(const json &page); | |||||
| void ParsePage(const json &page, int shard_index, bool load_dataset); | |||||
| MSRStatus ParseStatistics(const json &statistics); | MSRStatus ParseStatistics(const json &statistics); | ||||
| @@ -68,23 +68,25 @@ class ShardReader { | |||||
| virtual ~ShardReader(); | virtual ~ShardReader(); | ||||
| /// \brief open files and initialize reader, c++ API | /// \brief open files and initialize reader, c++ API | ||||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | |||||
| /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list | |||||
| /// \param[in] load_dataset load dataset from single file or not | |||||
| /// \param[in] n_consumer number of threads when reading | /// \param[in] n_consumer number of threads when reading | ||||
| /// \param[in] selected_columns column list to be populated | /// \param[in] selected_columns column list to be populated | ||||
| /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | ||||
| /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode | /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Open(const std::string &file_path, int n_consumer = 4, | |||||
| MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, | |||||
| const std::vector<std::string> &selected_columns = {}, | const std::vector<std::string> &selected_columns = {}, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false); | const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const bool &block_reader = false); | ||||
| /// \brief open files and initialize reader, python API | /// \brief open files and initialize reader, python API | ||||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | |||||
| /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list | |||||
| /// \param[in] load_dataset load dataset from single file or not | |||||
| /// \param[in] n_consumer number of threads when reading | /// \param[in] n_consumer number of threads when reading | ||||
| /// \param[in] selected_columns column list to be populated | /// \param[in] selected_columns column list to be populated | ||||
| /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus OpenPy(const std::string &file_path, const int &n_consumer = 4, | |||||
| MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4, | |||||
| const std::vector<std::string> &selected_columns = {}, | const std::vector<std::string> &selected_columns = {}, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}); | const std::vector<std::shared_ptr<ShardOperator>> &operators = {}); | ||||
| @@ -114,11 +116,13 @@ class ShardReader { | |||||
| int GetShardCount() const; | int GetShardCount() const; | ||||
| /// \brief get the number of rows in database | /// \brief get the number of rows in database | ||||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | |||||
| /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list | |||||
| /// \param[in] load_dataset load dataset from single file or not | |||||
| /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | ||||
| /// \param[out] count # of rows | /// \param[out] count # of rows | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count); | |||||
| MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count); | |||||
| /// \brief shuffle task with incremental seed | /// \brief shuffle task with incremental seed | ||||
| /// \return void | /// \return void | ||||
| @@ -220,7 +224,7 @@ class ShardReader { | |||||
| std::vector<std::vector<json>> &column_values); | std::vector<std::vector<json>> &column_values); | ||||
| /// \brief initialize reader | /// \brief initialize reader | ||||
| MSRStatus Init(const std::string &file_path); | |||||
| MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset); | |||||
| /// \brief validate column list | /// \brief validate column list | ||||
| MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns); | MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns); | ||||
| @@ -292,8 +296,9 @@ class ShardReader { | |||||
| void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories); | void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories); | ||||
| /// \brief get number of classes | /// \brief get number of classes | ||||
| int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); | |||||
| int64_t GetNumClasses(const std::string &category_field); | |||||
| std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data); | |||||
| /// \brief get exactly blob fields data by indices | /// \brief get exactly blob fields data by indices | ||||
| std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, | std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes, | ||||
| std::vector<uint32_t> &ordered_selected_columns_index); | std::vector<uint32_t> &ordered_selected_columns_index); | ||||
| @@ -36,9 +36,23 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe | |||||
| write_success_(true) {} | write_success_(true) {} | ||||
| MSRStatus ShardIndexGenerator::Build() { | MSRStatus ShardIndexGenerator::Build() { | ||||
| auto ret = ShardHeader::BuildSingleHeader(file_path_); | |||||
| if (ret.first != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| auto json_header = ret.second; | |||||
| auto ret2 = GetParentDir(file_path_); | |||||
| if (SUCCESS != ret2.first) { | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<std::string> real_addresses; | |||||
| for (const auto &path : json_header["shard_addresses"]) { | |||||
| std::string abs_path = ret2.second + string(path); | |||||
| real_addresses.emplace_back(abs_path); | |||||
| } | |||||
| ShardHeader header = ShardHeader(); | ShardHeader header = ShardHeader(); | ||||
| if (header.Build(file_path_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Build shard schema failed."; | |||||
| if (header.BuildDataset(real_addresses) == FAILED) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = header; | shard_header_ = header; | ||||
| @@ -47,20 +47,55 @@ ShardReader::ShardReader() { | |||||
| block_reader_ = false; | block_reader_ = false; | ||||
| } | } | ||||
| MSRStatus ShardReader::Init(const std::string &file_path) { | |||||
| std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { | |||||
| if (!IsLegalFile(file_path)) { | if (!IsLegalFile(file_path)) { | ||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto ret = ShardHeader::BuildSingleHeader(file_path); | |||||
| if (ret.first != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto header = ret.second; | |||||
| meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, | |||||
| {"version", header["version"]}, {"index_fields", header["index_fields"]}, | |||||
| {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; | |||||
| return {SUCCESS, header["shard_addresses"]}; | |||||
| } | |||||
| MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) { | |||||
| std::string file_path = file_paths[0]; | |||||
| json first_meta_data = json(); | |||||
| auto ret = GetMeta(file_path, first_meta_data); | |||||
| if (ret.first != SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ShardHeader sh = ShardHeader(); | |||||
| if (sh.Build(file_path) == FAILED) { | |||||
| if (file_paths.size() == 1 && load_dataset == true) { | |||||
| auto ret2 = GetParentDir(file_path); | |||||
| if (SUCCESS != ret2.first) { | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<std::string> real_addresses; | |||||
| for (const auto &path : ret.second) { | |||||
| std::string abs_path = ret2.second + string(path); | |||||
| real_addresses.emplace_back(abs_path); | |||||
| } | |||||
| file_paths_ = real_addresses; | |||||
| } else if (file_paths.size() >= 1 && load_dataset == false) { | |||||
| file_paths_ = file_paths; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||||
| header_size_ = shard_header_->GetHeaderSize(); | |||||
| page_size_ = shard_header_->GetPageSize(); | |||||
| file_paths_ = shard_header_->GetShardAddresses(); | |||||
| for (const auto &file : file_paths_) { | for (const auto &file : file_paths_) { | ||||
| json meta_data = json(); | |||||
| auto ret1 = GetMeta(file, meta_data); | |||||
| if (ret1.first != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| if (meta_data != first_meta_data) { | |||||
| MS_LOG(ERROR) << "Mindrecord files meta information is different."; | |||||
| return FAILED; | |||||
| } | |||||
| sqlite3 *db = nullptr; | sqlite3 *db = nullptr; | ||||
| // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it | // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it | ||||
| int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); | int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); | ||||
| @@ -91,7 +126,13 @@ MSRStatus ShardReader::Init(const std::string &file_path) { | |||||
| } | } | ||||
| database_paths_.push_back(db); | database_paths_.push_back(db); | ||||
| } | } | ||||
| ShardHeader sh = ShardHeader(); | |||||
| if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||||
| header_size_ = shard_header_->GetHeaderSize(); | |||||
| page_size_ = shard_header_->GetPageSize(); | |||||
| num_rows_ = 0; | num_rows_ = 0; | ||||
| auto row_group_summary = ReadRowGroupSummary(); | auto row_group_summary = ReadRowGroupSummary(); | ||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| @@ -248,7 +289,6 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str | |||||
| fs->close(); | fs->close(); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| json label_json = json::from_msgpack(label_raw); | json label_json = json::from_msgpack(label_raw); | ||||
| json tmp; | json tmp; | ||||
| if (!columns.empty()) { | if (!columns.empty()) { | ||||
| @@ -713,15 +753,9 @@ MSRStatus ShardReader::Finish() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { | |||||
| ShardHeader sh = ShardHeader(); | |||||
| if (sh.Build(file_path) == FAILED) { | |||||
| return -1; | |||||
| } | |||||
| auto header = std::make_shared<ShardHeader>(sh); | |||||
| auto file_paths = header->GetShardAddresses(); | |||||
| auto shard_count = file_paths.size(); | |||||
| auto index_fields = header->GetFields(); | |||||
| int64_t ShardReader::GetNumClasses(const std::string &category_field) { | |||||
| auto shard_count = file_paths_.size(); | |||||
| auto index_fields = shard_header_->GetFields(); | |||||
| std::map<std::string, int64_t> map_schema_id_fields; | std::map<std::string, int64_t> map_schema_id_fields; | ||||
| for (auto &field : index_fields) { | for (auto &field : index_fields) { | ||||
| @@ -742,7 +776,7 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri | |||||
| std::set<std::string> categories; | std::set<std::string> categories; | ||||
| for (int x = 0; x < shard_count; x++) { | for (int x = 0; x < shard_count; x++) { | ||||
| sqlite3 *db = nullptr; | sqlite3 *db = nullptr; | ||||
| int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); | |||||
| int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); | |||||
| if (SQLITE_OK != rc) { | if (SQLITE_OK != rc) { | ||||
| MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); | MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); | ||||
| return -1; | return -1; | ||||
| @@ -756,16 +790,16 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri | |||||
| return categories.size(); | return categories.size(); | ||||
| } | } | ||||
| MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count) { | |||||
| if (Init(file_path) == FAILED) { | |||||
| MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count) { | |||||
| if (SUCCESS != Init(file_paths, load_dataset)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int64_t num_samples = num_rows_; | int64_t num_samples = num_rows_; | ||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) { | if (std::dynamic_pointer_cast<ShardCategory>(op)) { | ||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | ||||
| std::string category_field = category_op->GetCategoryField(); | std::string category_field = category_op->GetCategoryField(); | ||||
| auto num_classes = GetNumClasses(file_path, category_field); | |||||
| auto num_classes = GetNumClasses(category_field); | |||||
| num_samples = category_op->GetNumSamples(num_rows_, num_classes); | num_samples = category_op->GetNumSamples(num_rows_, num_classes); | ||||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | ||||
| num_samples = op->GetNumSamples(num_rows_, 0); | num_samples = op->GetNumSamples(num_rows_, 0); | ||||
| @@ -779,12 +813,13 @@ MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::s | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||||
| MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer, | |||||
| const std::vector<std::string> &selected_columns, | const std::vector<std::string> &selected_columns, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) { | const std::vector<std::shared_ptr<ShardOperator>> &operators, const bool &block_reader) { | ||||
| // Open file and set header by ShardReader | // Open file and set header by ShardReader | ||||
| if (Init(file_path) == FAILED) { | |||||
| return FAILED; | |||||
| auto ret = Init(file_paths, load_dataset); | |||||
| if (SUCCESS != ret) { | |||||
| return ret; | |||||
| } | } | ||||
| auto thread_limit = GetMaxThreadNum(); | auto thread_limit = GetMaxThreadNum(); | ||||
| if (n_consumer > thread_limit) { | if (n_consumer > thread_limit) { | ||||
| @@ -837,11 +872,11 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consumer, | |||||
| MSRStatus ShardReader::OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer, | |||||
| const std::vector<std::string> &selected_columns, | const std::vector<std::string> &selected_columns, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | const std::vector<std::shared_ptr<ShardOperator>> &operators) { | ||||
| // Open file and set header by ShardReader | // Open file and set header by ShardReader | ||||
| if (Init(file_path) == FAILED) { | |||||
| if (SUCCESS != Init(file_paths, load_dataset)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // should remove blob field from selected_columns when call from python | // should remove blob field from selected_columns when call from python | ||||
| @@ -174,12 +174,25 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||||
| if (!IsLegalFile(path)) { | if (!IsLegalFile(path)) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ShardHeader sh = ShardHeader(); | |||||
| if (sh.Build(path) == FAILED) { | |||||
| auto ret1 = ShardHeader::BuildSingleHeader(path); | |||||
| if (ret1.first != SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||||
| auto paths = shard_header_->GetShardAddresses(); | |||||
| auto json_header = ret1.second; | |||||
| auto ret2 = GetParentDir(path); | |||||
| if (SUCCESS != ret2.first) { | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<std::string> real_addresses; | |||||
| for (const auto &path : json_header["shard_addresses"]) { | |||||
| std::string abs_path = ret2.second + string(path); | |||||
| real_addresses.emplace_back(abs_path); | |||||
| } | |||||
| ShardHeader header = ShardHeader(); | |||||
| if (header.BuildDataset(real_addresses) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| shard_header_ = std::make_shared<ShardHeader>(header); | |||||
| MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); | MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); | ||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -188,7 +201,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ret = Open(paths, true); | |||||
| ret = Open(json_header["shard_addresses"], true); | |||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| MS_LOG(ERROR) << "Open file failed"; | MS_LOG(ERROR) << "Open file failed"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -35,8 +35,9 @@ namespace mindrecord { | |||||
| std::atomic<bool> thread_status(false); | std::atomic<bool> thread_status(false); | ||||
| ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } | ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } | ||||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) { | |||||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||||
| shard_count_ = headers.size(); | shard_count_ = headers.size(); | ||||
| int shard_index = 0; | |||||
| bool first = true; | bool first = true; | ||||
| for (const auto &header : headers) { | for (const auto &header : headers) { | ||||
| if (first) { | if (first) { | ||||
| @@ -54,7 +55,8 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers) { | |||||
| header_size_ = header["header_size"].get<uint64_t>(); | header_size_ = header["header_size"].get<uint64_t>(); | ||||
| page_size_ = header["page_size"].get<uint64_t>(); | page_size_ = header["page_size"].get<uint64_t>(); | ||||
| } | } | ||||
| ParsePage(header["page"]); | |||||
| ParsePage(header["page"], shard_index, load_dataset); | |||||
| shard_index++; | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -136,40 +138,39 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) | |||||
| return {SUCCESS, json_header}; | return {SUCCESS, json_header}; | ||||
| } | } | ||||
| MSRStatus ShardHeader::Build(const std::string &file_path) { | |||||
| std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) { | |||||
| auto ret = ValidateHeader(file_path); | auto ret = ValidateHeader(file_path); | ||||
| if (SUCCESS != ret.first) { | if (SUCCESS != ret.first) { | ||||
| return FAILED; | |||||
| } | |||||
| json main_header = ret.second; | |||||
| json addresses = main_header["shard_addresses"]; | |||||
| vector<string> real_addresses; | |||||
| auto ret1 = GetParentDir(file_path); | |||||
| if (SUCCESS != ret1.first) { | |||||
| return FAILED; | |||||
| return {FAILED, json()}; | |||||
| } | } | ||||
| std::string parent_dir = ret1.second; | |||||
| json raw_header = ret.second; | |||||
| json header = {{"shard_addresses", raw_header["shard_addresses"]}, | |||||
| {"header_size", raw_header["header_size"]}, | |||||
| {"page_size", raw_header["page_size"]}, | |||||
| {"index_fields", raw_header["index_fields"]}, | |||||
| {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | |||||
| {"schema", raw_header["schema"][0]["schema"]}, | |||||
| {"version", raw_header["version"]}}; | |||||
| return {SUCCESS, header}; | |||||
| } | |||||
| for (const auto &addr : addresses) { | |||||
| std::string absolute_path = parent_dir + string(addr); | |||||
| real_addresses.emplace_back(absolute_path); | |||||
| } | |||||
| MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) { | |||||
| uint32_t thread_num = std::thread::hardware_concurrency(); | uint32_t thread_num = std::thread::hardware_concurrency(); | ||||
| if (thread_num == 0) thread_num = kThreadNumber; | if (thread_num == 0) thread_num = kThreadNumber; | ||||
| uint32_t work_thread_num = 0; | uint32_t work_thread_num = 0; | ||||
| uint32_t addr_count = real_addresses.size(); | |||||
| int group_num = ceil(addr_count * 1.0 / thread_num); | |||||
| uint32_t shard_count = file_paths.size(); | |||||
| int group_num = ceil(shard_count * 1.0 / thread_num); | |||||
| std::vector<std::thread> thread_set(thread_num); | std::vector<std::thread> thread_set(thread_num); | ||||
| std::vector<json> headers(addr_count); | |||||
| std::vector<json> headers(shard_count); | |||||
| for (uint32_t x = 0; x < thread_num; ++x) { | for (uint32_t x = 0; x < thread_num; ++x) { | ||||
| int start_num = x * group_num; | int start_num = x * group_num; | ||||
| int end_num = ((x + 1) * group_num > addr_count) ? addr_count : (x + 1) * group_num; | |||||
| int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; | |||||
| if (start_num >= end_num) { | if (start_num >= end_num) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| thread_set[x] = | thread_set[x] = | ||||
| std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), real_addresses); | |||||
| std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); | |||||
| work_thread_num++; | work_thread_num++; | ||||
| } | } | ||||
| @@ -180,7 +181,7 @@ MSRStatus ShardHeader::Build(const std::string &file_path) { | |||||
| thread_status = false; | thread_status = false; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (SUCCESS != InitializeHeader(headers)) { | |||||
| if (SUCCESS != InitializeHeader(headers, load_dataset)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -247,7 +248,8 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void ShardHeader::ParsePage(const json &pages) { | |||||
| void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||||
| // set shard_index when load_dataset is false | |||||
| if (pages_.empty() && shard_count_ <= kMaxShardCount) { | if (pages_.empty() && shard_count_ <= kMaxShardCount) { | ||||
| pages_.resize(shard_count_); | pages_.resize(shard_count_); | ||||
| } | } | ||||
| @@ -267,7 +269,11 @@ void ShardHeader::ParsePage(const json &pages) { | |||||
| std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id, | std::shared_ptr<Page> parsed_page = std::make_shared<Page>(page_id, shard_id, page_type, page_type_id, start_row_id, | ||||
| end_row_id, row_group_ids, page_size); | end_row_id, row_group_ids, page_size); | ||||
| pages_[shard_id].push_back(std::move(parsed_page)); | |||||
| if (load_dataset == true) { | |||||
| pages_[shard_id].push_back(std::move(parsed_page)); | |||||
| } else { | |||||
| pages_[shard_index].push_back(std::move(parsed_page)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -709,7 +715,7 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||||
| std::string line; | std::string line; | ||||
| while (std::getline(page_in_handle, line)) { | while (std::getline(page_in_handle, line)) { | ||||
| ParsePage(json::parse(line)); | |||||
| ParsePage(json::parse(line), -1, true); | |||||
| } | } | ||||
| page_in_handle.close(); | page_in_handle.close(); | ||||
| @@ -2189,7 +2189,7 @@ class MindDataset(SourceDataset): | |||||
| A source dataset that reads from shard files and database. | A source dataset that reads from shard files and database. | ||||
| Args: | Args: | ||||
| dataset_file (str): one of file names in dataset. | |||||
| dataset_file (str, list[str]): One of file names or file list in dataset. | |||||
| columns_list (list[str], optional): List of columns to be read (default=None). | columns_list (list[str], optional): List of columns to be read (default=None). | ||||
| num_parallel_workers (int, optional): The number of readers (default=None). | num_parallel_workers (int, optional): The number of readers (default=None). | ||||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | shuffle (bool, optional): Whether or not to perform shuffle on the dataset | ||||
| @@ -2214,6 +2214,10 @@ class MindDataset(SourceDataset): | |||||
| shuffle=None, num_shards=None, shard_id=None, | shuffle=None, num_shards=None, shard_id=None, | ||||
| block_reader=False, sampler=None): | block_reader=False, sampler=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| if isinstance(dataset_file, list): | |||||
| self.load_dataset = False | |||||
| else: | |||||
| self.load_dataset = True | |||||
| self.dataset_file = dataset_file | self.dataset_file = dataset_file | ||||
| self.columns_list = columns_list | self.columns_list = columns_list | ||||
| self.global_shuffle = shuffle | self.global_shuffle = shuffle | ||||
| @@ -2256,6 +2260,7 @@ class MindDataset(SourceDataset): | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_file"] = self.dataset_file | args["dataset_file"] = self.dataset_file | ||||
| args["load_dataset"] = self.load_dataset | |||||
| args["columns_list"] = self.columns_list | args["columns_list"] = self.columns_list | ||||
| args["global_shuffle"] = self.global_shuffle | args["global_shuffle"] = self.global_shuffle | ||||
| args["partitions"] = self.partitions | args["partitions"] = self.partitions | ||||
| @@ -2272,8 +2277,11 @@ class MindDataset(SourceDataset): | |||||
| Return: | Return: | ||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) | |||||
| if self.load_dataset: | |||||
| dataset_file = [self.dataset_file] | |||||
| else: | |||||
| dataset_file = self.dataset_file | |||||
| num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler) | |||||
| if self.partitions is not None and self.partitions[0] > 0: | if self.partitions is not None and self.partitions[0] > 0: | ||||
| if num_rows % self.partitions[0] == 0: | if num_rows % self.partitions[0] == 0: | ||||
| num_rows = num_rows // self.partitions[0] | num_rows = num_rows // self.partitions[0] | ||||
| @@ -529,8 +529,11 @@ def check_minddataset(method): | |||||
| dataset_file = param_dict.get('dataset_file') | dataset_file = param_dict.get('dataset_file') | ||||
| if dataset_file is None: | if dataset_file is None: | ||||
| raise ValueError("dataset_file is not provided.") | raise ValueError("dataset_file is not provided.") | ||||
| check_dataset_file(dataset_file) | |||||
| if isinstance(dataset_file, list): | |||||
| for f in dataset_file: | |||||
| check_dataset_file(f) | |||||
| else: | |||||
| check_dataset_file(dataset_file) | |||||
| check_param_type(nreq_param_int, param_dict, int) | check_param_type(nreq_param_int, param_dict, int) | ||||
| check_param_type(nreq_param_list, param_dict, list) | check_param_type(nreq_param_list, param_dict, list) | ||||
| @@ -28,7 +28,7 @@ class FileReader: | |||||
| Class to read MindRecord File series. | Class to read MindRecord File series. | ||||
| Args: | Args: | ||||
| file_name (str): File name of MindRecord File. | |||||
| file_name (str, list[str]): One of MindRecord File or file list. | |||||
| num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). | num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). | ||||
| It should not be smaller than 1 or larger than the number of CPU. | It should not be smaller than 1 or larger than the number of CPU. | ||||
| columns (list[str], optional): List of fields which correspond data would be read (default=None). | columns (list[str], optional): List of fields which correspond data would be read (default=None). | ||||
| @@ -38,8 +38,11 @@ class FileReader: | |||||
| ParamValueError: If file_name, num_consumer or columns is invalid. | ParamValueError: If file_name, num_consumer or columns is invalid. | ||||
| """ | """ | ||||
| def __init__(self, file_name, num_consumer=4, columns=None, operator=None): | def __init__(self, file_name, num_consumer=4, columns=None, operator=None): | ||||
| check_filename(file_name) | |||||
| self._file_name = file_name | |||||
| if isinstance(file_name, list): | |||||
| for f in file_name: | |||||
| check_filename(f) | |||||
| else: | |||||
| check_filename(file_name) | |||||
| if num_consumer is not None: | if num_consumer is not None: | ||||
| if isinstance(num_consumer, int): | if isinstance(num_consumer, int): | ||||
| @@ -28,7 +28,7 @@ class MindPage: | |||||
| Class to read MindRecord File series in pagination. | Class to read MindRecord File series in pagination. | ||||
| Args: | Args: | ||||
| file_name (str): File name of MindRecord File. | |||||
| file_name (str): One of MindRecord File or file list. | |||||
| num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). | num_consumer(int, optional): Number of consumer threads which load data to memory (default=4). | ||||
| It should not be smaller than 1 or larger than the number of CPU. | It should not be smaller than 1 or larger than the number of CPU. | ||||
| @@ -37,8 +37,11 @@ class MindPage: | |||||
| MRMInitSegmentError: If failed to initialize ShardSegment. | MRMInitSegmentError: If failed to initialize ShardSegment. | ||||
| """ | """ | ||||
| def __init__(self, file_name, num_consumer=4): | def __init__(self, file_name, num_consumer=4): | ||||
| check_filename(file_name) | |||||
| self._file_name = file_name | |||||
| if isinstance(file_name, list): | |||||
| for f in file_name: | |||||
| check_filename(f) | |||||
| else: | |||||
| check_filename(file_name) | |||||
| if num_consumer is not None: | if num_consumer is not None: | ||||
| if isinstance(num_consumer, int): | if isinstance(num_consumer, int): | ||||
| @@ -35,7 +35,7 @@ class ShardReader: | |||||
| Open file and prepare to read MindRecord File. | Open file and prepare to read MindRecord File. | ||||
| Args: | Args: | ||||
| file_name (str): File name of MindRecord File. | |||||
| file_name (str, list[str]): File names of MindRecord File. | |||||
| num_consumer (int): Number of worker threads which load data in parallel. Default: 4. | num_consumer (int): Number of worker threads which load data in parallel. Default: 4. | ||||
| columns (list[str]): List of fields which correspond data would be read. | columns (list[str]): List of fields which correspond data would be read. | ||||
| operator(int): Reserved parameter for operators. Default: None. | operator(int): Reserved parameter for operators. Default: None. | ||||
| @@ -48,7 +48,12 @@ class ShardReader: | |||||
| """ | """ | ||||
| columns = columns if columns else [] | columns = columns if columns else [] | ||||
| operator = operator if operator else [] | operator = operator if operator else [] | ||||
| ret = self._reader.open(file_name, num_consumer, columns, operator) | |||||
| if isinstance(file_name, list): | |||||
| load_dataset = False | |||||
| else: | |||||
| load_dataset = True | |||||
| file_name = [file_name] | |||||
| ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator) | |||||
| if ret != ms.MSRStatus.SUCCESS: | if ret != ms.MSRStatus.SUCCESS: | ||||
| logger.error("Failed to open {}.".format(file_name)) | logger.error("Failed to open {}.".format(file_name)) | ||||
| raise MRMOpenError | raise MRMOpenError | ||||
| @@ -40,7 +40,7 @@ class ShardSegment: | |||||
| Initialize the ShardSegment. | Initialize the ShardSegment. | ||||
| Args: | Args: | ||||
| file_name (str): File name of MindRecord File. | |||||
| file_name (str, list[str]): File names of MindRecord File. | |||||
| num_consumer (int): Number of worker threads which load data in parallel. Default: 4. | num_consumer (int): Number of worker threads which load data in parallel. Default: 4. | ||||
| columns (list[str]): List of fields which correspond data would be read. | columns (list[str]): List of fields which correspond data would be read. | ||||
| operator(int): Reserved parameter for operators. Default: None. | operator(int): Reserved parameter for operators. Default: None. | ||||
| @@ -53,7 +53,12 @@ class ShardSegment: | |||||
| """ | """ | ||||
| self._columns = columns if columns else [] | self._columns = columns if columns else [] | ||||
| operator = operator if operator else [] | operator = operator if operator else [] | ||||
| ret = self._segment.open(file_name, num_consumer, self._columns, operator) | |||||
| if isinstance(file_name, list): | |||||
| load_dataset = False | |||||
| else: | |||||
| load_dataset = True | |||||
| file_name = [file_name] | |||||
| ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator) | |||||
| if ret != SUCCESS: | if ret != SUCCESS: | ||||
| logger.error("Failed to open {}.".format(file_name)) | logger.error("Failed to open {}.".format(file_name)) | ||||
| raise MRMOpenError | raise MRMOpenError | ||||
| @@ -62,7 +62,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBasic) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list); | .SetColumnsToLoad(column_list); | ||||
| @@ -132,7 +133,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordSample) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list) | .SetColumnsToLoad(column_list) | ||||
| @@ -203,7 +205,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordShuffle) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list) | .SetColumnsToLoad(column_list) | ||||
| @@ -277,7 +280,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordCategory) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list) | .SetColumnsToLoad(column_list) | ||||
| @@ -345,7 +349,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordRepeat) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list); | .SetColumnsToLoad(column_list); | ||||
| @@ -426,7 +431,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordBlockReaderRepeat) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetBlockReader() | .SetBlockReader() | ||||
| @@ -507,7 +513,8 @@ TEST_F(MindDataTestMindRecordOp, TestMindRecordInvalidColumnList) { | |||||
| std::shared_ptr<MindRecordOp> my_mindrecord_op; | std::shared_ptr<MindRecordOp> my_mindrecord_op; | ||||
| MindRecordOp::Builder builder; | MindRecordOp::Builder builder; | ||||
| builder.SetDatasetFile(mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0") | |||||
| builder.SetDatasetFile({mindrecord_root_path_ + "/testMindDataSet/testImageNetData/imagenet.mindrecord0"}) | |||||
| .SetLoadDataset(true) | |||||
| .SetRowsPerBuffer(3) | .SetRowsPerBuffer(3) | ||||
| .SetNumMindRecordWorkers(4) | .SetNumMindRecordWorkers(4) | ||||
| .SetColumnsToLoad(column_list); | .SetColumnsToLoad(column_list); | ||||
| @@ -63,7 +63,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { | |||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | std::vector<std::shared_ptr<ShardOperator>> ops; | ||||
| ops.push_back(std::make_shared<ShardSample>(kSampleCount)); | ops.push_back(std::make_shared<ShardSample>(kSampleCount)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -89,7 +89,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { | |||||
| ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -115,7 +115,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { | |||||
| ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | ops.push_back(std::make_shared<ShardSample>(kNum, kDen)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -144,7 +144,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||||
| ASSERT_TRUE(partitions.second == 2); | ASSERT_TRUE(partitions.second == 2); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -168,7 +168,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2)); | ops.push_back(std::make_shared<ShardPkSample>("label", 2)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name},true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -193,7 +193,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); | ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name},true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -223,7 +223,7 @@ TEST_F(TestShardOperator, TestShardCategory) { | |||||
| ops.push_back(std::make_shared<ShardCategory>(categories)); | ops.push_back(std::make_shared<ShardCategory>(categories)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -254,7 +254,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { | |||||
| ops.push_back(std::make_shared<ShardShuffle>(1)); | ops.push_back(std::make_shared<ShardShuffle>(1)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 16, column_list, ops); | |||||
| dataset.Open({file_name}, true, 16, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -279,7 +279,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { | |||||
| ops.push_back(std::make_shared<ShardShuffle>(1)); | ops.push_back(std::make_shared<ShardShuffle>(1)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -306,7 +306,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { | |||||
| ops.push_back(std::make_shared<ShardSample>(kSampleSize)); | ops.push_back(std::make_shared<ShardSample>(kSampleSize)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -333,7 +333,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { | |||||
| ops.push_back(std::make_shared<ShardSample>(35)); | ops.push_back(std::make_shared<ShardSample>(35)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -357,11 +357,11 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { | |||||
| ops.push_back(std::make_shared<ShardShuffle>(1)); | ops.push_back(std::make_shared<ShardShuffle>(1)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| ShardReader compare_dataset; | ShardReader compare_dataset; | ||||
| compare_dataset.Open(file_name, 4, column_list); | |||||
| compare_dataset.Open({file_name},true, 4, column_list); | |||||
| compare_dataset.Launch(); | compare_dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -396,7 +396,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { | |||||
| ops.push_back(std::make_shared<ShardShuffle>(21)); | ops.push_back(std::make_shared<ShardShuffle>(21)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -430,7 +430,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | |||||
| ops.push_back(std::make_shared<ShardCategory>(categories)); | ops.push_back(std::make_shared<ShardCategory>(categories)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -464,7 +464,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { | |||||
| ops.push_back(std::make_shared<ShardCategory>(categories)); | ops.push_back(std::make_shared<ShardCategory>(categories)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name},true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -502,7 +502,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { | |||||
| ops.push_back(std::make_shared<ShardShuffle>(100)); | ops.push_back(std::make_shared<ShardShuffle>(100)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int i = 0; | int i = 0; | ||||
| @@ -55,7 +55,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { | |||||
| auto column_list = std::vector<std::string>{"file_name"}; | auto column_list = std::vector<std::string>{"file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list); | |||||
| dataset.Open({file_name}, true, 4, column_list); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -78,7 +78,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { | |||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | std::vector<std::shared_ptr<ShardOperator>> ops; | ||||
| ops.push_back(std::make_shared<ShardSample>(17)); | ops.push_back(std::make_shared<ShardSample>(17)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -103,7 +103,7 @@ TEST_F(TestShardReader, TestShardReaderBlock) { | |||||
| ops.push_back(std::make_shared<ShardSample>(3)); | ops.push_back(std::make_shared<ShardSample>(3)); | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| const bool kBlockReader = true; | const bool kBlockReader = true; | ||||
| dataset.Open(file_name, 4, column_list, ops, kBlockReader); | |||||
| dataset.Open({file_name}, true, 4, column_list, ops, kBlockReader); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -123,7 +123,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { | |||||
| MS_LOG(INFO) << FormatInfo("Test read imageNet"); | MS_LOG(INFO) << FormatInfo("Test read imageNet"); | ||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name); | |||||
| dataset.Open({file_name}, true); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -143,7 +143,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label"}; | auto column_list = std::vector<std::string>{"label"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(file_name, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -164,7 +164,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"file_namex"}; | auto column_list = std::vector<std::string>{"file_namex"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(file_name, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); | ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); | ||||
| } | } | ||||
| @@ -172,7 +172,7 @@ TEST_F(TestShardReader, TestShardVersion) { | |||||
| MS_LOG(INFO) << FormatInfo("Test shard version"); | MS_LOG(INFO) << FormatInfo("Test shard version"); | ||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(file_name, 4); | |||||
| MSRStatus ret = dataset.Open({file_name}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -195,7 +195,7 @@ TEST_F(TestShardReader, TestShardReaderDir) { | |||||
| auto column_list = std::vector<std::string>{"file_name"}; | auto column_list = std::vector<std::string>{"file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(file_name, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, FAILED); | ASSERT_EQ(ret, FAILED); | ||||
| } | } | ||||
| @@ -205,7 +205,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { | |||||
| auto column_list = std::vector<std::string>{"file_name"}; | auto column_list = std::vector<std::string>{"file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open(file_name, -481565535, column_list); | |||||
| dataset.Open({file_name}, true, -481565535, column_list); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -59,7 +59,7 @@ TEST_F(TestShardSegment, TestShardSegment) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open(file_name, 4); | |||||
| dataset.Open({file_name}, true, 4); | |||||
| auto x = dataset.GetCategoryFields(); | auto x = dataset.GetCategoryFields(); | ||||
| for (const auto &fields : x.second) { | for (const auto &fields : x.second) { | ||||
| @@ -97,7 +97,7 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open(file_name, 4); | |||||
| dataset.Open({file_name}, true, 4); | |||||
| auto x = dataset.GetCategoryFields(); | auto x = dataset.GetCategoryFields(); | ||||
| for (const auto &fields : x.second) { | for (const auto &fields : x.second) { | ||||
| @@ -121,7 +121,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open(file_name, 4); | |||||
| dataset.Open({file_name}, true, 4); | |||||
| auto x = dataset.GetCategoryFields(); | auto x = dataset.GetCategoryFields(); | ||||
| for (const auto &fields : x.second) { | for (const auto &fields : x.second) { | ||||
| @@ -143,7 +143,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open(file_name, 4); | |||||
| dataset.Open({file_name}, true, 4); | |||||
| auto x = dataset.GetCategoryFields(); | auto x = dataset.GetCategoryFields(); | ||||
| for (const auto &fields : x.second) { | for (const auto &fields : x.second) { | ||||
| @@ -165,7 +165,7 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open(file_name, 4); | |||||
| dataset.Open({file_name}, true, 4); | |||||
| auto x = dataset.GetCategoryFields(); | auto x = dataset.GetCategoryFields(); | ||||
| for (const auto &fields : x.second) { | for (const auto &fields : x.second) { | ||||
| @@ -60,7 +60,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { | |||||
| std::string filename = "./OneSample.shard01"; | std::string filename = "./OneSample.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(filename, 4); | |||||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -756,7 +756,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "file_name", "data"}; | auto column_list = std::vector<std::string>{"label", "file_name", "data"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(filename, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -842,7 +842,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { | |||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "file_name"}; | auto column_list = std::vector<std::string>{"label", "file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(filename, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -936,7 +936,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "data"}; | auto column_list = std::vector<std::string>{"label", "data"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(filename, 4, column_list); | |||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -1043,7 +1043,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { | |||||
| filename = "./TenSampleFortyShard.shard01"; | filename = "./TenSampleFortyShard.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open(filename, 4); | |||||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| dataset.Launch(); | dataset.Launch(); | ||||
| @@ -32,6 +32,8 @@ from mindspore.mindrecord import FileWriter | |||||
| FILES_NUM = 4 | FILES_NUM = 4 | ||||
| CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" | CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" | ||||
| CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord" | |||||
| CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord" | |||||
| CV_DIR_NAME = "../data/mindrecord/testImageNetData" | CV_DIR_NAME = "../data/mindrecord/testImageNetData" | ||||
| NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" | NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord" | ||||
| NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" | NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos" | ||||
| @@ -111,7 +113,6 @@ def test_cv_minddataset_writer_tutorial(): | |||||
| os.remove("{}".format(x)) | os.remove("{}".format(x)) | ||||
| os.remove("{}.db".format(x)) | os.remove("{}.db".format(x)) | ||||
| def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): | def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): | ||||
| """tutorial for cv minddataset.""" | """tutorial for cv minddataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| @@ -247,6 +248,126 @@ def test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_rem | |||||
| assert num_iter == 20 | assert num_iter == 20 | ||||
| def test_cv_minddataset_reader_file_list(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)], columns_list, num_readers) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 10 | |||||
| def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers) | |||||
| assert data_set.get_dataset_size() < 10 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter < 10 | |||||
| def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| if os.path.exists(CV1_FILE_NAME): | |||||
| os.remove(CV1_FILE_NAME) | |||||
| if os.path.exists("{}.db".format(CV1_FILE_NAME)): | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||||
| if os.path.exists(CV2_FILE_NAME): | |||||
| os.remove(CV2_FILE_NAME) | |||||
| if os.path.exists("{}.db".format(CV2_FILE_NAME)): | |||||
| os.remove("{}.db".format(CV2_FILE_NAME)) | |||||
| writer = FileWriter(CV1_FILE_NAME, 1) | |||||
| data = get_data(CV_DIR_NAME) | |||||
| cv_schema_json = {"id": {"type": "int32"}, | |||||
| "file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "data": {"type": "bytes"}} | |||||
| writer.add_schema(cv_schema_json, "CV1_schema") | |||||
| writer.add_index(["file_name", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| writer = FileWriter(CV2_FILE_NAME, 1) | |||||
| data = get_data(CV_DIR_NAME) | |||||
| cv_schema_json = {"id": {"type": "int32"}, | |||||
| "file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "data": {"type": "bytes"}} | |||||
| writer.add_schema(cv_schema_json, "CV2_schema") | |||||
| writer.add_index(["file_name", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME], columns_list, num_readers) | |||||
| assert data_set.get_dataset_size() == 30 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 30 | |||||
| if os.path.exists(CV1_FILE_NAME): | |||||
| os.remove(CV1_FILE_NAME) | |||||
| if os.path.exists("{}.db".format(CV1_FILE_NAME)): | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||||
| if os.path.exists(CV2_FILE_NAME): | |||||
| os.remove(CV2_FILE_NAME) | |||||
| if os.path.exists("{}.db".format(CV2_FILE_NAME)): | |||||
| os.remove("{}.db".format(CV2_FILE_NAME)) | |||||
| def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file): | |||||
| paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0')) | |||||
| for x in range(FILES_NUM)] | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None | |||||
| os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None | |||||
| writer = FileWriter(CV1_FILE_NAME, FILES_NUM) | |||||
| data = get_data(CV_DIR_NAME) | |||||
| cv_schema_json = {"id": {"type": "int32"}, | |||||
| "file_name": {"type": "string"}, | |||||
| "label": {"type": "int32"}, | |||||
| "data": {"type": "bytes"}} | |||||
| writer.add_schema(cv_schema_json, "CV1_schema") | |||||
| writer.add_index(["file_name", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] + [CV1_FILE_NAME + str(x) for x in range(2, 4)], columns_list, num_readers) | |||||
| assert data_set.get_dataset_size() < 20 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"]))) | |||||
| logger.info("-------------- item[data]: {} -----------------------------".format(item["data"])) | |||||
| logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"])) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter < 20 | |||||
| for x in paths: | |||||
| os.remove("{}".format(x)) | |||||
| os.remove("{}.db".format(x)) | |||||
| def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| @@ -22,6 +22,7 @@ import mindspore.dataset as ds | |||||
| from mindspore.mindrecord import FileWriter | from mindspore.mindrecord import FileWriter | ||||
| CV_FILE_NAME = "./imagenet.mindrecord" | CV_FILE_NAME = "./imagenet.mindrecord" | ||||
| CV1_FILE_NAME = "./imagenet1.mindrecord" | |||||
| def create_cv_mindrecord(files_num): | def create_cv_mindrecord(files_num): | ||||
| @@ -37,6 +38,31 @@ def create_cv_mindrecord(files_num): | |||||
| writer.commit() | writer.commit() | ||||
| def create_diff_schema_cv_mindrecord(files_num): | |||||
| """tutorial for cv dataset writer.""" | |||||
| os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None | |||||
| writer = FileWriter(CV1_FILE_NAME, files_num) | |||||
| cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} | |||||
| data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] | |||||
| writer.add_schema(cv_schema_json, "img_schema") | |||||
| writer.add_index(["file_name_1", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| def create_diff_page_size_cv_mindrecord(files_num): | |||||
| """tutorial for cv dataset writer.""" | |||||
| os.remove(CV1_FILE_NAME) if os.path.exists(CV1_FILE_NAME) else None | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) if os.path.exists("{}.db".format(CV1_FILE_NAME)) else None | |||||
| writer = FileWriter(CV1_FILE_NAME, files_num) | |||||
| writer.set_page_size(1<< 26) #64MB | |||||
| cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}} | |||||
| data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}] | |||||
| writer.add_schema(cv_schema_json, "img_schema") | |||||
| writer.add_index(["file_name", "label"]) | |||||
| writer.write_raw_data(data) | |||||
| writer.commit() | |||||
| def test_cv_lack_json(): | def test_cv_lack_json(): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| @@ -111,3 +137,34 @@ def test_cv_minddataset_pk_sample_exclusive_shuffle(): | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| def test_cv_minddataset_reader_different_schema(): | |||||
| create_cv_mindrecord(1) | |||||
| create_diff_schema_cv_mindrecord(1) | |||||
| columns_list = ["data", "label"] | |||||
| num_readers = 4 | |||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||||
| num_readers) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||
| os.remove(CV1_FILE_NAME) | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||||
| def test_cv_minddataset_reader_different_page_size(): | |||||
| create_cv_mindrecord(1) | |||||
| create_diff_page_size_cv_mindrecord(1) | |||||
| columns_list = ["data", "label"] | |||||
| num_readers = 4 | |||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||||
| num_readers) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||
| os.remove(CV1_FILE_NAME) | |||||
| os.remove("{}.db".format(CV1_FILE_NAME)) | |||||
| @@ -202,6 +202,16 @@ def test_cv_file_reader_tutorial(): | |||||
| assert count == 10 | assert count == 10 | ||||
| reader.close() | reader.close() | ||||
| def test_cv_file_reader_file_list(): | |||||
| """tutorial for cv file partial reader.""" | |||||
| reader = FileReader([CV_FILE_NAME + str(x) for x in range(FILES_NUM)]) | |||||
| count = 0 | |||||
| for index, x in enumerate(reader.get_next()): | |||||
| assert len(x) == 3 | |||||
| count = count + 1 | |||||
| logger.info("#item{}: {}".format(index, x)) | |||||
| assert count == 10 | |||||
| def test_cv_file_reader_partial_tutorial(): | def test_cv_file_reader_partial_tutorial(): | ||||
| """tutorial for cv file partial reader.""" | """tutorial for cv file partial reader.""" | ||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||