From: @cathwong Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| if (ENABLE_PYTHON) | |||
| add_library(APItoPython OBJECT | |||
| python/de_pipeline.cc | |||
| python/pybind_register.cc | |||
| python/bindings.cc | |||
| python/pybind_conversion.cc | |||
| python/bindings/dataset/include/datasets_bindings.cc | |||
| python/bindings/dataset/include/iterator_bindings.cc | |||
| python/bindings/dataset/include/schema_bindings.cc | |||
| python/bindings/dataset/engine/cache/bindings.cc | |||
| python/bindings/dataset/core/bindings.cc | |||
| python/bindings/dataset/callback/bindings.cc | |||
| @@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum | |||
| #ifndef ENABLE_ANDROID | |||
| // Function to return a transferred Node that transfers data through a device. | |||
| bool Dataset::DeviceQueue(bool send_epoch_end) { | |||
| bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end, | |||
| int32_t total_batches, bool create_data_info_queue) { | |||
| Status rc; | |||
| // Build and launch tree | |||
| @@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) { | |||
| return false; | |||
| } | |||
| // Add TransferNode IR on top of dataset d | |||
| auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end); | |||
| // Add TransferNode IR on top of dataset | |||
| auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end, | |||
| total_batches, create_data_info_queue); | |||
| // Get ToDevice consumer | |||
| auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1); | |||
| auto consumer = std::make_unique<ToDevice>(num_epochs); | |||
| ToDevice *consumer_ = consumer.get(); | |||
| rc = consumer->Init(ds); | |||
| if (rc.IsError()) { | |||
| @@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } | |||
| int64_t Dataset::GetDatasetSize() { | |||
| int64_t dataset_size; | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->Init(this->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->GetDatasetSize(&dataset_size); | |||
| return rc.IsError() ? -1 : dataset_size; | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1); | |||
| return dataset_size; | |||
| } | |||
| std::vector<DataType> Dataset::GetOutputTypes() { | |||
| std::vector<DataType> types; | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; | |||
| return types; | |||
| } | |||
| rc = tree_getters_->Init(this->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; | |||
| return types; | |||
| } | |||
| rc = tree_getters_->GetOutputTypes(&types); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed."; | |||
| types.clear(); | |||
| return types; | |||
| } | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {}); | |||
| return types; | |||
| } | |||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | |||
| std::vector<TensorShape> shapes; | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; | |||
| return shapes; | |||
| } | |||
| rc = tree_getters_->Init(this->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; | |||
| return shapes; | |||
| } | |||
| rc = tree_getters_->GetOutputShapes(&shapes); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed."; | |||
| shapes.clear(); | |||
| return shapes; | |||
| } | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {}); | |||
| return shapes; | |||
| } | |||
| int64_t Dataset::GetNumClasses() { | |||
| int64_t num_classes; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->Init(ds->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->GetNumClasses(&num_classes); | |||
| return rc.IsError() ? -1 : num_classes; | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1); | |||
| return num_classes; | |||
| } | |||
| std::vector<std::string> Dataset::GetColumnNames() { | |||
| std::vector<std::string> col_names; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed."; | |||
| return std::vector<std::string>(); | |||
| } | |||
| rc = tree_getters_->Init(ds->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed."; | |||
| return std::vector<std::string>(); | |||
| } | |||
| rc = tree_getters_->GetColumnNames(&col_names); | |||
| return rc.IsError() ? std::vector<std::string>() : col_names; | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {}); | |||
| return col_names; | |||
| } | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() { | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed."; | |||
| return output_class_indexing; | |||
| } | |||
| rc = tree_getters_->Init(ds->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed."; | |||
| return output_class_indexing; | |||
| } | |||
| rc = tree_getters_->GetClassIndexing(&output_class_indexing); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed."; | |||
| output_class_indexing.clear(); | |||
| return output_class_indexing; | |||
| } | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {}); | |||
| return output_class_indexing; | |||
| } | |||
| @@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset( | |||
| std::function<TensorRow(TensorRow)> element_length_function, | |||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | |||
| bool drop_remainder) { | |||
| auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, | |||
| bucket_batch_sizes, element_length_function, pad_info, | |||
| pad_to_bucket_boundary, drop_remainder); | |||
| std::shared_ptr<TensorOp> c_func = nullptr; | |||
| if (element_length_function != nullptr) { | |||
| c_func = std::make_shared<CFuncOp>(element_length_function); | |||
| } | |||
| auto ds = | |||
| std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes, | |||
| c_func, pad_info, pad_to_bucket_boundary, drop_remainder); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| @@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase | |||
| FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate, | |||
| std::vector<std::string> input_columns) { | |||
| auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns); | |||
| std::shared_ptr<TensorOp> c_func = nullptr; | |||
| if (predicate) c_func = std::make_shared<CFuncOp>(predicate); | |||
| auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| @@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||
| #endif | |||
| int64_t Dataset::GetBatchSize() { | |||
| int64_t batch_size; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->Init(ds->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->GetBatchSize(&batch_size); | |||
| return rc.IsError() ? -1 : batch_size; | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1); | |||
| return batch_size; | |||
| } | |||
| int64_t Dataset::GetRepeatCount() { | |||
| int64_t repeat_count; | |||
| auto ds = shared_from_this(); | |||
| Status rc; | |||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | |||
| rc = runtime_context->Init(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->Init(ds->IRNode()); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | |||
| return -1; | |||
| } | |||
| rc = tree_getters_->GetRepeatCount(&repeat_count); | |||
| return rc.IsError() ? 0 : repeat_count; | |||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0); | |||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0); | |||
| return repeat_count; | |||
| } | |||
| std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) { | |||
| @@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai | |||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | |||
| // SchemaObj init function | |||
| bool SchemaObj::init() { | |||
| if (schema_file_ != "") { | |||
| Status SchemaObj::init() { | |||
| if (!schema_file_.empty()) { | |||
| Path schema_file(schema_file_); | |||
| if (!schema_file.Exists()) { | |||
| MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!"; | |||
| return false; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(), | |||
| "The file " + schema_file_ + " does not exist or permission denied!"); | |||
| nlohmann::json js; | |||
| try { | |||
| std::ifstream in(schema_file_); | |||
| in >> js; | |||
| if (js.find("columns") == js.end()) { | |||
| MS_LOG(ERROR) << "\"columns\" node is required in the schema json file."; | |||
| return false; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(), | |||
| "\"columns\" node is required in the schema json file."); | |||
| } catch (const std::exception &err) { | |||
| MS_LOG(ERROR) << "Schema file failed to load"; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load"); | |||
| } | |||
| return from_json(js); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| // Function to add a column to schema with a mstype de_type and known shape | |||
| Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) { | |||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||
| return add_column(name, data_type.ToString(), shape); | |||
| } | |||
| // Function to add a column to schema with a mstype de_type | |||
| bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) { | |||
| // Function to add a column to schema with a string de_type and known shape | |||
| Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) { | |||
| DataType data_type(de_type); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown."); | |||
| nlohmann::json new_column; | |||
| new_column["name"] = name; | |||
| // if de_type is mstype | |||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||
| new_column["type"] = data_type.ToString(); | |||
| if (shape.size() > 0) { | |||
| new_column["shape"] = shape; | |||
| new_column["rank"] = shape.size(); | |||
| } else { | |||
| new_column["rank"] = 1; | |||
| } | |||
| new_column["shape"] = shape; | |||
| new_column["rank"] = shape.size(); | |||
| columns_.push_back(new_column); | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| // Function to add a column to schema with a mstype de_type and without shape | |||
| Status SchemaObj::add_column(std::string name, TypeId de_type) { | |||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||
| return add_column(name, data_type.ToString()); | |||
| } | |||
| // Function to add a column to schema with a string de_type | |||
| bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) { | |||
| // Function to add a column to schema with a string de_type and without shape | |||
| Status SchemaObj::add_column(std::string name, std::string de_type) { | |||
| DataType data_type(de_type); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown."); | |||
| nlohmann::json new_column; | |||
| new_column["name"] = name; | |||
| DataType data_type(de_type); | |||
| new_column["type"] = data_type.ToString(); | |||
| if (shape.size() > 0) { | |||
| new_column["shape"] = shape; | |||
| new_column["rank"] = shape.size(); | |||
| } else { | |||
| new_column["rank"] = 1; | |||
| } | |||
| new_column["rank"] = 1; | |||
| columns_.push_back(new_column); | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| std::string SchemaObj::to_json() { | |||
| @@ -792,7 +711,7 @@ std::string SchemaObj::to_json() { | |||
| return json_file.dump(2); | |||
| } | |||
| bool SchemaObj::parse_column(nlohmann::json columns) { | |||
| Status SchemaObj::parse_column(nlohmann::json columns) { | |||
| std::string name, de_type; | |||
| std::vector<int32_t> shape; | |||
| @@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||
| for (auto column : columns) { | |||
| auto key_name = column.find("name"); | |||
| if (key_name == column.end()) { | |||
| MS_LOG(ERROR) << "Column's name is missing"; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Column's name is missing"); | |||
| } | |||
| name = *key_name; | |||
| auto key_type = column.find("type"); | |||
| if (key_type == column.end()) { | |||
| MS_LOG(ERROR) << "Column's type is missing"; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Column's type is missing"); | |||
| } | |||
| de_type = *key_type; | |||
| @@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||
| if (key_shape != column.end()) { | |||
| shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | |||
| } | |||
| if (!add_column(name, de_type, shape)) { | |||
| return false; | |||
| } | |||
| RETURN_IF_NOT_OK(add_column(name, de_type, shape)); | |||
| } | |||
| } else if (columns.type() == nlohmann::json::value_t::object) { | |||
| for (const auto &it_child : columns.items()) { | |||
| name = it_child.key(); | |||
| auto key_type = it_child.value().find("type"); | |||
| if (key_type == it_child.value().end()) { | |||
| MS_LOG(ERROR) << "Column's type is missing"; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Column's type is missing"); | |||
| } | |||
| de_type = *key_type; | |||
| @@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||
| shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | |||
| } | |||
| if (!add_column(name, de_type, shape)) { | |||
| return false; | |||
| } | |||
| RETURN_IF_NOT_OK(add_column(name, de_type, shape)); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional)."; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional)."); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| bool SchemaObj::from_json(nlohmann::json json_obj) { | |||
| Status SchemaObj::from_json(nlohmann::json json_obj) { | |||
| for (const auto &it_child : json_obj.items()) { | |||
| if (it_child.key() == "datasetType") { | |||
| dataset_type_ = it_child.value(); | |||
| } else if (it_child.key() == "numRows") { | |||
| num_rows_ = it_child.value(); | |||
| } else if (it_child.key() == "columns") { | |||
| if (!parse_column(it_child.value())) { | |||
| MS_LOG(ERROR) << "parse columns failed"; | |||
| return false; | |||
| } | |||
| RETURN_IF_NOT_OK(parse_column(it_child.value())); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknown field " << it_child.key(); | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key()); | |||
| } | |||
| } | |||
| if (columns_.empty()) { | |||
| MS_LOG(ERROR) << "Columns are missing."; | |||
| return false; | |||
| RETURN_STATUS_SYNTAX_ERROR("Columns are missing."); | |||
| } | |||
| if (num_rows_ <= 0) { | |||
| MS_LOG(ERROR) << "numRows must be greater than 0"; | |||
| return false; | |||
| if (num_rows_ < 0) { | |||
| RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0"); | |||
| } | |||
| return true; | |||
| return Status::OK(); | |||
| } | |||
| Status SchemaObj::FromJSONString(const std::string &json_string) { | |||
| try { | |||
| nlohmann::json js = nlohmann::json::parse(json_string); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(), | |||
| "\"columns\" node is required in the schema json JSON."); | |||
| RETURN_IF_NOT_OK(from_json(js)); | |||
| } catch (const std::exception &err) { | |||
| RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse"); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // OTHER FUNCTIONS | |||
| @@ -1,136 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER( | |||
| DEPipeline, 0, ([](const py::module *m) { | |||
| (void)py::class_<DEPipeline>(*m, "DEPipeline") | |||
| .def(py::init<>()) | |||
| .def( | |||
| "AddNodeToTree", | |||
| [](DEPipeline &de, const OpName &op_name, const py::dict &args) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); | |||
| return out; | |||
| }, | |||
| py::return_value_policy::reference) | |||
| .def_static("AddChildToParentNode", | |||
| [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { | |||
| THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); | |||
| }) | |||
| .def("AssignRootNode", | |||
| [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) | |||
| .def("SetBatchParameters", | |||
| [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) | |||
| .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); }) | |||
| .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) | |||
| .def("GetColumnNames", | |||
| [](DEPipeline &de) { | |||
| py::list out; | |||
| THROW_IF_ERROR(de.GetColumnNames(&out)); | |||
| return out; | |||
| }) | |||
| .def("GetNextAsMap", | |||
| [](DEPipeline &de) { | |||
| py::dict out; | |||
| THROW_IF_ERROR(de.GetNextAsMap(&out)); | |||
| return out; | |||
| }) | |||
| .def("GetNextAsList", | |||
| [](DEPipeline &de) { | |||
| py::list out; | |||
| THROW_IF_ERROR(de.GetNextAsList(&out)); | |||
| return out; | |||
| }) | |||
| .def("GetOutputShapes", | |||
| [](DEPipeline &de) { | |||
| py::list out; | |||
| THROW_IF_ERROR(de.GetOutputShapes(&out)); | |||
| return out; | |||
| }) | |||
| .def("GetOutputTypes", | |||
| [](DEPipeline &de) { | |||
| py::list out; | |||
| THROW_IF_ERROR(de.GetOutputTypes(&out)); | |||
| return out; | |||
| }) | |||
| .def("GetDataInfo", | |||
| [](DEPipeline &de) { | |||
| py::list types, shapes; | |||
| THROW_IF_ERROR(de.GetDataInfo(&types, &shapes)); | |||
| return py::make_tuple(types, shapes); | |||
| }) | |||
| .def("GetDatasetSize", &DEPipeline::GetDatasetSize) | |||
| .def("GetBatchSize", &DEPipeline::GetBatchSize) | |||
| .def("GetNumClasses", &DEPipeline::GetNumClasses) | |||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | |||
| .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) | |||
| .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); }) | |||
| .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | |||
| THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | |||
| return true; | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(OpName, 0, ([](const py::module *m) { | |||
| (void)py::enum_<OpName>(*m, "OpName", py::arithmetic()) | |||
| .value("SHUFFLE", OpName::kShuffle) | |||
| .value("BATCH", OpName::kBatch) | |||
| .value("BUCKETBATCH", OpName::kBucketBatch) | |||
| .value("BARRIER", OpName::kBarrier) | |||
| .value("MINDRECORD", OpName::kMindrecord) | |||
| .value("CACHE", OpName::kCache) | |||
| .value("REPEAT", OpName::kRepeat) | |||
| .value("SKIP", OpName::kSkip) | |||
| .value("TAKE", OpName::kTake) | |||
| .value("ZIP", OpName::kZip) | |||
| .value("CONCAT", OpName::kConcat) | |||
| .value("MAP", OpName::kMap) | |||
| .value("FILTER", OpName::kFilter) | |||
| .value("DEVICEQUEUE", OpName::kDeviceQueue) | |||
| .value("GENERATOR", OpName::kGenerator) | |||
| .export_values() | |||
| .value("RENAME", OpName::kRename) | |||
| .value("TFREADER", OpName::kTfReader) | |||
| .value("PROJECT", OpName::kProject) | |||
| .value("IMAGEFOLDER", OpName::kImageFolder) | |||
| .value("MNIST", OpName::kMnist) | |||
| .value("MANIFEST", OpName::kManifest) | |||
| .value("VOC", OpName::kVoc) | |||
| .value("COCO", OpName::kCoco) | |||
| .value("CIFAR10", OpName::kCifar10) | |||
| .value("CIFAR100", OpName::kCifar100) | |||
| .value("RANDOMDATA", OpName::kRandomData) | |||
| .value("BUILDVOCAB", OpName::kBuildVocab) | |||
| .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) | |||
| .value("CELEBA", OpName::kCelebA) | |||
| .value("TEXTFILE", OpName::kTextFile) | |||
| .value("EPOCHCTRL", OpName::kEpochCtrl) | |||
| .value("CSV", OpName::kCsv) | |||
| .value("CLUE", OpName::kClue); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -19,8 +19,10 @@ | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/core/client.h" // DE client | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "pybind11/numpy.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -0,0 +1,551 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| // IR non-leaf nodes | |||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||
| // IR non-leaf nodes - for android | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" | |||
| #endif | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/data_type.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/random.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| // IR leaf nodes | |||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||
| // IR leaf nodes disabled for android | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { | |||
| (void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") | |||
| .def("SetNumWorkers", | |||
| [](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) { | |||
| return num_workers ? self->SetNumWorkers(*num_workers) : self; | |||
| }) | |||
| .def( | |||
| "Zip", | |||
| [](std::shared_ptr<DatasetNode> self, py::list datasets) { | |||
| auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets))); | |||
| THROW_IF_ERROR(zip->ValidateParams()); | |||
| return zip; | |||
| }, | |||
| py::arg("datasets")); | |||
| })); | |||
| // PYBIND FOR LEAF NODES | |||
| // (In alphabetical order) | |||
| PYBIND_REGISTER( | |||
| CelebANode, 2, ([](const py::module *m) { | |||
| (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode, | |||
| std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, | |||
| toStringSet(extensions), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(celebA->ValidateParams()); | |||
| return celebA; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { | |||
| (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node", | |||
| "to create a Cifar10Node") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), | |||
| toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(cifar10->ValidateParams()); | |||
| return cifar10; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { | |||
| (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node", | |||
| "to create a Cifar100Node") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), | |||
| toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(cifar100->ValidateParams()); | |||
| return cifar100; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| CLUENode, 2, ([](const py::module *m) { | |||
| (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode") | |||
| .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle, | |||
| int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| std::shared_ptr<CLUENode> clue_node = | |||
| std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle), | |||
| num_shards, shard_id, toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(clue_node->ValidateParams()); | |||
| return clue_node; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| CocoNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode") | |||
| .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode, | |||
| std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>( | |||
| dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(coco->ValidateParams()); | |||
| return coco; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode") | |||
| .def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults, | |||
| std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle, | |||
| int32_t num_shards, int32_t shard_id, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), | |||
| column_names, num_samples, toShuffleMode(shuffle), | |||
| num_shards, shard_id, toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(csv->ValidateParams()); | |||
| return csv; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( | |||
| *m, "GeneratorNode", "to create a GeneratorNode") | |||
| .def(py::init([](py::function generator_function, const std::vector<std::string> &column_names, | |||
| const std::vector<DataType> &column_types) { | |||
| auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types); | |||
| THROW_IF_ERROR(gen->ValidateParams()); | |||
| return gen; | |||
| })) | |||
| .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) { | |||
| auto gen = std::make_shared<GeneratorNode>(generator_function, schema); | |||
| THROW_IF_ERROR(gen->ValidateParams()); | |||
| return gen; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( | |||
| *m, "ImageFolderNode", "to create an ImageFolderNode") | |||
| .def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler, | |||
| std::optional<py::list> extensions, std::optional<py::dict> class_indexing, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| bool recursive = true; | |||
| auto imagefolder = std::make_shared<ImageFolderNode>( | |||
| dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions), | |||
| toStringMap(class_indexing), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(imagefolder->ValidateParams()); | |||
| return imagefolder; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode", | |||
| "to create a ManifestNode") | |||
| .def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler, | |||
| std::optional<py::dict> class_indexing, bool decode, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler), | |||
| toStringMap(class_indexing), decode, | |||
| toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(manifest->ValidateParams()); | |||
| return manifest; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode", | |||
| "to create a MindDataNode") | |||
| .def(py::init([](std::string dataset_file, std::optional<py::list> columns_list, | |||
| std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) { | |||
| nlohmann::json padded_sample_json; | |||
| std::map<std::string, std::string> sample_bytes; | |||
| THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); | |||
| auto minddata = | |||
| std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list), | |||
| toSamplerObj(sampler, true), padded_sample_json, num_padded); | |||
| minddata->SetSampleBytes(&sample_bytes); | |||
| THROW_IF_ERROR(minddata->ValidateParams()); | |||
| return minddata; | |||
| })) | |||
| .def(py::init([](py::list dataset_file, std::optional<py::list> columns_list, | |||
| std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) { | |||
| nlohmann::json padded_sample_json; | |||
| std::map<std::string, std::string> sample_bytes; | |||
| THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); | |||
| auto minddata = | |||
| std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list), | |||
| toSamplerObj(sampler, true), padded_sample_json, num_padded); | |||
| minddata->SetSampleBytes(&sample_bytes); | |||
| THROW_IF_ERROR(minddata->ValidateParams()); | |||
| return minddata; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode", | |||
| "to create an MnistNode") | |||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), | |||
| toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(mnist->ValidateParams()); | |||
| return mnist; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| RandomNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode") | |||
| .def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto random_node = | |||
| std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(random_node->ValidateParams()); | |||
| return random_node; | |||
| })) | |||
| .def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| auto random_node = | |||
| std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(random_node->ValidateParams()); | |||
| return random_node; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", | |||
| "to create a TextFileNode") | |||
| .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, | |||
| int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>( | |||
| toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id, | |||
| toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(textfile_node->ValidateParams()); | |||
| return textfile_node; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| TFRecordNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode", | |||
| "to create a TFRecordNode") | |||
| .def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list, | |||
| std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards, | |||
| std::optional<int32_t> shard_id, bool shard_equal_rows, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| if (!num_samples) { | |||
| *num_samples = 0; | |||
| } | |||
| std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( | |||
| toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle), | |||
| *num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(tfrecord->ValidateParams()); | |||
| return tfrecord; | |||
| })) | |||
| .def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list, | |||
| std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards, | |||
| std::optional<int32_t> shard_id, bool shard_equal_rows, | |||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| if (!num_samples) { | |||
| *num_samples = 0; | |||
| } | |||
| std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( | |||
| toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle), | |||
| *num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(tfrecord->ValidateParams()); | |||
| return tfrecord; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode") | |||
| .def( | |||
| py::init([](std::string dataset_dir, std::string task, std::string usage, | |||
| std::optional<py::dict> class_indexing, bool decode, | |||
| std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| std::shared_ptr<VOCNode> voc = | |||
| std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode, | |||
| toSamplerObj(sampler), toDatasetCache(std::move(cc))); | |||
| THROW_IF_ERROR(voc->ValidateParams()); | |||
| return voc; | |||
| })); | |||
| })); | |||
| // PYBIND FOR NON-LEAF NODES | |||
| // (In alphabetical order) | |||
| PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode", | |||
| "to create a BatchNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder, | |||
| bool pad, py::list in_col_names, py::list out_col_names, py::list col_order, | |||
| py::object size_obj, py::object map_obj, py::dict pad_info) { | |||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; | |||
| if (pad) { | |||
| THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); | |||
| } | |||
| py::function size_func = | |||
| py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function(); | |||
| py::function map_func = | |||
| py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function(); | |||
| auto batch = std::make_shared<BatchNode>( | |||
| self, batch_size, drop_remainder, pad, toStringVector(in_col_names), | |||
| toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info); | |||
| THROW_IF_ERROR(batch->ValidateParams()); | |||
| return batch; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>( | |||
| *m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names, | |||
| std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes, | |||
| py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary, | |||
| bool drop_remainder) { | |||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; | |||
| THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); | |||
| auto bucket_batch = std::make_shared<BucketBatchByLengthNode>( | |||
| dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes, | |||
| toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info, | |||
| pad_to_bucket_boundary, drop_remainder); | |||
| THROW_IF_ERROR(bucket_batch->ValidateParams()); | |||
| return bucket_batch; | |||
| }), | |||
| py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"), | |||
| py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(), | |||
| py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder")); | |||
| })); | |||
| PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>( | |||
| *m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab, | |||
| const std::vector<std::string> &col_names, uint32_t vocab_size, | |||
| float character_coverage, SentencePieceModel model_type, | |||
| const std::unordered_map<std::string, std::string> ¶ms) { | |||
| auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>( | |||
| self, vocab, col_names, vocab_size, character_coverage, model_type, params); | |||
| THROW_IF_ERROR(build_sentence_vocab->ValidateParams()); | |||
| return build_sentence_vocab; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>( | |||
| *m, "BuildVocabNode", "to create a BuildVocabNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns, | |||
| py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) { | |||
| auto build_vocab = | |||
| std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range), | |||
| top_k, toStringVector(special_tokens), special_first); | |||
| THROW_IF_ERROR(build_vocab->ValidateParams()); | |||
| return build_vocab; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode", | |||
| "to create a ConcatNode") | |||
| .def( | |||
| py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler, | |||
| py::list children_flag_and_nums, py::list children_start_end_index) { | |||
| auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler), | |||
| toPairVector(children_flag_and_nums), | |||
| toPairVector(children_start_end_index)); | |||
| THROW_IF_ERROR(concat->ValidateParams()); | |||
| return concat; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode", | |||
| "to create a FilterNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate, | |||
| std::vector<std::string> input_columns) { | |||
| auto filter = | |||
| std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns); | |||
| THROW_IF_ERROR(filter->ValidateParams()); | |||
| return filter; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations, | |||
| std::optional<py::list> input_columns, std::optional<py::list> output_columns, | |||
| std::optional<py::list> project_columns, | |||
| std::optional<std::shared_ptr<CacheClient>> cc, | |||
| std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) { | |||
| auto map = std::make_shared<MapNode>( | |||
| self, std::move(toTensorOperations(operations)), toStringVector(input_columns), | |||
| toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)), | |||
| std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end())); | |||
| THROW_IF_ERROR(map->ValidateParams()); | |||
| return map; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode", | |||
| "to create a ProjectNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) { | |||
| auto project = std::make_shared<ProjectNode>(self, toStringVector(columns)); | |||
| THROW_IF_ERROR(project->ValidateParams()); | |||
| return project; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode", | |||
| "to create a RenameNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns, | |||
| std::optional<py::list> output_columns) { | |||
| auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns), | |||
| toStringVector(output_columns)); | |||
| THROW_IF_ERROR(rename->ValidateParams()); | |||
| return rename; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode", | |||
| "to create a RepeatNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) { | |||
| auto repeat = std::make_shared<RepeatNode>(input, count); | |||
| THROW_IF_ERROR(repeat->ValidateParams()); | |||
| return repeat; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode", | |||
| "to create a ShuffleNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) { | |||
| auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch); | |||
| THROW_IF_ERROR(shuffle->ValidateParams()); | |||
| return shuffle; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode", | |||
| "to create a SkipNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { | |||
| auto skip = std::make_shared<SkipNode>(self, count); | |||
| THROW_IF_ERROR(skip->ValidateParams()); | |||
| return skip; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode", | |||
| "to create a SyncWaitNode") | |||
| .def( | |||
| py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) { | |||
| py::function callback_func = | |||
| py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function(); | |||
| auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback); | |||
| THROW_IF_ERROR(sync_wait->ValidateParams()); | |||
| return sync_wait; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode", | |||
| "to create a TakeNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { | |||
| auto take = std::make_shared<TakeNode>(self, count); | |||
| THROW_IF_ERROR(take->ValidateParams()); | |||
| return take; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode", | |||
| "to create a TransferNode") | |||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type, | |||
| bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) { | |||
| auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end, | |||
| total_batch, create_data_info_queue); | |||
| THROW_IF_ERROR(transfer->ValidateParams()); | |||
| return transfer; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode") | |||
| .def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) { | |||
| auto zip = std::make_shared<ZipNode>(datasets); | |||
| THROW_IF_ERROR(zip->ValidateParams()); | |||
| return zip; | |||
| })); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,168 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| #include "minddata/dataset/engine/python_runtime_context.h" | |||
| #include "minddata/dataset/engine/consumers/python_tree_consumer.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) { | |||
| (void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer"); | |||
| })); | |||
| PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>( | |||
| *m, "PythonIteratorConsumer") | |||
| .def(py::init<int32_t>()) | |||
| .def("Init", [](PythonIteratorConsumer &self, | |||
| std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("GetNextAsMap", | |||
| [](PythonIteratorConsumer &self) { | |||
| py::dict output; | |||
| THROW_IF_ERROR(self.GetNextAsDict(&output)); | |||
| return output; | |||
| }) | |||
| .def("GetNextAsList", [](PythonIteratorConsumer &self) { | |||
| py::list output; | |||
| THROW_IF_ERROR(self.GetNextAsList(&output)); | |||
| return output; | |||
| }); | |||
| })); | |||
| PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m, | |||
| "TreeGetters") | |||
| .def(py::init<>()) | |||
| .def("Init", | |||
| [](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("GetOutputShapes", | |||
| [](PythonTreeGetters &self) { | |||
| std::vector<TensorShape> shapes; | |||
| THROW_IF_ERROR(self.GetOutputShapes(&shapes)); | |||
| return shapesToListOfShape(shapes); | |||
| }) | |||
| .def("GetOutputTypes", | |||
| [](PythonTreeGetters &self) { | |||
| std::vector<DataType> types; | |||
| THROW_IF_ERROR(self.GetOutputTypes(&types)); | |||
| return typesToListOfType(types); | |||
| }) | |||
| .def("GetNumClasses", | |||
| [](PythonTreeGetters &self) { | |||
| int64_t num_classes; | |||
| THROW_IF_ERROR(self.GetNumClasses(&num_classes)); | |||
| return num_classes; | |||
| }) | |||
| .def("GetRepeatCount", | |||
| [](PythonTreeGetters &self) { | |||
| int64_t repeat_count; | |||
| THROW_IF_ERROR(self.GetRepeatCount(&repeat_count)); | |||
| return repeat_count; | |||
| }) | |||
| .def("GetBatchSize", | |||
| [](PythonTreeGetters &self) { | |||
| int64_t batch_size; | |||
| THROW_IF_ERROR(self.GetBatchSize(&batch_size)); | |||
| return batch_size; | |||
| }) | |||
| .def("GetColumnNames", | |||
| [](PythonTreeGetters &self) { | |||
| std::vector<std::string> col_names; | |||
| THROW_IF_ERROR(self.GetColumnNames(&col_names)); | |||
| return col_names; | |||
| }) | |||
| .def("GetClassIndexing", | |||
| [](PythonTreeGetters &self) { | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; | |||
| THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing)); | |||
| return output_class_indexing; | |||
| }) | |||
| .def("GetDatasetSize", | |||
| [](PythonTreeGetters &self) { | |||
| int64_t dataset_size; | |||
| THROW_IF_ERROR(self.GetDatasetSize(&dataset_size)); | |||
| return dataset_size; | |||
| }) | |||
| .def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; }); | |||
| })); | |||
| PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) { | |||
| (void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m, | |||
| "PythonRuntimeContext") | |||
| .def(py::init<>()) | |||
| .def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); }) | |||
| .def("AssignConsumer", &PythonRuntimeContext::AssignConsumer) | |||
| .def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); }) | |||
| .def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference) | |||
| .def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; }); | |||
| })); | |||
| PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>( | |||
| *m, "PythonBuildVocabConsumer") | |||
| .def(py::init<>()) | |||
| .def("Init", [](PythonBuildVocabConsumer &self, | |||
| std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); }); | |||
| })); | |||
| PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) { | |||
| (void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice") | |||
| .def(py::init<int32_t>()) | |||
| .def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); }) | |||
| .def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); }) | |||
| .def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); }) | |||
| .def("GetDataInfo", | |||
| [](ToDevice &self) { | |||
| std::vector<DataType> types_c; | |||
| std::vector<TensorShape> shapes_c; | |||
| { | |||
| py::gil_scoped_release rel; | |||
| THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c)); | |||
| } | |||
| py::list types, shapes; | |||
| for (auto el : types_c) { | |||
| types.append(el.AsNumpyType()); | |||
| py::list shape; | |||
| } | |||
| for (auto el : shapes_c) { | |||
| py::list shape = el.AsPyList(); | |||
| shapes.append(shape); | |||
| } | |||
| return py::make_tuple(types, shapes); | |||
| }) | |||
| .def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; }); | |||
| })); | |||
| PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) { | |||
| (void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>( | |||
| *m, "PythonSaveToDisk") | |||
| .def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) { | |||
| auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType); | |||
| THROW_IF_ERROR(save->ValidateParams()); | |||
| return save; | |||
| })) | |||
| .def("Init", | |||
| [](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||
| .def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); }); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER( | |||
| SchemaObj, 0, ([](const py::module *m) { | |||
| (void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj") | |||
| .def(py::init([](std::string schema_file) { | |||
| auto schema = std::make_shared<SchemaObj>(schema_file); | |||
| THROW_IF_ERROR(schema->init()); | |||
| return schema; | |||
| })) | |||
| .def("add_column", [](SchemaObj &self, std::string name, TypeId de_type, | |||
| std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); }) | |||
| .def("add_column", [](SchemaObj &self, std::string name, std::string de_type, | |||
| std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); }) | |||
| .def("add_column", | |||
| [](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | |||
| .def("add_column", [](SchemaObj &self, std::string name, | |||
| std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | |||
| .def("to_json", &SchemaObj::to_json) | |||
| .def("to_string", &SchemaObj::to_string) | |||
| .def("from_string", | |||
| [](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); }) | |||
| .def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); }) | |||
| .def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); }) | |||
| .def("get_num_rows", &SchemaObj::get_num_rows) | |||
| .def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; }); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,6 @@ | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h" | |||
| #include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h" | |||
| @@ -1,265 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <stack> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/client.h" // DE client | |||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "pybind11/numpy.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using json = nlohmann::json; | |||
| using DsOpPtr = std::shared_ptr<DatasetOp>; | |||
| class CacheClient; | |||
| // enum for the dataset operator names | |||
| enum OpName { | |||
| kShuffle, | |||
| kMindrecord, | |||
| kBatch, | |||
| kBucketBatch, | |||
| kBarrier, | |||
| kCache, | |||
| kRepeat, | |||
| kSkip, | |||
| kTake, | |||
| kZip, | |||
| kConcat, | |||
| kMap, | |||
| kFilter, | |||
| kDeviceQueue, | |||
| kGenerator, | |||
| kRename, | |||
| kTfReader, | |||
| kProject, | |||
| kImageFolder, | |||
| kMnist, | |||
| kManifest, | |||
| kVoc, | |||
| kCoco, | |||
| kCifar10, | |||
| kCifar100, | |||
| kCelebA, | |||
| kRandomData, | |||
| kTextFile, | |||
| kBuildVocab, | |||
| kClue, | |||
| kEpochCtrl, | |||
| kSentencePieceVocab, | |||
| kCsv | |||
| }; | |||
| // The C++ binder class that we expose to the python script. | |||
| class DEPipeline { | |||
| public: | |||
| DEPipeline(); | |||
| ~DEPipeline(); | |||
| // Function to add a Node to the Execution Tree. | |||
| Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); | |||
| // Function to add a child and parent relationship. | |||
| static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); | |||
| // Function to assign the node as root. | |||
| Status AssignRootNode(const DsOpPtr &dataset_op); | |||
| // Function to get the column names in the last node in the tree in order | |||
| Status GetColumnNames(py::list *output); | |||
| // Function to prepare the tree for execution | |||
| Status PrepareTree(const int32_t num_epochs); | |||
| // Function to launch the tree execution. | |||
| Status LaunchTreeExec(); | |||
| // Get a row of data as dictionary of column name to the value. | |||
| Status GetNextAsMap(py::dict *output); | |||
| // Get a row of data as list. | |||
| Status GetNextAsList(py::list *output); | |||
| Status GetOutputShapes(py::list *output); | |||
| Status GetOutputTypes(py::list *output); | |||
| Status GetDataInfo(py::list *types, py::list *shapes); | |||
| Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type); | |||
| int GetDatasetSize() const; | |||
| int GetBatchSize() const; | |||
| int GetRepeatCount() const; | |||
| Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| template <typename T, typename S> | |||
| Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||
| std::unique_ptr<S> *s, bool need_convert = false); | |||
| Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||
| const TensorRow &row, json *schema, std::vector<std::string> *index_fields); | |||
| Status FetchDataFromTensorRow(const TensorRow &row, | |||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data, | |||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data); | |||
| Status BuildMindrecordSamplerChain(const py::handle &handle, | |||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||
| int num_padded); | |||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| void PrintTree(); | |||
| int32_t GetNumClasses() const; | |||
| Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status SetBatchParameters(const py::dict &args); | |||
| Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status StopSend(); | |||
| Status ContinueSend(); | |||
| Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||
| std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||
| private: | |||
| // Execution tree that links the dataset operators. | |||
| std::shared_ptr<ExecutionTree> tree_; | |||
| std::unique_ptr<DatasetIterator> iterator_; | |||
| static Status ParsePadInfo(py::handle value, PadInfo *pad_info); | |||
| /// \brief Helper function to inject a cache operator over top of the current operation being built. | |||
| /// \param[in] cache_client The client to use for caching | |||
| /// \param[in] num_workers The number of workers to use in the cache op | |||
| /// \param[in] input_op The operator to build the cache on top of | |||
| /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be | |||
| /// the cache operator | |||
| /// \return Status return code | |||
| Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op, | |||
| std::shared_ptr<DatasetOp> *cache_op); | |||
| /// \brief Helper function to inject a shuffle operator over top of the current operation being built. | |||
| /// \param[in] shuffle_size The size to use in the shuffle buffer | |||
| /// \param[in] input_op The operator to build shuffle on top of | |||
| /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be | |||
| /// the shuffle operator | |||
| /// \return Status return code | |||
| Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | |||
| std::shared_ptr<DatasetOp> *shuffle_op); | |||
| /// \brief Helper function to compute the shuffle size | |||
| /// \param[in] num_files The number of files in the dataset | |||
| /// \param[in] num_devices The number of devices in the dataset | |||
| /// \param[in] num_rows The number of rows in the dataset | |||
| /// \param[in] total_rows An upper bound on the total rows in the dataset | |||
| /// \param[out] shuffle_size The resultant computed shuffle size | |||
| /// \return Status return code | |||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int64_t *shuffle_size); | |||
| int batch_size_; | |||
| int repeat_num_; | |||
| int num_rows_; | |||
| int num_classes_; | |||
| int temp_batch_size_; | |||
| bool temp_drop_remainder_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||
| @@ -0,0 +1,265 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); } | |||
| int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | |||
| int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | |||
| bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } | |||
| std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); } | |||
| std::set<std::string> toStringSet(const std::optional<py::list> list) { | |||
| std::set<std::string> set; | |||
| if (list) { | |||
| for (auto l : *list) { | |||
| if (!l.is_none()) { | |||
| (void)set.insert(py::str(l)); | |||
| } | |||
| } | |||
| } | |||
| return set; | |||
| } | |||
| std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) { | |||
| std::map<std::string, int32_t> map; | |||
| if (dict) { | |||
| for (auto p : *dict) { | |||
| (void)map.emplace(toString(p.first), toInt(p.second)); | |||
| } | |||
| } | |||
| return map; | |||
| } | |||
| std::vector<std::string> toStringVector(const std::optional<py::list> list) { | |||
| std::vector<std::string> vector; | |||
| if (list) { | |||
| for (auto l : *list) { | |||
| if (l.is_none()) | |||
| vector.emplace_back(""); | |||
| else | |||
| vector.push_back(py::str(l)); | |||
| } | |||
| } | |||
| return vector; | |||
| } | |||
| std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) { | |||
| std::pair<int64_t, int64_t> pair; | |||
| if (tuple) { | |||
| pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1])); | |||
| } | |||
| return pair; | |||
| } | |||
| std::vector<std::pair<int, int>> toPairVector(const py::list list) { | |||
| std::vector<std::pair<int, int>> vector; | |||
| if (list) { | |||
| for (auto data : list) { | |||
| auto l = data.cast<py::tuple>(); | |||
| if (l[1].is_none()) | |||
| vector.emplace_back(toInt64(l[0]), 0); | |||
| else | |||
| vector.emplace_back(toInt64(l[0]), toInt64(l[1])); | |||
| } | |||
| } | |||
| return vector; | |||
| } | |||
| std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) { | |||
| std::vector<std::shared_ptr<TensorOperation>> vector; | |||
| if (operations) { | |||
| for (auto op : *operations) { | |||
| std::shared_ptr<TensorOp> tensor_op; | |||
| if (py::isinstance<TensorOp>(op)) { | |||
| tensor_op = op.cast<std::shared_ptr<TensorOp>>(); | |||
| } else if (py::isinstance<py::function>(op)) { | |||
| tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>()); | |||
| } else { | |||
| THROW_IF_ERROR( | |||
| []() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }()); | |||
| } | |||
| vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op)); | |||
| } | |||
| } | |||
| return vector; | |||
| } | |||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) { | |||
| std::vector<std::shared_ptr<DatasetNode>> vector; | |||
| vector.push_back(self); | |||
| if (datasets) { | |||
| for (auto ds : *datasets) { | |||
| if (py::isinstance<DatasetNode>(ds)) { | |||
| vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>()); | |||
| } else { | |||
| THROW_IF_ERROR( | |||
| []() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }()); | |||
| } | |||
| } | |||
| } | |||
| return vector; | |||
| } | |||
| std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) { | |||
| if (py_sampler) { | |||
| std::shared_ptr<SamplerObj> sampler_obj; | |||
| if (!isMindDataset) { | |||
| // Common Sampler | |||
| std::shared_ptr<SamplerRT> sampler; | |||
| auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create"); | |||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||
| } else { | |||
| // Mindrecord Sampler | |||
| std::shared_ptr<mindrecord::ShardOperator> sampler; | |||
| auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset"); | |||
| sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||
| } | |||
| return sampler_obj; | |||
| } else { | |||
| THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }()); | |||
| } | |||
| return nullptr; | |||
| } | |||
| // Here we take in a python object, that holds a reference to a C++ object | |||
| std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) { | |||
| if (cc) { | |||
| std::shared_ptr<DatasetCache> built_cache; | |||
| // Common Sampler | |||
| built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value())); | |||
| return built_cache; | |||
| } else { | |||
| // don't need to check here as cache is not enabled. | |||
| return nullptr; | |||
| } | |||
| } | |||
| ShuffleMode toShuffleMode(const int32_t shuffle) { | |||
| if (shuffle == 0) return ShuffleMode::kFalse; | |||
| if (shuffle == 1) return ShuffleMode::kFiles; | |||
| if (shuffle == 2) return ShuffleMode::kGlobal; | |||
| return ShuffleMode(); | |||
| } | |||
| std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) { | |||
| std::vector<std::shared_ptr<CsvBase>> vector; | |||
| if (csv_bases) { | |||
| for (auto base : *csv_bases) { | |||
| if (py::isinstance<py::int_>(base)) { | |||
| vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base))); | |||
| } else if (py::isinstance<py::float_>(base)) { | |||
| vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base))); | |||
| } else if (py::isinstance<py::str>(base)) { | |||
| vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base))); | |||
| } else { | |||
| THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }()); | |||
| } | |||
| } | |||
| } | |||
| return vector; | |||
| } | |||
| Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json, | |||
| std::map<std::string, std::string> *sample_bytes) { | |||
| for (const py::handle &key : padded_sample) { | |||
| if (py::isinstance<py::bytes>(padded_sample[key])) { | |||
| (*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>(); | |||
| // py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam | |||
| (*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object(); | |||
| } else { | |||
| nlohmann::json obj_json; | |||
| if (padded_sample[key].is_none()) { | |||
| obj_json = nullptr; | |||
| } else if (py::isinstance<py::int_>(padded_sample[key])) { | |||
| obj_json = padded_sample[key].cast<int64_t>(); | |||
| } else if (py::isinstance<py::float_>(padded_sample[key])) { | |||
| obj_json = padded_sample[key].cast<double>(); | |||
| } else if (py::isinstance<py::str>(padded_sample[key])) { | |||
| obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes | |||
| } else { | |||
| MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]); | |||
| RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed"); | |||
| } | |||
| (*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json; | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) { | |||
| for (auto p : value) { | |||
| if (!p.second.is_none()) { | |||
| auto tp = py::reinterpret_borrow<py::tuple>(p.second); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); | |||
| TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); | |||
| std::shared_ptr<Tensor> pad_val = nullptr; | |||
| if (py::isinstance<py::str>(tp[1])) { | |||
| std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val), | |||
| "Cannot create pad_value Tensor"); | |||
| } else { | |||
| float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val), | |||
| "Cannot create pad_value Tensor"); | |||
| pad_val->SetItemAt<float>({}, pad_val_float); | |||
| } | |||
| (void)pad_info->insert({toString(p.first), {shape, pad_val}}); | |||
| } else { // tuple is None | |||
| (void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}}); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) { | |||
| std::shared_ptr<TensorOp> py_func; | |||
| if (!func.is_none()) { | |||
| py::function py_function = func.cast<py::function>(); | |||
| py_func = std::make_shared<PyFuncOp>(py_function, data_type); | |||
| } else { | |||
| py_func = nullptr; | |||
| } | |||
| return py_func; | |||
| } | |||
| py::list shapesToListOfShape(std::vector<TensorShape> shapes) { | |||
| py::list shape_list; | |||
| for (const auto &shape : shapes) { | |||
| shape_list.append(shape.AsVector()); | |||
| } | |||
| return shape_list; | |||
| } | |||
| py::list typesToListOfType(std::vector<DataType> types) { | |||
| py::list type_list; | |||
| for (const auto &type : types) { | |||
| type_list.append(type.AsNumpyType()); | |||
| } | |||
| return type_list; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/include/datasets.h" | |||
| #include "minddata/dataset/include/samplers.h" | |||
| #include "minddata/dataset/include/transforms.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||
| #include "minddata/dataset/kernels/py_func_op.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| float toFloat(const py::handle &handle); | |||
| int toInt(const py::handle &handle); | |||
| int64_t toInt64(const py::handle &handle); | |||
| bool toBool(const py::handle &handle); | |||
| std::string toString(const py::handle &handle); | |||
| std::set<std::string> toStringSet(const std::optional<py::list> list); | |||
| std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict); | |||
| std::vector<std::string> toStringVector(const std::optional<py::list> list); | |||
| std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple); | |||
| std::vector<std::pair<int, int>> toPairVector(const py::list list); | |||
| std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations); | |||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets); | |||
| std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false); | |||
| std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc); | |||
| ShuffleMode toShuffleMode(const int32_t shuffle); | |||
| std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases); | |||
| std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type); | |||
| Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json, | |||
| std::map<std::string, std::string> *sample_bytes); | |||
| Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info); | |||
| py::list shapesToListOfShape(std::vector<TensorShape> shapes); | |||
| py::list typesToListOfType(std::vector<DataType> types); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||
| @@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||
| return sampler; | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // PreBuiltOperation | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) | |||
| : sp_(std::move(sampler)), sp_minddataset_(nullptr) {} | |||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | |||
| : sp_(nullptr), sp_minddataset_(std::move(sampler)) {} | |||
| #endif | |||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | |||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||
| #endif | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | |||
| // runtime mindrecord sampler object | |||
| @@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() { | |||
| std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | |||
| // PreBuiltOperation | |||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {} | |||
| Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | |||
| std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | |||
| // RandomApplyOperation | |||
| RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | |||
| : transforms_(transforms), prob_(prob) {} | |||
| @@ -18,6 +18,7 @@ set(SRC_FILES_LIST | |||
| dataset_iterator.cc | |||
| tree_adapter.cc | |||
| runtime_context.cc | |||
| python_runtime_context.cc | |||
| consumers/tree_consumer.cc | |||
| ) | |||
| if (ENABLE_PYTHON) | |||
| @@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) { | |||
| std::unordered_map<std::string, TensorPtr> row; | |||
| std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec; | |||
| Status s; | |||
| { | |||
| py::gil_scoped_release gil_release; | |||
| RETURN_IF_NOT_OK(GetNextAsMap(&row)); | |||
| s = GetNextAsOrderedPair(&vec); | |||
| } | |||
| for (auto el : row) { | |||
| (*out)[common::SafeCStr(el.first)] = el.second; | |||
| RETURN_IF_NOT_OK(s); | |||
| // Generate Python dict, python dict maintains its insertion order | |||
| for (const auto &pair : vec) { | |||
| (*out)[common::SafeCStr(pair.first)] = pair.second; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status PythonBuildVocabConsumer::Start() { | |||
| py::gil_scoped_release gil_release; | |||
| return BuildVocabConsumer::Start(); | |||
| } | |||
| Status PythonSaveToDisk::Save() { | |||
| py::gil_scoped_release gil_release; | |||
| return SaveToDisk::Save(); | |||
| } | |||
| PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType) | |||
| : SaveToDisk(datasetPath, numFiles, datasetType) {} | |||
| Status PythonTreeGetters::GetRow(TensorRow *r) { | |||
| py::gil_scoped_release gil_release; | |||
| return TreeGetters::GetRow(r); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| @@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer { | |||
| /// \return Status error code | |||
| Status GetNextAsDict(py::dict *out); | |||
| }; | |||
| class PythonBuildVocabConsumer : public BuildVocabConsumer { | |||
| public: | |||
| Status Start() override; | |||
| }; | |||
| class PythonSaveToDisk : public SaveToDisk { | |||
| public: | |||
| PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType); | |||
| Status Save() override; | |||
| }; | |||
| class PythonTreeGetters : public TreeGetters { | |||
| public: | |||
| Status GetRow(TensorRow *r) override; | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | |||
| #include "minddata/dataset/engine/tree_adapter.h" | |||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/mindrecord/include/shard_header.h" | |||
| @@ -35,7 +36,7 @@ namespace mindspore::dataset { | |||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||
| Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); } | |||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | |||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); } | |||
| // IteratorConsumer | |||
| Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { | |||
| @@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||
| return Status::OK(); | |||
| } | |||
| Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty."); | |||
| TensorRow curr_row; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row)); | |||
| RETURN_OK_IF_TRUE(curr_row.empty()); | |||
| size_t num_cols = curr_row.size(); // num_cols is non-empty. | |||
| // order the column names according to their ids | |||
| if (column_order_.empty()) { | |||
| const int32_t invalid_col_id = -1; | |||
| column_order_.resize(num_cols, {std::string(), invalid_col_id}); | |||
| for (const auto &itr : tree_adapter_->GetColumnNameMap()) { | |||
| int32_t ind = itr.second; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds."); | |||
| column_order_[ind] = std::make_pair(itr.first, ind); | |||
| } | |||
| // error check, make sure the ids in col_name_id_map are continuous and starts from 0 | |||
| for (const auto &col : column_order_) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous."); | |||
| } | |||
| } | |||
| vec->reserve(num_cols); | |||
| std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec), | |||
| [curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); }); | |||
| return Status::OK(); | |||
| } | |||
| // ToDevice | |||
| Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } | |||
| @@ -81,7 +114,6 @@ Status ToDevice::Send() { | |||
| RETURN_IF_NOT_OK(tree_adapter_->Launch()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_IF_NOT_OK(root->GetNextBuffer(&db)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -101,9 +133,36 @@ Status ToDevice::Stop() { | |||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | |||
| op->StopSend(); | |||
| return Status::OK(); | |||
| } | |||
| Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) { | |||
| // tree_.root() must be DeviceQueueOp | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp"); | |||
| DATA_INFO data_info; | |||
| RETURN_IF_NOT_OK(op->GetDataInfo(&data_info)); | |||
| for (auto el : data_info) { | |||
| types->push_back(el.first); | |||
| shapes->push_back(el.second); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status ToDevice::Terminate() { | |||
| #ifdef ENABLE_TDTQUE | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | |||
| op->StopWaiting(); | |||
| #endif | |||
| return TreeConsumer::Terminate(); | |||
| } | |||
| #ifndef ENABLE_ANDROID | |||
| // SaveToDisk | |||
| Status SaveToDisk::ValidateParams() { | |||
| @@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row, | |||
| if (column_type == DataType::DE_INT8) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<int8_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT16) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<int16_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT16) { | |||
| std::unique_ptr<int32_t> data; | |||
| std::unique_ptr<uint16_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT8) { | |||
| std::unique_ptr<uint8_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT32) { | |||
| std::unique_ptr<int32_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_UINT32) { | |||
| std::unique_ptr<int64_t> data; | |||
| std::unique_ptr<uint32_t> dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_INT64) { | |||
| std::unique_ptr<int64_t> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_FLOAT32) { | |||
| std::unique_ptr<float> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_FLOAT64) { | |||
| std::unique_ptr<double> data, dummy; | |||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||
| RETURN_IF_NOT_OK(s); | |||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | |||
| } else if (column_type == DataType::DE_STRING) { | |||
| @@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row, | |||
| } | |||
| template <typename T, typename S> | |||
| Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||
| std::unique_ptr<S> *s, bool need_convert) { | |||
| if (nullptr == src) { | |||
| @@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & | |||
| } | |||
| #endif | |||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { | |||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | |||
| } | |||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||
| Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | |||
| if (init_flag_) { | |||
| return Status::OK(); | |||
| } | |||
| Status s = tree_adapter_->Compile(std::move(d), 1); | |||
| if (!s.IsError()) { | |||
| init_flag_ = true; | |||
| } | |||
| return s; | |||
| } | |||
| bool TreeGetters::isInitialized() { return init_flag_; } | |||
| Status TreeGetters::GetRow(TensorRow *row) { | |||
| if (row_flag_ == false) { | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(row)); | |||
| row_flag_ = true; | |||
| } | |||
| root_ = std::move(d); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); } | |||
| Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ == -1) { | |||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize))); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | |||
| dataset_size_ = *dataset_size; | |||
| if (*dataset_size == -1) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| int64_t num_rows = 0; | |||
| TensorRow row = row_; | |||
| while (row.size() != 0) { | |||
| num_rows++; | |||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||
| if (*dataset_size == -1) { // run through the tree and get everything | |||
| TensorRow row; | |||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||
| int64_t row_cnt = 0; | |||
| while (!row.empty()) { | |||
| ++row_cnt; | |||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||
| } | |||
| dataset_size_ = num_rows; | |||
| *dataset_size = row_cnt; | |||
| } | |||
| dataset_size_ = *dataset_size; // save the previous result | |||
| } | |||
| *dataset_size = dataset_size_; | |||
| @@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||
| } | |||
| Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| for (auto ts : row_) { | |||
| DataType dt = ts->type(); | |||
| types->push_back(dt); | |||
| } | |||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); | |||
| if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); | |||
| std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types), | |||
| [](const TensorPtr &t) { return t->type(); }); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { | |||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||
| for (auto ts : row_) { | |||
| TensorShape t = ts->shape(); | |||
| shapes->push_back(t); | |||
| } | |||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); | |||
| if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); | |||
| std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes), | |||
| [](const TensorPtr &t) { return t->shape(); }); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetBatchSize(int64_t *batch_size) { | |||
| RETURN_IF_NOT_OK(InternalInit()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| *batch_size = root->GetTreeBatchSize(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { | |||
| RETURN_IF_NOT_OK(InternalInit()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| *repeat_count = root->GetTreeRepeatCount(); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetNumClasses(int64_t *num_classes) { | |||
| RETURN_IF_NOT_OK(InternalInit()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetColumnNames(std::vector<std::string> *output) { | |||
| RETURN_IF_NOT_OK(InternalInit()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map(); | |||
| if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty."); | |||
| std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(), | |||
| column_name_id_map.end()); | |||
| std::sort(column_name_id_vector.begin(), column_name_id_vector.end(), | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty."); | |||
| std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end()); | |||
| std::sort(col_name_id_vec.begin(), col_name_id_vec.end(), | |||
| [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { | |||
| return a.second < b.second; | |||
| }); | |||
| for (auto item : column_name_id_vector) { | |||
| (*output).push_back(item.first); | |||
| } | |||
| std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output), | |||
| [](const std::pair<std::string, int32_t> &p) { return p.first; }); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | |||
| RETURN_IF_NOT_OK(InternalInit()); | |||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||
| RETURN_UNEXPECTED_IF_NULL(root); | |||
| RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing)); | |||
| return Status::OK(); | |||
| } | |||
| Status TreeGetters::InternalInit(int8_t type) { | |||
| if (init_flag_) return Status::OK(); | |||
| tree_adapter_->SetPrePassOverride([&type](OptPass pre) { | |||
| pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type))); | |||
| return pre; | |||
| }); | |||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | |||
| if (!s.IsError()) init_flag_ = true; | |||
| return s; | |||
| } | |||
| Status TreeGetters::InternalInit() { | |||
| if (init_flag_) return Status::OK(); | |||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | |||
| if (!s.IsError()) init_flag_ = true; | |||
| return s; | |||
| } | |||
| Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); } | |||
| Status BuildVocabConsumer::Start() { | |||
| @@ -41,7 +41,7 @@ class TreeConsumer { | |||
| /// \return Status error code. | |||
| virtual Status Init(std::shared_ptr<DatasetNode> d); | |||
| Status Terminate(); | |||
| virtual Status Terminate(); | |||
| protected: | |||
| /// The class owns the tree_adapter that handles execution tree operations. | |||
| @@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer { | |||
| /// \return Status error code | |||
| Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); | |||
| /// Returns the next row in as a map | |||
| /// \param[out] out std::map of string to Tensor | |||
| /// \return Status error code | |||
| Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec); | |||
| protected: | |||
| /// Method to return the name of the consumer | |||
| /// \return string | |||
| @@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer { | |||
| private: | |||
| int32_t num_epochs_; | |||
| std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id | |||
| }; | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer { | |||
| /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | |||
| /// would be written to disk) | |||
| /// \return Status error code | |||
| Status Save(); | |||
| virtual Status Save(); | |||
| protected: | |||
| /// Method to return the name of the consumer | |||
| @@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer { | |||
| private: | |||
| template <typename T, typename S> | |||
| Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||
| std::unique_ptr<S> *s, bool need_convert = false); | |||
| @@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer { | |||
| /// Consumer that iterates over the dataset and send it to a device | |||
| class ToDevice : public TreeConsumer { | |||
| public: | |||
| explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | |||
| : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||
| explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | |||
| ~ToDevice() = default; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| Status Terminate() override; | |||
| /// Send the data to device | |||
| /// \return Status error code | |||
| Status Send(); | |||
| virtual Status Send(); | |||
| /// Stop to send data to device | |||
| /// \return Status error code | |||
| Status Stop(); | |||
| virtual Status Stop(); | |||
| /// Continue to send data to device | |||
| /// \return Status error code | |||
| Status Continue(); | |||
| virtual Status Continue(); | |||
| /// Get data info from TDT | |||
| /// \return Status error code | |||
| virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes); | |||
| protected: | |||
| /// Method to return the name of the consumer | |||
| @@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer { | |||
| std::string Name() override { return "ToDevice"; } | |||
| private: | |||
| std::string device_type_; | |||
| bool send_epoch_end_; | |||
| int32_t num_epochs_; | |||
| }; | |||
| @@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer { | |||
| TreeGetters(); | |||
| ~TreeGetters() = default; | |||
| Status Init(std::shared_ptr<DatasetNode> d) override; | |||
| Status GetDatasetSize(int64_t *size); | |||
| Status GetOutputTypes(std::vector<DataType> *types); | |||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | |||
| @@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer { | |||
| Status GetNumClasses(int64_t *num_classes); | |||
| Status GetColumnNames(std::vector<std::string> *output); | |||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); | |||
| bool isInitialized(); | |||
| std::string Name() override { return "TreeGetters"; } | |||
| Status GetRow(TensorRow *r); | |||
| virtual Status GetRow(TensorRow *r); | |||
| private: | |||
| std::shared_ptr<DatasetNode> root_; | |||
| int64_t dataset_size_; | |||
| TensorRow row_; | |||
| TensorRow first_row_; | |||
| bool init_flag_; // indicate whether the tree has initialized | |||
| bool row_flag_; // indicate whether the first row has been stored in row_ | |||
| Status InternalInit(int8_t type); | |||
| Status InternalInit(); | |||
| }; | |||
| class BuildVocabConsumer : public TreeConsumer { | |||
| @@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer { | |||
| /// Start consuming | |||
| /// \return Status error code | |||
| Status Start(); | |||
| virtual Status Start(); | |||
| protected: | |||
| /// Method to return the name of the consumer | |||
| @@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||
| } | |||
| // Constructor of the ConcatOp. | |||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||
| std::vector<std::pair<int, int>> children_start_end_index) | |||
| ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index) | |||
| : PipelineOp(op_connector_size), | |||
| children_num_(0), | |||
| sampler_(sampler), | |||
| @@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp { | |||
| // @note The builder class should be used to call it | |||
| // @param op_connector_size - connector size | |||
| explicit ConcatOp(int32_t op_connector_size); | |||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||
| std::vector<std::pair<int, int>> children_start_end_index); | |||
| ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index); | |||
| // Destructor | |||
| ~ConcatOp() = default; | |||
| @@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return Name of the current Op | |||
| virtual std::string Name() const = 0; | |||
| /// Op name and ID getter | |||
| /// \return Name and ID of the current Op | |||
| std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; } | |||
| /// Execution Tree getter | |||
| /// \return Pointer to the ExecutionTree the current op belongs to, no ownership | |||
| ExecutionTree *Tree() { return tree_; } | |||
| @@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() { | |||
| } | |||
| tree_->SetFinished(); | |||
| MS_LOG(INFO) << "Device queue total batch is " << send_batch; | |||
| return Status::OK(); | |||
| } | |||
| @@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>; | |||
| using DATA_INFO_QUEUE = Queue<DATA_INFO>; | |||
| const int kDataInfoQueueCapacity = 128; | |||
| class DeviceQueueOp : public PipelineOp { | |||
| public: | |||
| static const uint32_t INVALID_HANDLE = 0xffffffffUL; | |||
| @@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp { | |||
| #ifdef ENABLE_TDTQUE | |||
| Status SendDataToAscend(); | |||
| bool ascend_keep_waiting_; | |||
| #endif | |||
| #ifdef ENABLE_GPUQUE | |||
| @@ -169,7 +169,7 @@ Status MapOp::operator()() { | |||
| } | |||
| // The operator class just starts off threads by calling the tree_ function | |||
| rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); | |||
| rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID()); | |||
| // Synchronize with TaskManager | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| @@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) { | |||
| } | |||
| if (image_ids_.size() == 0) { | |||
| RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | |||
| } else { | |||
| num_rows = image_ids_.size(); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| @@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows = num_rows_, sample_size; | |||
| int64_t num_rows = num_rows_; | |||
| if (num_rows_ <= 0) { | |||
| std::shared_ptr<ShardOperator> op; | |||
| // The last operator is parent sampler | |||
| std::shared_ptr<ShardOperator> op = operators_.back(); | |||
| RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); | |||
| } | |||
| sample_size = operators_[0]->GetNumSamples(num_rows, 0); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| *dataset_size = num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| @@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() { | |||
| return Status::OK(); | |||
| } | |||
| // Get the file list of the specific shard ID | |||
| Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) { | |||
| if (!shard_filenames->empty()) { | |||
| RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n"); | |||
| } | |||
| for (int index = 0; index < dataset_files_list_.size(); index++) { | |||
| if (index % num_devices_ == device_id_) { | |||
| shard_filenames->push_back(dataset_files_list_.at(index)); | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| @@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | |||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | |||
| num_rows = num_rows_per_shard_; | |||
| } else { | |||
| RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_)); | |||
| std::vector<std::string> shard_file_list; | |||
| RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); | |||
| RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list)); | |||
| } | |||
| } | |||
| sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | |||
| @@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp { | |||
| // @return - Status | |||
| Status ComputeColMap() override; | |||
| // Private function for computing the file list of the specific shard ID. This is because in distributed scenario, | |||
| // data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case. | |||
| // @return - Status - the status code returned. | |||
| Status GetShardFileList(std::vector<std::string> *shard_filenames); | |||
| int32_t device_id_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| @@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) { | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| num_rows = static_cast<int64_t>(op->image_ids_.size()); | |||
| } | |||
| } else { | |||
| num_rows = image_ids_.size(); | |||
| } | |||
| sample_size = sampler_->CalculateNumSamples(num_rows); | |||
| *dataset_size = sample_size; | |||
| @@ -141,8 +141,6 @@ Status ExecutionTree::Launch() { | |||
| " Expected state: " + std::to_string(static_cast<int>(kDeTStateReady)); | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| std::ostringstream ss; | |||
| ss << *this; | |||
| // Profiling infrastructures need to be initialized before Op launching | |||
| if (profiling_manager_->IsProfilingEnable()) { | |||
| @@ -152,6 +150,8 @@ Status ExecutionTree::Launch() { | |||
| RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor()); | |||
| } | |||
| std::ostringstream ss; | |||
| ss << *this; | |||
| MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); | |||
| for (auto itr = this->begin(); itr != this->end(); ++itr) { | |||
| // An inlined operator is one that has an output connector size of 0, and it does not | |||
| @@ -160,7 +160,7 @@ Status ExecutionTree::Launch() { | |||
| // the launching tree/user thread. Do not exec any thread for an inlined op. | |||
| itr->state_ = DatasetOp::OpState::kDeOpRunning; | |||
| if (!itr->inlined()) { | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr))); | |||
| // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp | |||
| } | |||
| } | |||
| @@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_ | |||
| // Given the number of workers, launches the worker entry function for each. Essentially a | |||
| // wrapper for the TaskGroup handling that is stored inside the execution tree. | |||
| Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) { | |||
| Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) { | |||
| // Launch the workers | |||
| for (int32_t i = 0; i < num_workers; ++i) { | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); | |||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -150,7 +150,7 @@ class ExecutionTree { | |||
| // @param num_workers - The number of workers to launch | |||
| // @param func - The function entry point that workers will execute | |||
| // @return Status - The error code return | |||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func); | |||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); | |||
| // Getter method | |||
| // @return shared_ptr to the root operator | |||
| @@ -1,4 +1,5 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| add_library(engine-ir-cache OBJECT | |||
| dataset_cache_impl.cc) | |||
| pre_built_dataset_cache.cc | |||
| dataset_cache_impl.cc) | |||
| @@ -18,8 +18,8 @@ | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore::dataset { | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||
| /// \return Status Error code | |||
| Status DatasetCacheImpl::Build() { | |||
| @@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindspore::dataset | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -24,8 +24,8 @@ | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| namespace mindspore::dataset { | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// DatasetCache is the IR of CacheClient | |||
| class DatasetCacheImpl : public DatasetCache { | |||
| public: | |||
| @@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache { | |||
| std::optional<int32_t> num_connections_; | |||
| std::optional<int32_t> prefetch_sz_; | |||
| }; | |||
| } // namespace mindspore::dataset | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||
| /// \return Status Error code | |||
| Status PreBuiltDatasetCache::Build() { | |||
| // we actually want to keep a reference of the runtime object so it can be shared by different pipelines | |||
| return Status::OK(); | |||
| } | |||
| Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||
| std::shared_ptr<CacheOp> cache_op = nullptr; | |||
| RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); | |||
| *ds = cache_op; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// DatasetCache is the IR of CacheClient | |||
| class PreBuiltDatasetCache : public DatasetCache { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param cc a pre-built cache client | |||
| explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {} | |||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||
| /// \return Status Error code | |||
| Status Build() override; | |||
| Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||
| Status ValidateParams() override { return Status::OK(); } | |||
| private: | |||
| std::shared_ptr<CacheClient> cache_client_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||
| @@ -31,7 +31,7 @@ namespace dataset { | |||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | |||
| std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | |||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||
| std::function<TensorRow(TensorRow)> element_length_function, | |||
| std::shared_ptr<TensorOp> element_length_function, | |||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | |||
| bool drop_remainder) | |||
| : column_names_(column_names), | |||
| @@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( | |||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| std::shared_ptr<TensorOp> c_func; | |||
| if (element_length_function_ != nullptr) { | |||
| c_func = std::make_shared<CFuncOp>(element_length_function_); | |||
| } else { | |||
| c_func = nullptr; | |||
| bucket_boundaries_.insert(bucket_boundaries_.begin(), 0); | |||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>( | |||
| column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_, | |||
| pad_to_bucket_boundary_, drop_remainder_, connector_que_size_)); | |||
| if (bucket_boundaries_[0] == 0) { | |||
| bucket_boundaries_.erase(bucket_boundaries_.begin()); | |||
| } | |||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||
| c_func, pad_info_, pad_to_bucket_boundary_, | |||
| drop_remainder_, connector_que_size_)); | |||
| return node_ops; | |||
| } | |||
| @@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| /// \brief Constructor | |||
| BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | |||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | |||
| std::function<TensorRow(TensorRow)> element_length_function = nullptr, | |||
| std::shared_ptr<TensorOp> element_length_function = nullptr, | |||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | |||
| bool pad_to_bucket_boundary = false, bool drop_remainder = false); | |||
| @@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode { | |||
| std::vector<std::string> column_names_; | |||
| std::vector<int32_t> bucket_boundaries_; | |||
| std::vector<int32_t> bucket_batch_sizes_; | |||
| std::function<TensorRow(TensorRow)> element_length_function_; | |||
| std::shared_ptr<TensorOp> element_length_function_; | |||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; | |||
| bool pad_to_bucket_boundary_; | |||
| bool drop_remainder_; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | |||
| @@ -27,7 +28,15 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Function to build ConcatOp | |||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | |||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets, | |||
| const std::shared_ptr<SamplerObj> &sampler, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index) | |||
| : sampler_(sampler), | |||
| children_flag_and_nums_(children_flag_and_nums), | |||
| children_start_end_index_(children_start_end_index) { | |||
| this->children = datasets; | |||
| } | |||
| Status ConcatNode::ValidateParams() { | |||
| if (children.size() < 2) { | |||
| @@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() { | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) || | |||
| (!children_flag_and_nums_.empty() && children_start_end_index_.empty())) { | |||
| std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together"; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | |||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||
| } else { | |||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_, | |||
| children_start_end_index_)); | |||
| } | |||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||
| return node_ops; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| @@ -29,7 +30,10 @@ namespace dataset { | |||
| class ConcatNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets); | |||
| explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets, | |||
| const std::shared_ptr<SamplerObj> &sampler = nullptr, | |||
| const std::vector<std::pair<int, int>> &children_flag_and_nums = {}, | |||
| const std::vector<std::pair<int, int>> &children_start_end_index = {}); | |||
| /// \brief Destructor | |||
| ~ConcatNode() = default; | |||
| @@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode { | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -240,6 +240,7 @@ DatasetNode::DatasetNode() { | |||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| connector_que_size_ = cfg->op_connector_size(); | |||
| worker_connector_size_ = cfg->worker_connector_size(); | |||
| build_status = Status::OK(); // remove me after changing return val of Build() | |||
| } | |||
| // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | |||
| @@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { | |||
| // This method will only be called if its derived class does not implement one. | |||
| return p->VisitAfter(shared_from_this(), modified); | |||
| } | |||
| Status DatasetNode::GetShardId(int32_t *shard_id) { | |||
| if (!Children().empty()) { | |||
| // Get shard id from the child node | |||
| return Children()[0]->GetShardId(shard_id); | |||
| } else { | |||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node"); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| virtual Status GetShardId(int32_t *shard_id) { | |||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||
| } | |||
| virtual Status GetShardId(int32_t *shard_id); | |||
| /// \brief Setter function for runtime number of workers | |||
| /// \param[in] num_workers The number of threads in this operator | |||
| @@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| /// \return Status of the node visit | |||
| virtual Status AcceptAfter(NodePass *p, bool *modified); | |||
| /// \brief Method to get status from Node.Build() | |||
| /// \notes Remove me after changing return val of Build() | |||
| Status BuildStatus() { return build_status; } | |||
| protected: | |||
| std::vector<std::shared_ptr<DatasetNode>> children; | |||
| std::shared_ptr<DatasetCache> cache_; | |||
| @@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||
| int32_t rows_per_buffer_; | |||
| int32_t connector_que_size_; | |||
| int32_t worker_connector_size_; | |||
| Status build_status; // remove me after changing return val of Build() | |||
| }; | |||
| } // namespace dataset | |||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for FilterNode | |||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, | |||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | |||
| std::vector<std::string> input_columns) | |||
| : predicate_(predicate), input_columns_(input_columns) { | |||
| this->children.push_back(child); | |||
| @@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| std::shared_ptr<TensorOp> c_func; | |||
| c_func = std::make_shared<CFuncOp>(predicate_); | |||
| node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func)); | |||
| node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_)); | |||
| return node_ops; | |||
| } | |||
| @@ -29,7 +29,7 @@ namespace dataset { | |||
| class FilterNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, | |||
| FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | |||
| std::vector<std::string> input_columns = {}); | |||
| /// \brief Destructor | |||
| @@ -44,7 +44,7 @@ class FilterNode : public DatasetNode { | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::function<TensorRow(TensorRow)> predicate_; | |||
| std::shared_ptr<TensorOp> predicate_; | |||
| std::vector<std::string> input_columns_; | |||
| }; | |||
| @@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | |||
| node_ops.push_back(project_op); | |||
| } | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(map_op); | |||
| return node_ops; | |||
| @@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() { | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_)); | |||
| build_status = schema->LoadSchemaFile(schema_path_, column_names_); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| // Argument that is not exposed to user in the API. | |||
| std::set<std::string> extensions = {}; | |||
| @@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() { | |||
| RETURN_EMPTY_IF_ERROR( | |||
| schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| decode_, usage_, extensions_, std::move(schema), | |||
| @@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() { | |||
| RETURN_EMPTY_IF_ERROR( | |||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| @@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() { | |||
| RETURN_EMPTY_IF_ERROR( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | |||
| dataset_dir_, connector_que_size_, std::move(schema), | |||
| @@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | |||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | |||
| build_status = clue_op->Init(); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||
| build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| @@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||
| std::shared_ptr<CocoOp> op = | |||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(op); | |||
| @@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() { | |||
| std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | |||
| rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | |||
| num_shards_, shard_id_, std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | |||
| build_status = csv_op->Init(); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows)); | |||
| build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| @@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< | |||
| const std::vector<DataType> &column_types) | |||
| : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} | |||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||
| : generator_function_(generator_function), schema_(schema) {} | |||
| std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | |||
| std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>(); | |||
| if (schema_ != nullptr) { | |||
| column_names_.clear(); | |||
| column_types_.clear(); | |||
| std::string schema_json_string = schema_->to_json(); | |||
| RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {})); | |||
| for (int32_t i = 0; i < data_schema->NumColumns(); i++) { | |||
| ColDescriptor col = data_schema->column(i); | |||
| column_names_.push_back(col.name()); | |||
| column_types_.push_back((col.type())); | |||
| } | |||
| } | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | |||
| @@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | |||
| // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which | |||
| // best be delivered when the test cases for this api is ready. | |||
| Status rc = op->Init(); | |||
| build_status = rc; // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| if (rc.IsOk()) { | |||
| node_ops.push_back(op); | |||
| @@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | |||
| // no validation is needed for generator op. | |||
| Status GeneratorNode::ValidateParams() { return Status::OK(); } | |||
| Status GeneratorNode::GetShardId(int32_t *shard_id) { | |||
| RETURN_UNEXPECTED_IF_NULL(shard_id); | |||
| *shard_id = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode { | |||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | |||
| const std::vector<DataType> &column_types); | |||
| /// \brief Constructor | |||
| GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema); | |||
| /// \brief Destructor | |||
| ~GeneratorNode() = default; | |||
| @@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding | |||
| /// \return Status Status::OK() if get shard id successfully | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| private: | |||
| py::function generator_function_; | |||
| std::vector<std::string> column_names_; | |||
| std::vector<DataType> column_types_; | |||
| std::shared_ptr<SchemaObj> schema_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() { | |||
| RETURN_EMPTY_IF_ERROR( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | |||
| @@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() { | |||
| manifest_op = | |||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | |||
| class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(manifest_op); | |||
| @@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| std::vector<std::shared_ptr<ShardOperator>> operators_; | |||
| RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | |||
| build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| std::shared_ptr<MindRecordOp> mindrecord_op; | |||
| // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | |||
| @@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||
| padded_sample_, sample_bytes_); | |||
| } | |||
| RETURN_EMPTY_IF_ERROR(mindrecord_op->Init()); | |||
| build_status = mindrecord_op->Init(); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(mindrecord_op); | |||
| return node_ops; | |||
| @@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() { | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_EMPTY_IF_ERROR( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | |||
| @@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||
| std::shared_ptr<RandomDataOp> op; | |||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | |||
| std::move(data_schema), std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(op); | |||
| @@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | |||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||
| build_status = text_file_op->Init(); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||
| build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| // Add TextFileOp | |||
| node_ops.push_back(text_file_op); | |||
| @@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||
| std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | |||
| shard_id_, shard_equal_rows_, std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(tf_reader_op->Init()); | |||
| build_status = tf_reader_op->Init(); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| @@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files)); | |||
| build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| // Add the shuffle op after this op | |||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op)); | |||
| build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||
| rows_per_buffer_, &shuffle_op); | |||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||
| node_ops.push_back(shuffle_op); | |||
| } | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| // Add TFReaderOp | |||
| node_ops.push_back(tf_reader_op); | |||
| @@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||
| std::shared_ptr<VOCOp> voc_op; | |||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | |||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| node_ops.push_back(voc_op); | |||
| return node_ops; | |||
| @@ -27,9 +27,8 @@ namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for SyncWaitNode | |||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||
| py::function callback) | |||
| : condition_name_(condition_name), num_batch_(num_batch), callback_(callback) { | |||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) | |||
| : condition_name_(condition_name), callback_(callback) { | |||
| this->children.push_back(child); | |||
| } | |||
| @@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() { | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_)); | |||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||
| // The reason for this is because having it otherwise can lead to blocking issues | |||
| // See barrier_op.h for more details | |||
| int32_t rows_per_buffer = 1; | |||
| node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_)); | |||
| return node_ops; | |||
| } | |||
| // Function to validate the parameters for SyncWaitNode | |||
| Status SyncWaitNode::ValidateParams() { | |||
| if (num_batch_ <= 0) { | |||
| std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_); | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SyncWaitNode::ValidateParams() { return Status::OK(); } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -31,8 +31,7 @@ namespace dataset { | |||
| class SyncWaitNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||
| py::function callback); | |||
| SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback); | |||
| /// \brief Destructor | |||
| ~SyncWaitNode() = default; | |||
| @@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode { | |||
| private: | |||
| std::string condition_name_; | |||
| int32_t num_batch_; | |||
| py::function callback_; | |||
| }; | |||
| @@ -18,72 +18,80 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for TransferNode | |||
| TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end) | |||
| : prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) { | |||
| TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, | |||
| bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) | |||
| : prefetch_size_(16), | |||
| queue_name_(std::move(queue_name)), | |||
| device_type_(std::move(device_type)), | |||
| send_epoch_end_(send_epoch_end), | |||
| total_batch_(total_batch), | |||
| create_data_info_queue_(create_data_info_queue), | |||
| device_id_(0) { | |||
| this->children.push_back(child); | |||
| } | |||
| // Validator for TransferNode | |||
| Status TransferNode::ValidateParams() { | |||
| // Check if device_type_ is in {"CPU", "GPU", "Ascend"} | |||
| RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"})); | |||
| if (total_batch_ < 0) { | |||
| std::string err_msg = "TransferNode: Total batches should be >= 0, value given: "; | |||
| MS_LOG(ERROR) << err_msg << total_batch_; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| // Function to build TransferNode | |||
| std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | |||
| // Get a uuid for queue name | |||
| queue_name_ = Services::GetUniqueID(); | |||
| // TODO(CRC): | |||
| // Get device type from ms context | |||
| device_type_ = "CPU"; | |||
| // Get device ID from children | |||
| device_id_ = 0; | |||
| RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_)); | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| if (queue_name_.empty()) { | |||
| // Get a uuid for queue name | |||
| queue_name_ = Services::GetUniqueID(); | |||
| } | |||
| if (device_type_.empty()) { | |||
| auto context = MsContext::GetInstance(); | |||
| if (context == nullptr) { | |||
| device_type_ = kCPUDevice; | |||
| } else { | |||
| device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| } | |||
| } | |||
| // Get device type from ms context | |||
| // Convert device_type_ from string to DeviceType | |||
| DeviceQueueOp::DeviceType type; | |||
| if (device_type_ == "CPU") { | |||
| if (device_type_ == kCPUDevice) { | |||
| type = DeviceQueueOp::DeviceType::CPU; | |||
| } else if (device_type_ == "GPU") { | |||
| } else if (device_type_ == kGPUDevice) { | |||
| type = DeviceQueueOp::DeviceType::GPU; | |||
| } else if (device_type_ == "Ascend") { | |||
| } else if (device_type_ == kAscendDevice) { | |||
| type = DeviceQueueOp::DeviceType::Ascend; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknown device target."; | |||
| return {}; | |||
| } | |||
| node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||
| total_batch_, false)); | |||
| return node_ops; | |||
| } | |||
| // Function to get the device_id | |||
| Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) { | |||
| // Get device id according to the type of dataset | |||
| Status rc = ds->GetShardId(device_id); | |||
| if (rc != Status::OK()) { | |||
| // Get device id from the child node | |||
| if (ds->Children().size()) { | |||
| ds = ds->Children()[0]; | |||
| return TransferNode::get_distribution(ds, device_id); | |||
| } else { | |||
| std::string err_msg = "Unknown dataset type."; | |||
| MS_LOG(ERROR) << err_msg; | |||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||
| } | |||
| } | |||
| // Get device ID (shard ID) from children | |||
| device_id_ = 0; | |||
| build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build() | |||
| RETURN_EMPTY_IF_ERROR(build_status); | |||
| return Status::OK(); | |||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||
| node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||
| total_batch_, create_data_info_queue_)); | |||
| return node_ops; | |||
| } | |||
| } // namespace dataset | |||
| @@ -29,7 +29,8 @@ namespace dataset { | |||
| class TransferNode : public DatasetNode { | |||
| public: | |||
| /// \brief Constructor | |||
| TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end); | |||
| TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end, | |||
| int32_t total_batch, bool create_data_info_queue); | |||
| /// \brief Destructor | |||
| ~TransferNode() = default; | |||
| @@ -42,8 +43,6 @@ class TransferNode : public DatasetNode { | |||
| /// \return Status Status::OK() if all the parameters are valid | |||
| Status ValidateParams() override; | |||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||
| private: | |||
| std::string queue_name_; | |||
| int32_t device_id_; | |||
| @@ -51,6 +50,7 @@ class TransferNode : public DatasetNode { | |||
| int32_t prefetch_size_; | |||
| bool send_epoch_end_; | |||
| int32_t total_batch_; | |||
| bool create_data_info_queue_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo | |||
| } | |||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | |||
| if (type_ == kOutputShapeAndType) { | |||
| nodes_to_clear_callback_.push_back(node); | |||
| } else if (type_ == kDatasetSize) { | |||
| nodes_to_remove_.push_back(node); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||
| if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); | |||
| return Status::OK(); | |||
| } | |||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||
| if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); | |||
| nodes_to_clear_callback_.push_back(node); | |||
| return Status::OK(); | |||
| } | |||
| @@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,10 @@ class GetterPass : public TreePass { | |||
| enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; | |||
| /// \brief Constructor | |||
| explicit GetterPass(GetterType tp) : pass_(tp) {} | |||
| /// \brief default copy Constructor | |||
| explicit GetterPass(const GetterPass &) = default; | |||
| /// \brief Destructor | |||
| ~GetterPass() = default; | |||
| @@ -51,11 +55,10 @@ class GetterPass : public TreePass { | |||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); } | |||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override; | |||
| Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override; | |||
| // whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp | |||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | |||
| @@ -67,7 +70,7 @@ class GetterPass : public TreePass { | |||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_; | |||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_; | |||
| }; | |||
| // outter class needs only to own the inner class object since it automatically has access to its private variables | |||
| // outer class needs only to own the inner class object since it automatically has access to its private variables | |||
| GetterNodes pass_; | |||
| }; | |||
| } // namespace dataset | |||
| @@ -19,7 +19,14 @@ | |||
| namespace mindspore::dataset { | |||
| Status PythonRuntimeContext::Terminate() { return TerminateImpl(); } | |||
| Status PythonRuntimeContext::Terminate() { | |||
| MS_LOG(INFO) << "Terminating a PythonRuntime"; | |||
| if (tree_consumer_ != nullptr) { | |||
| return TerminateImpl(); | |||
| } | |||
| MS_LOG(WARNING) << "TreeConsumer was not initialized"; | |||
| return Status::OK(); | |||
| } | |||
| Status PythonRuntimeContext::TerminateImpl() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | |||
| @@ -22,7 +22,14 @@ namespace mindspore::dataset { | |||
| void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { | |||
| tree_consumer_ = std::move(tree_consumer); | |||
| } | |||
| Status NativeRuntimeContext::Terminate() { return TerminateImpl(); } | |||
| Status NativeRuntimeContext::Terminate() { | |||
| MS_LOG(INFO) << "Terminating a NativeRuntime"; | |||
| if (tree_consumer_ != nullptr) { | |||
| return TerminateImpl(); | |||
| } | |||
| MS_LOG(WARNING) << "TreeConsumer was not initialized"; | |||
| return Status::OK(); | |||
| } | |||
| Status NativeRuntimeContext::TerminateImpl() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | |||
| @@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||
| Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | |||
| // Build the DatasetOp ExecutionTree from the optimized IR tree | |||
| std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | |||
| RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build() | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node."); | |||
| (*op) = ops.front(); // return the first op to be added as child by the caller of this function | |||
| @@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||
| RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); | |||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | |||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | |||
| // Note: We will gradually move the pre pass, optimizer pass, and post pass | |||
| // on ExecutionTree to perform on IR tree. | |||
| // Prepare the tree | |||
| @@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||
| // After the tree is prepared, the col_name_id_map can safely be obtained | |||
| column_name_map_ = tree_->root()->column_name_id_map(); | |||
| // Profiling parameters init | |||
| cur_batch_num_ = 0; | |||
| cur_connector_size_ = 0; | |||
| cur_connector_capacity_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| @@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) { | |||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||
| RETURN_UNEXPECTED_IF_NULL(row); | |||
| row->clear(); // make sure row is empty | |||
| bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); | |||
| // When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree | |||
| if (cur_db_ == nullptr) { | |||
| RETURN_IF_NOT_OK(tree_->Launch()); | |||
| // Profiling | |||
| std::shared_ptr<Tracing> node; | |||
| Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); | |||
| if (s.IsOk()) { | |||
| tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node); | |||
| } | |||
| if (tracing_ != nullptr) { | |||
| cur_connector_size_ = tree_->root()->ConnectorSize(); | |||
| cur_connector_capacity_ = tree_->root()->ConnectorCapacity(); | |||
| } | |||
| RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag | |||
| RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows) | |||
| if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows) | |||
| MS_LOG(INFO) << "End of data iteration."; | |||
| if (isProfilingEnable) { | |||
| tree_->SetEpochEnd(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached."); | |||
| if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf | |||
| RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); | |||
| RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag | |||
| if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag | |||
| MS_LOG(INFO) << "End of data iteration."; | |||
| if (isProfilingEnable) { | |||
| tree_->SetEpochEnd(); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| if (cur_db_->eof()) { | |||
| tree_->SetFinished(); | |||
| std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; | |||
| RETURN_STATUS_UNEXPECTED(err); | |||
| } | |||
| } | |||
| RETURN_IF_NOT_OK(cur_db_->PopRow(row)); | |||
| // Record profiling info | |||
| if (tracing_ != nullptr) { | |||
| cur_batch_num_++; | |||
| tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| #include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -60,6 +61,9 @@ class TreeAdapter { | |||
| // Set optional optimization pass | |||
| void SetOptimize(bool value) { optimize_ = value; } | |||
| // function to override override the pre-pass | |||
| void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; } | |||
| // Optional optimizations status | |||
| bool OptimizationEnabled() const { return optimize_; } | |||
| @@ -82,9 +86,14 @@ class TreeAdapter { | |||
| std::unique_ptr<DataBuffer> cur_db_; | |||
| std::unordered_map<std::string, int32_t> column_name_map_; | |||
| std::unique_ptr<ExecutionTree> tree_; | |||
| std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling | |||
| int32_t num_epochs_; | |||
| bool optimize_; // Flag to enable optional optimization pass | |||
| bool optimize_; // Flag to enable optional optimization pass | |||
| std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data | |||
| int32_t cur_batch_num_; // current batch number, used for profiling | |||
| int32_t cur_connector_size_; // current connector size of root op, used for profiling | |||
| int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling | |||
| std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||
| /// \brief Function to transfer data through a device. | |||
| /// \notes If device is Ascend, features of data will be transferred one by one. The limitation | |||
| /// of data transmission per time is 256M. | |||
| /// \param[in] queue_name Channel name (default="", create new unique name). | |||
| /// \param[in] device_type Type of device (default="", get from MSContext). | |||
| /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs). | |||
| /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). | |||
| /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data). | |||
| /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes | |||
| /// of data or not(default=false). | |||
| /// \return Returns true if no error encountered else false. | |||
| bool DeviceQueue(bool send_epoch_end = true); | |||
| bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1, | |||
| bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false); | |||
| /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline | |||
| /// \note Usage restrictions: | |||
| @@ -371,21 +378,34 @@ class SchemaObj { | |||
| /// \brief SchemaObj init function | |||
| /// \return bool true if schema init success | |||
| bool init(); | |||
| Status init(); | |||
| /// \brief Add new column to the schema with unknown shape of rank 1 | |||
| /// \param[in] name name of the column. | |||
| /// \param[in] de_type data type of the column(TypeId). | |||
| /// \return bool true if schema init success | |||
| Status add_column(std::string name, TypeId de_type); | |||
| /// \brief Add new column to the schema with unknown shape of rank 1 | |||
| /// \param[in] name name of the column. | |||
| /// \param[in] de_type data type of the column(std::string). | |||
| /// \param[in] shape shape of the column. | |||
| /// \return bool true if schema init success | |||
| Status add_column(std::string name, std::string de_type); | |||
| /// \brief Add new column to the schema | |||
| /// \param[in] name name of the column. | |||
| /// \param[in] de_type data type of the column(TypeId). | |||
| /// \param[in] shape shape of the column. | |||
| /// \return bool true if schema init success | |||
| bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape); | |||
| Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape); | |||
| /// \brief Add new column to the schema | |||
| /// \param[in] name name of the column. | |||
| /// \param[in] de_type data type of the column(std::string). | |||
| /// \param[in] shape shape of the column. | |||
| /// \return bool true if schema init success | |||
| bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape); | |||
| Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape); | |||
| /// \brief Get a JSON string of the schema | |||
| /// \return JSON string of the schema | |||
| @@ -395,25 +415,27 @@ class SchemaObj { | |||
| std::string to_string() { return to_json(); } | |||
| /// \brief set a new value to dataset_type | |||
| inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; } | |||
| inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); } | |||
| /// \brief set a new value to num_rows | |||
| inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; } | |||
| /// \brief get the current num_rows | |||
| inline int32_t get_num_rows() { return num_rows_; } | |||
| inline int32_t get_num_rows() const { return num_rows_; } | |||
| Status FromJSONString(const std::string &json_string); | |||
| private: | |||
| /// \brief Parse the columns and add it to columns | |||
| /// \param[in] columns dataset attribution information, decoded from schema file. | |||
| /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. | |||
| /// \return JSON string of the schema | |||
| bool parse_column(nlohmann::json columns); | |||
| Status parse_column(nlohmann::json columns); | |||
| /// \brief Get schema file from json file | |||
| /// \param[in] json_obj object of json parsed. | |||
| /// \return bool true if json dump success | |||
| bool from_json(nlohmann::json json_obj); | |||
| Status from_json(nlohmann::json json_obj); | |||
| int32_t num_rows_; | |||
| std::string dataset_type_; | |||
| @@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||
| class DistributedSamplerObj; | |||
| class PKSamplerObj; | |||
| class PreBuiltSamplerObj; | |||
| class RandomSamplerObj; | |||
| class SequentialSamplerObj; | |||
| class SubsetRandomSamplerObj; | |||
| @@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj { | |||
| int64_t num_samples_; | |||
| }; | |||
| class PreBuiltSamplerObj : public SamplerObj { | |||
| public: | |||
| #ifndef ENABLE_ANDROID | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | |||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | |||
| #endif | |||
| ~PreBuiltSamplerObj() = default; | |||
| std::shared_ptr<SamplerRT> Build() override; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||
| #endif | |||
| bool ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<SamplerRT> sp_; | |||
| #ifndef ENABLE_ANDROID | |||
| std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_; | |||
| #endif | |||
| }; | |||
| class RandomSamplerObj : public SamplerObj { | |||
| public: | |||
| RandomSamplerObj(bool replacement, int64_t num_samples); | |||
| @@ -70,6 +70,7 @@ namespace transforms { | |||
| class ComposeOperation; | |||
| class DuplicateOperation; | |||
| class OneHotOperation; | |||
| class PreBuiltOperation; | |||
| class RandomApplyOperation; | |||
| class RandomChoiceOperation; | |||
| class TypeCastOperation; | |||
| @@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation { | |||
| float num_classes_; | |||
| }; | |||
| class PreBuiltOperation : public TensorOperation { | |||
| public: | |||
| explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op); | |||
| ~PreBuiltOperation() = default; | |||
| std::shared_ptr<TensorOp> Build() override; | |||
| Status ValidateParams() override; | |||
| private: | |||
| std::shared_ptr<TensorOp> op_; | |||
| }; | |||
| class RandomApplyOperation : public TensorOperation { | |||
| public: | |||
| explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob); | |||
| @@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation { | |||
| private: | |||
| std::vector<std::shared_ptr<TensorOperation>> transforms_; | |||
| }; | |||
| class TypeCastOperation : public TensorOperation { | |||
| public: | |||
| explicit TypeCastOperation(std::string data_type); | |||
| @@ -71,6 +71,15 @@ namespace dataset { | |||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ | |||
| } while (false) | |||
| #define RETURN_SECOND_IF_ERROR(_s, _r) \ | |||
| do { \ | |||
| Status __rc = (_s); \ | |||
| if (__rc.IsError()) { \ | |||
| MS_LOG(ERROR) << __rc; \ | |||
| return _r; \ | |||
| } \ | |||
| } while (false) | |||
| enum class StatusCode : char { | |||
| kOK = 0, | |||
| kOutOfMemory = 1, | |||
| @@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) { | |||
| while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { | |||
| // We can't tell which conditional_variable this thread is waiting on. So we may need | |||
| // to interrupt everything one more time. | |||
| MS_LOG(INFO) << "Some threads not responding. Interrupt again"; | |||
| std::stringstream ss; | |||
| ss << get_id(); | |||
| MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again"; | |||
| interrupt_svc->InterruptAll(); | |||
| } | |||
| } else { | |||
| @@ -21,7 +21,8 @@ import numpy | |||
| import mindspore._c_dataengine as cde | |||
| __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', | |||
| 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load'] | |||
| 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load', | |||
| 'get_callback_timeout'] | |||
| INT32_MAX = 2147483647 | |||
| UINT32_MAX = 4294967295 | |||
| @@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list): | |||
| for index, _ in enumerate(type_list): | |||
| if type_list[index] is not None: | |||
| type_list[index] = mstype_to_detype(type_list[index]) | |||
| else: | |||
| type_list[index] = cde.DataType("") | |||
| return type_list | |||
| @@ -15,17 +15,13 @@ | |||
| """Built-in iterators. | |||
| """ | |||
| from abc import abstractmethod | |||
| import copy | |||
| import weakref | |||
| import numpy as np | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._c_dataengine import DEPipeline | |||
| from mindspore._c_dataengine import OpName | |||
| import mindspore._c_dataengine as cde | |||
| from mindspore import log as logger | |||
| from . import datasets as de | |||
| _ITERATOR_CLEANUP = False | |||
| @@ -57,29 +53,6 @@ def _cleanup(): | |||
| itr.release() | |||
| def alter_tree(node): | |||
| """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes.""" | |||
| if not node.children: | |||
| return _alter_node(node) | |||
| converted_children = [] | |||
| for input_op in node.children: | |||
| converted_children.append(alter_tree(input_op)) | |||
| node.children = converted_children | |||
| return _alter_node(node) | |||
| def _alter_node(node): | |||
| """DEPRECATED""" | |||
| # Please check ccsrc/dataset/engine/opt for tree transformation. | |||
| if isinstance(node, de.MapDataset): | |||
| if node.python_multiprocessing: | |||
| # Bootstrap can only be performed on a copy of the original dataset node. | |||
| # Bootstrap on original dataset node will make all iterators share the same process pool | |||
| node.iterator_bootstrap() | |||
| return node | |||
| class Iterator: | |||
| """ | |||
| General Iterator over a dataset. | |||
| @@ -89,185 +62,62 @@ class Iterator: | |||
| """ | |||
| def __init__(self, dataset, num_epochs=-1, output_numpy=False): | |||
| self.num_epochs = num_epochs | |||
| self.output_numpy = output_numpy | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| _unset_iterator_cleanup() | |||
| self._col_names = None | |||
| # create a copy of tree and work on it. | |||
| self.dataset = copy.deepcopy(dataset) | |||
| self.ori_dataset = dataset | |||
| self.parent_subtree = [] | |||
| # The dataset passed into the iterator is not the root of the tree. | |||
| # Trim the tree by saving the parent subtree into self.parent_subtree and | |||
| # restore it after launching our C++ pipeline. | |||
| if self.dataset.parent: | |||
| logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.") | |||
| self.parent_subtree = self.dataset.parent | |||
| self.dataset.parent = [] | |||
| self.dataset = alter_tree(self.dataset) | |||
| if not self.__is_tree(): | |||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).") | |||
| self.depipeline = DEPipeline() | |||
| # for manifest temporary use | |||
| self.__batch_node(self.dataset, 0) | |||
| root = self.__convert_node_postorder(self.dataset) | |||
| self.depipeline.AssignRootNode(root) | |||
| self.depipeline.PrepareTree(self.num_epochs) | |||
| self.ir_tree, self.dataset = dataset.create_ir_tree() | |||
| self._runtime_context = cde.PythonRuntimeContext() | |||
| self._runtime_context.Init() | |||
| consumer = cde.PythonIteratorConsumer(num_epochs) | |||
| consumer.Init(self.ir_tree) | |||
| self._runtime_context.AssignConsumer(consumer) | |||
| self._iterator = self._runtime_context.GetConsumer() | |||
| self._transform_tensor = lambda t: t.as_array() | |||
| if not output_numpy: | |||
| self._transform_tensor = lambda t: Tensor(t.as_array()) | |||
| self._index = 0 | |||
| # todo remove next when ContextManager is done | |||
| ITERATORS_LIST.append(weakref.ref(self)) | |||
| _unset_iterator_cleanup() | |||
| ####### | |||
| def __iter__(self): | |||
| return self | |||
| def stop(self): | |||
| """ | |||
| Manually terminate Python iterator instead of relying on out of scope destruction. | |||
| """ | |||
| logger.info("Terminating Python iterator. This will also terminate C++ pipeline.") | |||
| if hasattr(self, 'depipeline') and self.depipeline: | |||
| del self.depipeline | |||
| def __is_tree_node(self, node): | |||
| """Check if a node is tree node.""" | |||
| if not node.children: | |||
| if len(node.parent) > 1: | |||
| return False | |||
| if len(node.parent) > 1: | |||
| return False | |||
| for input_node in node.children: | |||
| cls = self.__is_tree_node(input_node) | |||
| if not cls: | |||
| return False | |||
| return True | |||
| def __is_tree(self): | |||
| return self.__is_tree_node(self.dataset) | |||
| @staticmethod | |||
| def __get_dataset_type(dataset): | |||
| """Get the dataset type.""" | |||
| op_type = None | |||
| if isinstance(dataset, de.ShuffleDataset): | |||
| op_type = OpName.SHUFFLE | |||
| elif isinstance(dataset, de.MindDataset): | |||
| op_type = OpName.MINDRECORD | |||
| elif isinstance(dataset, de.BatchDataset): | |||
| op_type = OpName.BATCH | |||
| elif isinstance(dataset, de.BucketBatchByLengthDataset): | |||
| op_type = OpName.BUCKETBATCH | |||
| elif isinstance(dataset, de.SyncWaitDataset): | |||
| op_type = OpName.BARRIER | |||
| elif isinstance(dataset, de.ZipDataset): | |||
| op_type = OpName.ZIP | |||
| elif isinstance(dataset, de.ConcatDataset): | |||
| op_type = OpName.CONCAT | |||
| elif isinstance(dataset, de.MapDataset): | |||
| op_type = OpName.MAP | |||
| elif isinstance(dataset, de.FilterDataset): | |||
| op_type = OpName.FILTER | |||
| elif isinstance(dataset, de.RepeatDataset): | |||
| op_type = OpName.REPEAT | |||
| elif isinstance(dataset, de.SkipDataset): | |||
| op_type = OpName.SKIP | |||
| elif isinstance(dataset, de.TakeDataset): | |||
| op_type = OpName.TAKE | |||
| elif isinstance(dataset, de.ImageFolderDataset): | |||
| op_type = OpName.IMAGEFOLDER | |||
| elif isinstance(dataset, de.GeneratorDataset): | |||
| op_type = OpName.GENERATOR | |||
| elif isinstance(dataset, de.TransferDataset): | |||
| op_type = OpName.DEVICEQUEUE | |||
| elif isinstance(dataset, de.RenameDataset): | |||
| op_type = OpName.RENAME | |||
| elif isinstance(dataset, de.TFRecordDataset): | |||
| op_type = OpName.TFREADER | |||
| elif isinstance(dataset, de.ProjectDataset): | |||
| op_type = OpName.PROJECT | |||
| elif isinstance(dataset, de.MnistDataset): | |||
| op_type = OpName.MNIST | |||
| elif isinstance(dataset, de.ManifestDataset): | |||
| op_type = OpName.MANIFEST | |||
| elif isinstance(dataset, de.VOCDataset): | |||
| op_type = OpName.VOC | |||
| elif isinstance(dataset, de.CocoDataset): | |||
| op_type = OpName.COCO | |||
| elif isinstance(dataset, de.Cifar10Dataset): | |||
| op_type = OpName.CIFAR10 | |||
| elif isinstance(dataset, de.Cifar100Dataset): | |||
| op_type = OpName.CIFAR100 | |||
| elif isinstance(dataset, de.CelebADataset): | |||
| op_type = OpName.CELEBA | |||
| elif isinstance(dataset, de.RandomDataset): | |||
| op_type = OpName.RANDOMDATA | |||
| elif isinstance(dataset, de.TextFileDataset): | |||
| op_type = OpName.TEXTFILE | |||
| elif isinstance(dataset, de.BuildVocabDataset): | |||
| op_type = OpName.BUILDVOCAB | |||
| elif isinstance(dataset, de.BuildSentencePieceVocabDataset): | |||
| op_type = OpName.SENTENCEPIECEVOCAB | |||
| elif isinstance(dataset, de.CLUEDataset): | |||
| op_type = OpName.CLUE | |||
| elif isinstance(dataset, de.CSVDataset): | |||
| op_type = OpName.CSV | |||
| else: | |||
| raise ValueError("Unsupported DatasetOp.") | |||
| return op_type | |||
| # Convert Python node into C node and add to C layer execution tree in postorder traversal. | |||
| def __convert_node_postorder(self, node): | |||
| self.check_node_type(node) | |||
| op_type = self.__get_dataset_type(node) | |||
| c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) | |||
| for py_child in node.children: | |||
| c_child = self.__convert_node_postorder(py_child) | |||
| self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"]) | |||
| return c_nodes["top"] | |||
| def __batch_node(self, dataset, level): | |||
| """Recursively get batch node in the dataset tree.""" | |||
| if isinstance(dataset, de.BatchDataset): | |||
| return | |||
| for input_op in dataset.children: | |||
| self.__batch_node(input_op, level + 1) | |||
| @staticmethod | |||
| def __print_local(dataset, level): | |||
| """Recursively print the name and address of nodes in the dataset tree.""" | |||
| name = dataset.__class__.__name__ | |||
| ptr = hex(id(dataset)) | |||
| for _ in range(level): | |||
| logger.info("\t", end='') | |||
| if not dataset.children: | |||
| logger.info("-%s (%s)", name, ptr) | |||
| else: | |||
| logger.info("+%s (%s)", name, ptr) | |||
| for input_op in dataset.children: | |||
| Iterator.__print_local(input_op, level + 1) | |||
| def print(self): | |||
| """Print the dataset tree""" | |||
| self.__print_local(self.dataset, 0) | |||
| if hasattr(self, '_runtime_context') and self._runtime_context: | |||
| if hasattr(self, '_iterator') and self._iterator: | |||
| self._runtime_context.Terminate() | |||
| del self._iterator | |||
| del self._runtime_context | |||
| del self.dataset | |||
| def release(self): | |||
| if hasattr(self, 'depipeline') and self.depipeline: | |||
| del self.depipeline | |||
| self.stop() | |||
| def __del__(self): | |||
| self.release() | |||
| @abstractmethod | |||
| def get_next(self): | |||
| def _get_next(self): | |||
| raise RuntimeError("Calling base class Iterator's get_next is invalid.") | |||
| def __next__(self): | |||
| if not self.depipeline: | |||
| if not self._runtime_context: | |||
| logger.warning("Iterator does not have a running C++ pipeline." + | |||
| "It might because Iterator stop() had been called, or C++ pipeline crashed silently.") | |||
| raise RuntimeError("Iterator does not have a running C++ pipeline.") | |||
| data = self.get_next() | |||
| data = self._get_next() | |||
| if not data: | |||
| if self._index == 0: | |||
| logger.warning("No records available.") | |||
| @@ -277,100 +127,56 @@ class Iterator: | |||
| self._index += 1 | |||
| return data | |||
| @abstractmethod | |||
| def check_node_type(self, node): | |||
| pass | |||
| def get_output_shapes(self): | |||
| return [t for t in self.depipeline.GetOutputShapes()] | |||
| def get_output_types(self): | |||
| return [t for t in self.depipeline.GetOutputTypes()] | |||
| def get_dataset_size(self): | |||
| return self.depipeline.GetDatasetSize() | |||
| def get_batch_size(self): | |||
| return self.depipeline.GetBatchSize() | |||
| def get_repeat_count(self): | |||
| return self.depipeline.GetRepeatCount() | |||
| def num_classes(self): | |||
| return self.depipeline.GetNumClasses() | |||
| def get_col_names(self): | |||
| return self.depipeline.GetColumnNames() | |||
| def __deepcopy__(self, memo): | |||
| return self | |||
| def _getters(self): | |||
| """ | |||
| Get pipeline information. | |||
| """ | |||
| getter = cde.TreeGetters() | |||
| getter.Init(self.ir_tree) | |||
| self._runtime_context.AssignConsumer(getter) | |||
| self._col_names = getter.GetColumnNames() | |||
| class SaveOp(Iterator): | |||
| """ | |||
| The derived class of Iterator with dict type. | |||
| """ | |||
| def __init__(self, dataset, num_epochs=-1): | |||
| super().__init__(dataset, num_epochs) | |||
| self.depipeline.LaunchTreeExec() | |||
| def get_next(self): | |||
| pass | |||
| def check_node_type(self, node): | |||
| if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)): | |||
| logger.warning("Used shuffle, repeat, batch before save operator.") | |||
| def save(self, file_names, file_type): | |||
| return self.depipeline.SaveDataset(file_names, file_type) | |||
| def get_col_names(self): | |||
| """ | |||
| Get names of the columns in the dataset | |||
| """ | |||
| if self._col_names is None: | |||
| self._getters() | |||
| return self._col_names | |||
| class DictIterator(Iterator): | |||
| """ | |||
| The derived class of Iterator with dict type. | |||
| """ | |||
| def __init__(self, dataset, num_epochs=-1, output_numpy=False): | |||
| super().__init__(dataset, num_epochs, output_numpy) | |||
| self.depipeline.LaunchTreeExec() | |||
| def check_node_type(self, node): | |||
| pass | |||
| def __iter__(self): | |||
| return self | |||
| def get_next(self): | |||
| def _get_next(self): | |||
| """ | |||
| Returns the next record in the dataset as dictionary | |||
| Returns: | |||
| Dict, the next record in the dataset. | |||
| """ | |||
| if self.output_numpy: | |||
| return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()} | |||
| return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()} | |||
| return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()} | |||
| class TupleIterator(Iterator): | |||
| """ | |||
| The derived class of Iterator with list type. | |||
| """ | |||
| def check_node_type(self, node): | |||
| pass | |||
| def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False): | |||
| if columns is not None: | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| # todo: move next to IR | |||
| dataset = dataset.project(columns) | |||
| super().__init__(dataset, num_epochs, output_numpy) | |||
| self.depipeline.LaunchTreeExec() | |||
| def __iter__(self): | |||
| return self | |||
| def get_next(self): | |||
| def _get_next(self): | |||
| """ | |||
| Returns the next record in the dataset as a list | |||
| @@ -378,15 +184,14 @@ class TupleIterator(Iterator): | |||
| List, the next record in the dataset. | |||
| """ | |||
| if self.output_numpy: | |||
| return [t.as_array() for t in self.depipeline.GetNextAsList()] | |||
| return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()] | |||
| return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()] | |||
| class DummyIterator: | |||
| """ | |||
| A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" | |||
| """ | |||
| def __init__(self, dataset, mode): | |||
| self.mode = mode | |||
| self.shapes = dataset.output_shapes() | |||
| @@ -283,9 +283,12 @@ def create_node(node): | |||
| node.get('shard_id'), sampler) | |||
| elif dataset_op == 'TFRecordDataset': | |||
| shuffle = node.get('shuffle') | |||
| if shuffle is not None and isinstance(shuffle, str): | |||
| shuffle = de.Shuffle(shuffle) | |||
| pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), | |||
| node.get('num_samples'), node.get('num_parallel_workers'), | |||
| de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id')) | |||
| shuffle, node.get('num_shards'), node.get('shard_id')) | |||
| elif dataset_op == 'ManifestDataset': | |||
| sampler = construct_sampler(node.get('sampler')) | |||
| @@ -293,14 +293,38 @@ def check_save(method): | |||
| return new_method | |||
| def check_iterator(method): | |||
| def check_tuple_iterator(method): | |||
| """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| [columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_bool = ['output_numpy'] | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| if num_epochs is not None: | |||
| type_check(num_epochs, (int,), "num_epochs") | |||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||
| if columns is not None: | |||
| check_columns(columns, "column_names") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_dict_iterator(method): | |||
| """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_bool = ['output_numpy'] | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| if num_epochs is not None: | |||
| type_check(num_epochs, (int,), "num_epochs") | |||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -523,6 +547,8 @@ def check_batch(method): | |||
| sig = ins.signature(batch_size) | |||
| if len(sig.parameters) != 1: | |||
| raise ValueError("callable batch_size should take one parameter (BatchInfo).") | |||
| else: | |||
| check_pos_int32(int(batch_size), "batch_size") | |||
| if num_parallel_workers is not None: | |||
| check_num_parallel_workers(num_parallel_workers) | |||
| @@ -807,6 +833,21 @@ def check_project(method): | |||
| return new_method | |||
| def check_schema(method): | |||
| """check the input arguments of Schema.__init__.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [schema_file], _ = parse_user_args(method, *args, **kwargs) | |||
| if schema_file is not None: | |||
| type_check(schema_file, (str,), "schema_file") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_add_column(method): | |||
| """check the input arguments of add_column.""" | |||
| @@ -1261,3 +1302,23 @@ def check_cache_option(cache): | |||
| """Sanity check for cache parameter""" | |||
| if cache is not None: | |||
| type_check(cache, (cache_client.DatasetCache,), "cache") | |||
| def check_to_device_send(method): | |||
| """A wrapper that wraps a parameter checker around the check_to_device_send.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [num_epochs], _ = parse_user_args(method, *args, **kwargs) | |||
| if num_epochs is not None: | |||
| type_check(num_epochs, (int,), "num_epochs") | |||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def replace_none(value, default): | |||
| return value if value is not None else default | |||
| @@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format. | |||
| """ | |||
| from enum import IntEnum | |||
| import copy | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ | |||
| check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model | |||
| __all__ = [ | |||
| "Vocab", "SentencePieceVocab", "to_str", "to_bytes" | |||
| ] | |||
| @@ -39,8 +39,7 @@ class Vocab(cde.Vocab): | |||
| @classmethod | |||
| @check_from_dataset | |||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | |||
| special_first=True): | |||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True): | |||
| """ | |||
| Build a vocab from a dataset. | |||
| @@ -69,21 +68,7 @@ class Vocab(cde.Vocab): | |||
| Returns: | |||
| Vocab, Vocab object built from dataset. | |||
| """ | |||
| vocab = Vocab() | |||
| if columns is None: | |||
| columns = [] | |||
| if not isinstance(columns, list): | |||
| columns = [columns] | |||
| if freq_range is None: | |||
| freq_range = (None, None) | |||
| if special_tokens is None: | |||
| special_tokens = [] | |||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | |||
| for d in root.create_dict_iterator(num_epochs=1): | |||
| if d is not None: | |||
| raise ValueError("from_dataset should receive data other than None.") | |||
| return vocab | |||
| return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first) | |||
| @classmethod | |||
| @check_from_list | |||
| @@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||
| """ | |||
| SentencePiece obiect that is used to segmentate words | |||
| """ | |||
| @classmethod | |||
| @check_from_dataset_sentencepiece | |||
| def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params): | |||
| @@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||
| SentencePiece, SentencePiece object from dataset. | |||
| """ | |||
| vocab = SentencePieceVocab() | |||
| root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage, | |||
| model_type, params) | |||
| for d in root.create_dict_iterator(num_epochs=1): | |||
| if d is None: | |||
| raise ValueError("from_dataset should receive data other than None.") | |||
| return vocab | |||
| return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage, | |||
| DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) | |||
| @classmethod | |||
| @check_from_file_sentencepiece | |||
| @@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum): | |||
| CHAR = 2 | |||
| WORD = 3 | |||
| DE_C_INTER_SENTENCEPIECE_MODE = { | |||
| SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM, | |||
| SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE, | |||
| @@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method): | |||
| [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) | |||
| if col_names is not None: | |||
| type_check(col_names, (list,), "col_names") | |||
| type_check_list(col_names, (str,), "col_names") | |||
| if vocab_size is not None: | |||
| check_uint32(vocab_size, "vocab_size") | |||
| @@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full") | |||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES | |||
| "${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc" | |||
| "${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc" | |||
| ) | |||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| @@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | |||
| return network | |||
| class DatasetHelper: | |||
| """ | |||
| DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. | |||
| @@ -197,7 +198,6 @@ class DatasetHelper: | |||
| def get_data_info(self): | |||
| return self.iter.get_data_info() | |||
| class _DatasetIter: | |||
| """Base iter for dataset helper""" | |||
| @@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter): | |||
| class _DatasetIterNormal: | |||
| """Iter for normal(non sink) mode, feed the data from host.""" | |||
| def __init__(self, dataset, epoch_num=-1): | |||
| self.dataset = dataset | |||
| self.device_num = _get_device_num() | |||
| @@ -61,15 +61,15 @@ class MindData: | |||
| def send(self, num_epochs=-1): | |||
| pass | |||
| def get_data_info(self): | |||
| pass | |||
| def stop_send(self): | |||
| pass | |||
| def continue_send(self): | |||
| pass | |||
| def get_data_info(self): | |||
| pass | |||
| def __len__(self): | |||
| return self._size | |||
| @@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) { | |||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||
| iter->GetNextRow(&row); | |||
| } | |||
| // 5 batches of size 2 | |||
| EXPECT_EQ(i, 5); | |||
| // With 2 boundaries, 3 buckets are created | |||
| EXPECT_EQ(i, 3); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| @@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { | |||
| // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not | |||
| EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); | |||
| EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | |||
| EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos); | |||
| EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | |||
| EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | |||
| } | |||
| @@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { | |||
| const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}}; | |||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | |||
| std::vector<size_t> row_sizes = {2, 2, 0, 0}; | |||
| std::vector<size_t> row_sizes = {2, 2, 0}; | |||
| TensorRow row; | |||
| for (size_t sz : row_sizes) { | |||
| @@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { | |||
| rc = tree_adapter.GetNext(&row); | |||
| EXPECT_TRUE(rc.IsError()); | |||
| const std::string err_msg = rc.ToString(); | |||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||
| } | |||
| TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||
| @@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||
| const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); | |||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | |||
| std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0}; | |||
| std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0}; | |||
| TensorRow row; | |||
| for (size_t sz : row_sizes) { | |||
| @@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||
| } | |||
| rc = tree_adapter.GetNext(&row); | |||
| const std::string err_msg = rc.ToString(); | |||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||
| } | |||
| TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||
| @@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||
| const std::unordered_map<std::string, int32_t> map = {{"label", 0}}; | |||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | |||
| std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0}; | |||
| std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0}; | |||
| TensorRow row; | |||
| for (size_t sz : row_sizes) { | |||
| @@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||
| } | |||
| rc = tree_adapter.GetNext(&row); | |||
| const std::string err_msg = rc.ToString(); | |||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||
| } | |||
| @@ -451,6 +451,10 @@ def test_batch_exception_13(): | |||
| def test_batch_exception_14(): | |||
| """ | |||
| Test per_batch_map and input column name | |||
| """ | |||
| logger.info("test_batch_exception_14") | |||
| batch_size = 2 | |||
| input_columns = ["num"] | |||
| data1 = ds.TFRecordDataset(DATA_DIR) | |||
| @@ -460,6 +464,22 @@ def test_batch_exception_14(): | |||
| assert "per_batch_map and input_columns need to be passed in together." in str(e) | |||
| def test_batch_exception_15(): | |||
| """ | |||
| Test batch_size = int32 max value + 1 | |||
| """ | |||
| logger.info("test_batch_exception_15") | |||
| batch_size = 2147483647 + 1 | |||
| input_columns = ["num"] | |||
| data1 = ds.TFRecordDataset(DATA_DIR) | |||
| err_msg = "" | |||
| try: | |||
| _ = data1.batch(batch_size=batch_size, input_columns=input_columns) | |||
| except ValueError as e: | |||
| err_msg = str(e) | |||
| assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg | |||
| if __name__ == '__main__': | |||
| test_batch_01() | |||
| test_batch_02() | |||
| @@ -486,4 +506,5 @@ if __name__ == '__main__': | |||
| test_batch_exception_12() | |||
| test_batch_exception_13() | |||
| test_batch_exception_14() | |||
| test_batch_exception_15() | |||
| logger.info('\n') | |||
| @@ -12,7 +12,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import os | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| @@ -354,6 +355,18 @@ def test_clue_to_device(): | |||
| data.send() | |||
| def test_clue_invalid_files(): | |||
| """ | |||
| Test CLUE with invalid files | |||
| """ | |||
| AFQMC_DIR = '../data/dataset/testCLUE/afqmc' | |||
| afqmc_train_json = os.path.join(AFQMC_DIR) | |||
| with pytest.raises(ValueError) as info: | |||
| _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False) | |||
| assert "The following patterns did not match any files" in str(info.value) | |||
| assert AFQMC_DIR in str(info.value) | |||
| if __name__ == "__main__": | |||
| test_clue() | |||
| test_clue_num_shards() | |||
| @@ -366,3 +379,4 @@ if __name__ == "__main__": | |||
| test_clue_tnews() | |||
| test_clue_wsc() | |||
| test_clue_to_device() | |||
| test_clue_invalid_files() | |||
| @@ -195,30 +195,42 @@ def test_csv_dataset_size(): | |||
| assert data.get_dataset_size() == 5 | |||
| def test_csv_dataset_exception(): | |||
| def test_csv_dataset_type_error(): | |||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", "", "", ""], | |||
| column_defaults=["", 0, "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| with pytest.raises(Exception) as err: | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| pass | |||
| assert "failed to parse file" in str(err.value) | |||
| assert "type does not match" in str(err.value) | |||
| def test_csv_dataset_type_error(): | |||
| def test_csv_dataset_exception(): | |||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | |||
| data = ds.CSVDataset( | |||
| TEST_FILE, | |||
| column_defaults=["", 0, "", ""], | |||
| column_defaults=["", "", "", ""], | |||
| column_names=['col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| with pytest.raises(Exception) as err: | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| pass | |||
| assert "type does not match" in str(err.value) | |||
| assert "failed to parse file" in str(err.value) | |||
| def test_csv_dataset_duplicate_columns(): | |||
| data = ds.CSVDataset( | |||
| DATA_FILE, | |||
| column_defaults=["1", "2", "3", "4"], | |||
| column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'], | |||
| shuffle=False) | |||
| with pytest.raises(RuntimeError) as info: | |||
| _ = data.create_dict_iterator(num_epochs=1, output_numpy=True) | |||
| assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value) | |||
| assert "column_names" in str(info.value) | |||
| if __name__ == "__main__": | |||
| @@ -234,5 +246,6 @@ if __name__ == "__main__": | |||
| test_csv_dataset_header() | |||
| test_csv_dataset_number() | |||
| test_csv_dataset_size() | |||
| test_csv_dataset_exception() | |||
| test_csv_dataset_type_error() | |||
| test_csv_dataset_exception() | |||
| test_csv_dataset_duplicate_columns() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================== | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.vision.c_transforms as vision | |||
| IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" | |||
| IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||
| @@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000 | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | |||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | |||
| MNIST_DATA_DIR = "../data/dataset/testMnistData" | |||
| MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord" | |||
| SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||
| MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" | |||
| CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" | |||
| CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data" | |||
| VOC_DATA_DIR = "../data/dataset/testVOC2012" | |||
| COCO_DATA_DIR = "../data/dataset/testCOCO/train/" | |||
| ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | |||
| CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" | |||
| CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||
| CSV_FILE = '../data/dataset/testCSV/1.csv' | |||
| TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | |||
| def test_imagenet_rawdata_dataset_size(): | |||
| @@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size(): | |||
| ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0) | |||
| assert ds_shard_2_0.get_dataset_size() == 6 | |||
| # FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code. | |||
| # Correct answer should be 12/3=4, the code issue should be addressed. | |||
| ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0) | |||
| assert ds_shard_3_0.get_dataset_size() == 4 | |||
| assert ds_shard_3_0.get_dataset_size() == 6 | |||
| count = 0 | |||
| for _ in ds_shard_3_0.create_dict_iterator(): | |||
| count += 1 | |||
| assert ds_shard_3_0.get_dataset_size() == count | |||
| def test_mnist_dataset_size(): | |||
| @@ -76,6 +93,14 @@ def test_mnist_dataset_size(): | |||
| assert ds_shard_3_0.get_dataset_size() == 3334 | |||
| def test_mind_dataset_size(): | |||
| dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0") | |||
| assert dataset.get_dataset_size() == 20 | |||
| dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 10 | |||
| def test_manifest_dataset_size(): | |||
| ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE) | |||
| assert ds_total.get_dataset_size() == 4 | |||
| @@ -95,10 +120,11 @@ def test_cifar10_dataset_size(): | |||
| assert ds_total.get_dataset_size() == 10000 | |||
| # test get_dataset_size with usage flag | |||
| train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() | |||
| assert train_size == 0 | |||
| train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() | |||
| assert train_size == 10000 | |||
| test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size() | |||
| assert test_size == 0 | |||
| all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() | |||
| assert all_size == 10000 | |||
| @@ -120,8 +146,6 @@ def test_cifar100_dataset_size(): | |||
| assert ds_total.get_dataset_size() == 10000 | |||
| # test get_dataset_size with usage flag | |||
| train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() | |||
| assert train_size == 0 | |||
| test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() | |||
| assert test_size == 10000 | |||
| all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() | |||
| @@ -137,10 +161,97 @@ def test_cifar100_dataset_size(): | |||
| assert ds_shard_3_0.get_dataset_size() == 3334 | |||
| def test_voc_dataset_size(): | |||
| dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) | |||
| assert dataset.get_dataset_size() == 10 | |||
| dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True, | |||
| num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 5 | |||
| def test_coco_dataset_size(): | |||
| dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", | |||
| decode=True, shuffle=False) | |||
| assert dataset.get_dataset_size() == 6 | |||
| dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, | |||
| shuffle=False, num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 3 | |||
| def test_celeba_dataset_size(): | |||
| dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||
| assert dataset.get_dataset_size() == 4 | |||
| dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||
| def test_clue_dataset_size(): | |||
| dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False) | |||
| assert dataset.get_dataset_size() == 3 | |||
| dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||
| def test_csv_dataset_size(): | |||
| dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], | |||
| shuffle=False) | |||
| assert dataset.get_dataset_size() == 3 | |||
| dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], | |||
| shuffle=False, num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||
| def test_text_file_dataset_size(): | |||
| dataset = ds.TextFileDataset(TEXT_DATA_FILE) | |||
| assert dataset.get_dataset_size() == 3 | |||
| dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0) | |||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||
| def test_padded_dataset_size(): | |||
| dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}]) | |||
| assert dataset.get_dataset_size() == 2 | |||
| def test_pipeline_get_dataset_size(): | |||
| dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False) | |||
| assert dataset.get_dataset_size() == 12 | |||
| dataset = dataset.shuffle(buffer_size=3) | |||
| assert dataset.get_dataset_size() == 12 | |||
| decode_op = vision.Decode() | |||
| resize_op = vision.RandomResize(10) | |||
| dataset = dataset.map([decode_op, resize_op], input_columns=["image"]) | |||
| assert dataset.get_dataset_size() == 12 | |||
| dataset = dataset.batch(batch_size=3) | |||
| assert dataset.get_dataset_size() == 4 | |||
| dataset = dataset.repeat(count=2) | |||
| assert dataset.get_dataset_size() == 8 | |||
| if __name__ == '__main__': | |||
| test_imagenet_rawdata_dataset_size() | |||
| test_imagenet_tf_file_dataset_size() | |||
| test_mnist_dataset_size() | |||
| test_mind_dataset_size() | |||
| test_manifest_dataset_size() | |||
| test_cifar10_dataset_size() | |||
| test_cifar100_dataset_size() | |||
| test_voc_dataset_size() | |||
| test_coco_dataset_size() | |||
| test_celeba_dataset_size() | |||
| test_clue_dataset_size() | |||
| test_csv_dataset_size() | |||
| test_text_file_dataset_size() | |||
| test_padded_dataset_size() | |||
| test_pipeline_get_dataset_size() | |||
| @@ -521,7 +521,7 @@ def test_chained_sampler_04(): | |||
| # Verify dataset size | |||
| data1_size = data1.get_dataset_size() | |||
| logger.info("dataset size is: {}".format(data1_size)) | |||
| assert data1_size == 24 | |||
| assert data1_size == 6 | |||
| # Verify number of iterations | |||
| num_iter = 0 | |||
| @@ -182,6 +182,15 @@ def test_voc_exception(): | |||
| pass | |||
| def test_voc_num_classes(): | |||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||
| assert data1.num_classes() is None | |||
| class_index = {'car': 0, 'cat': 1, 'train': 5} | |||
| data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True) | |||
| assert data2.num_classes() is None | |||
| if __name__ == '__main__': | |||
| test_voc_segmentation() | |||
| test_voc_detection() | |||
| @@ -191,3 +200,4 @@ if __name__ == '__main__': | |||
| test_case_1() | |||
| test_case_2() | |||
| test_voc_exception() | |||
| test_voc_num_classes() | |||
| @@ -107,7 +107,7 @@ def test_decode_op(): | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||
| with pytest.raises(RuntimeError) as info: | |||
| iter2.__next__() | |||
| @@ -205,7 +205,7 @@ def test_generator_dict_3(): | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||
| def test_generator_dict_4(): | |||
| @@ -396,7 +396,7 @@ def test_generator_tuple_3(): | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||
| def test_generator_tuple_4(): | |||
| @@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2(): | |||
| # Expect a AttributeError since iter1 has been stopped. | |||
| with pytest.raises(AttributeError) as info: | |||
| iter1.__next__() | |||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||
| def test_generator_tuple_repeat_repeat_3(): | |||
| @@ -74,9 +74,11 @@ def test_case2(): | |||
| def test_case3(): | |||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) | |||
| data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5) | |||
| data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2) | |||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename( | |||
| ["col_sint64"], ["a1"]) | |||
| data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename( | |||
| ["col_sint64"], ["a2"]) | |||
| data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"]) | |||
| data4 = ds.zip((data1, data2, data3)) | |||
| @@ -84,8 +86,9 @@ def test_case3(): | |||
| def test_case4(): | |||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) | |||
| data2 = ds.TFRecordDataset(FILES) | |||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename( | |||
| ["col_sint64"], ["a1"]) | |||
| data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"]) | |||
| assert data2.get_dataset_size() == 12 | |||
| data2 = data2.batch(2) | |||
| assert data2.get_dataset_size() == 6 | |||