From: @cathwong Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -2,9 +2,11 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| add_library(APItoPython OBJECT | add_library(APItoPython OBJECT | ||||
| python/de_pipeline.cc | |||||
| python/pybind_register.cc | python/pybind_register.cc | ||||
| python/bindings.cc | |||||
| python/pybind_conversion.cc | |||||
| python/bindings/dataset/include/datasets_bindings.cc | |||||
| python/bindings/dataset/include/iterator_bindings.cc | |||||
| python/bindings/dataset/include/schema_bindings.cc | |||||
| python/bindings/dataset/engine/cache/bindings.cc | python/bindings/dataset/engine/cache/bindings.cc | ||||
| python/bindings/dataset/core/bindings.cc | python/bindings/dataset/core/bindings.cc | ||||
| python/bindings/dataset/callback/bindings.cc | python/bindings/dataset/callback/bindings.cc | ||||
| @@ -115,7 +115,8 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to return a transferred Node that transfers data through a device. | // Function to return a transferred Node that transfers data through a device. | ||||
| bool Dataset::DeviceQueue(bool send_epoch_end) { | |||||
| bool Dataset::DeviceQueue(std::string queue_name, std::string device_type, int32_t num_epochs, bool send_epoch_end, | |||||
| int32_t total_batches, bool create_data_info_queue) { | |||||
| Status rc; | Status rc; | ||||
| // Build and launch tree | // Build and launch tree | ||||
| @@ -126,11 +127,12 @@ bool Dataset::DeviceQueue(bool send_epoch_end) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // Add TransferNode IR on top of dataset d | |||||
| auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), send_epoch_end); | |||||
| // Add TransferNode IR on top of dataset | |||||
| auto ds = std::make_shared<TransferNode>(shared_from_this()->IRNode(), queue_name, device_type, send_epoch_end, | |||||
| total_batches, create_data_info_queue); | |||||
| // Get ToDevice consumer | // Get ToDevice consumer | ||||
| auto consumer = std::make_unique<ToDevice>(send_epoch_end, -1); | |||||
| auto consumer = std::make_unique<ToDevice>(num_epochs); | |||||
| ToDevice *consumer_ = consumer.get(); | ToDevice *consumer_ = consumer.get(); | ||||
| rc = consumer->Init(ds); | rc = consumer->Init(ds); | ||||
| if (rc.IsError()) { | if (rc.IsError()) { | ||||
| @@ -199,127 +201,55 @@ Dataset::Dataset() { tree_getters_ = std::make_shared<TreeGetters>(); } | |||||
| int64_t Dataset::GetDatasetSize() { | int64_t Dataset::GetDatasetSize() { | ||||
| int64_t dataset_size; | int64_t dataset_size; | ||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->GetDatasetSize(&dataset_size); | |||||
| return rc.IsError() ? -1 : dataset_size; | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetDatasetSize(&dataset_size), -1); | |||||
| return dataset_size; | |||||
| } | } | ||||
| std::vector<DataType> Dataset::GetOutputTypes() { | std::vector<DataType> Dataset::GetOutputTypes() { | ||||
| std::vector<DataType> types; | std::vector<DataType> types; | ||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed."; | |||||
| return types; | |||||
| } | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed."; | |||||
| return types; | |||||
| } | |||||
| rc = tree_getters_->GetOutputTypes(&types); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed."; | |||||
| types.clear(); | |||||
| return types; | |||||
| } | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {}); | |||||
| return types; | return types; | ||||
| } | } | ||||
| std::vector<TensorShape> Dataset::GetOutputShapes() { | std::vector<TensorShape> Dataset::GetOutputShapes() { | ||||
| std::vector<TensorShape> shapes; | std::vector<TensorShape> shapes; | ||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed."; | |||||
| return shapes; | |||||
| } | |||||
| rc = tree_getters_->Init(this->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed."; | |||||
| return shapes; | |||||
| } | |||||
| rc = tree_getters_->GetOutputShapes(&shapes); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed."; | |||||
| shapes.clear(); | |||||
| return shapes; | |||||
| } | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {}); | |||||
| return shapes; | return shapes; | ||||
| } | } | ||||
| int64_t Dataset::GetNumClasses() { | int64_t Dataset::GetNumClasses() { | ||||
| int64_t num_classes; | int64_t num_classes; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->GetNumClasses(&num_classes); | |||||
| return rc.IsError() ? -1 : num_classes; | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1); | |||||
| return num_classes; | |||||
| } | } | ||||
| std::vector<std::string> Dataset::GetColumnNames() { | std::vector<std::string> Dataset::GetColumnNames() { | ||||
| std::vector<std::string> col_names; | std::vector<std::string> col_names; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetColumnNames: Initializing RuntimeContext failed."; | |||||
| return std::vector<std::string>(); | |||||
| } | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetColumnNames: Initializing TreeGetters failed."; | |||||
| return std::vector<std::string>(); | |||||
| } | |||||
| rc = tree_getters_->GetColumnNames(&col_names); | |||||
| return rc.IsError() ? std::vector<std::string>() : col_names; | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {}); | |||||
| return col_names; | |||||
| } | } | ||||
| std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() { | std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() { | ||||
| std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; | std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed."; | |||||
| return output_class_indexing; | |||||
| } | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed."; | |||||
| return output_class_indexing; | |||||
| } | |||||
| rc = tree_getters_->GetClassIndexing(&output_class_indexing); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed."; | |||||
| output_class_indexing.clear(); | |||||
| return output_class_indexing; | |||||
| } | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {}); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {}); | |||||
| return output_class_indexing; | return output_class_indexing; | ||||
| } | } | ||||
| @@ -501,9 +431,13 @@ BucketBatchByLengthDataset::BucketBatchByLengthDataset( | |||||
| std::function<TensorRow(TensorRow)> element_length_function, | std::function<TensorRow(TensorRow)> element_length_function, | ||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | ||||
| bool drop_remainder) { | bool drop_remainder) { | ||||
| auto ds = std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, | |||||
| bucket_batch_sizes, element_length_function, pad_info, | |||||
| pad_to_bucket_boundary, drop_remainder); | |||||
| std::shared_ptr<TensorOp> c_func = nullptr; | |||||
| if (element_length_function != nullptr) { | |||||
| c_func = std::make_shared<CFuncOp>(element_length_function); | |||||
| } | |||||
| auto ds = | |||||
| std::make_shared<BucketBatchByLengthNode>(input->IRNode(), column_names, bucket_boundaries, bucket_batch_sizes, | |||||
| c_func, pad_info, pad_to_bucket_boundary, drop_remainder); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| } | } | ||||
| @@ -522,7 +456,9 @@ ConcatDataset::ConcatDataset(const std::vector<std::shared_ptr<Dataset>> &datase | |||||
| FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate, | FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<TensorRow(TensorRow)> predicate, | ||||
| std::vector<std::string> input_columns) { | std::vector<std::string> input_columns) { | ||||
| auto ds = std::make_shared<FilterNode>(input->IRNode(), predicate, input_columns); | |||||
| std::shared_ptr<TensorOp> c_func = nullptr; | |||||
| if (predicate) c_func = std::make_shared<CFuncOp>(predicate); | |||||
| auto ds = std::make_shared<FilterNode>(input->IRNode(), c_func, input_columns); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| } | } | ||||
| @@ -604,40 +540,20 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) { | |||||
| #endif | #endif | ||||
| int64_t Dataset::GetBatchSize() { | int64_t Dataset::GetBatchSize() { | ||||
| int64_t batch_size; | int64_t batch_size; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->GetBatchSize(&batch_size); | |||||
| return rc.IsError() ? -1 : batch_size; | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetBatchSize(&batch_size), -1); | |||||
| return batch_size; | |||||
| } | } | ||||
| int64_t Dataset::GetRepeatCount() { | int64_t Dataset::GetRepeatCount() { | ||||
| int64_t repeat_count; | int64_t repeat_count; | ||||
| auto ds = shared_from_this(); | |||||
| Status rc; | |||||
| std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>(); | ||||
| rc = runtime_context->Init(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->Init(ds->IRNode()); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed."; | |||||
| return -1; | |||||
| } | |||||
| rc = tree_getters_->GetRepeatCount(&repeat_count); | |||||
| return rc.IsError() ? 0 : repeat_count; | |||||
| RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0); | |||||
| RETURN_SECOND_IF_ERROR(tree_getters_->GetRepeatCount(&repeat_count), 0); | |||||
| return repeat_count; | |||||
| } | } | ||||
| std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) { | std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) { | ||||
| @@ -720,62 +636,65 @@ std::shared_ptr<BatchDataset> Dataset::Batch(int32_t batch_size, bool drop_remai | |||||
| SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | SchemaObj::SchemaObj(const std::string &schema_file) : schema_file_(schema_file), num_rows_(0), dataset_type_("") {} | ||||
| // SchemaObj init function | // SchemaObj init function | ||||
| bool SchemaObj::init() { | |||||
| if (schema_file_ != "") { | |||||
| Status SchemaObj::init() { | |||||
| if (!schema_file_.empty()) { | |||||
| Path schema_file(schema_file_); | Path schema_file(schema_file_); | ||||
| if (!schema_file.Exists()) { | |||||
| MS_LOG(ERROR) << "The file " << schema_file << " does not exist or permission denied!"; | |||||
| return false; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema_file.Exists(), | |||||
| "The file " + schema_file_ + " does not exist or permission denied!"); | |||||
| nlohmann::json js; | nlohmann::json js; | ||||
| try { | try { | ||||
| std::ifstream in(schema_file_); | std::ifstream in(schema_file_); | ||||
| in >> js; | in >> js; | ||||
| if (js.find("columns") == js.end()) { | |||||
| MS_LOG(ERROR) << "\"columns\" node is required in the schema json file."; | |||||
| return false; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(), | |||||
| "\"columns\" node is required in the schema json file."); | |||||
| } catch (const std::exception &err) { | } catch (const std::exception &err) { | ||||
| MS_LOG(ERROR) << "Schema file failed to load"; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Schema file failed to load"); | |||||
| } | } | ||||
| return from_json(js); | return from_json(js); | ||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to add a column to schema with a mstype de_type and known shape | |||||
| Status SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) { | |||||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||||
| return add_column(name, data_type.ToString(), shape); | |||||
| } | } | ||||
| // Function to add a column to schema with a mstype de_type | |||||
| bool SchemaObj::add_column(std::string name, TypeId de_type, std::vector<int32_t> shape) { | |||||
| // Function to add a column to schema with a string de_type and known shape | |||||
| Status SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) { | |||||
| DataType data_type(de_type); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown."); | |||||
| nlohmann::json new_column; | nlohmann::json new_column; | ||||
| new_column["name"] = name; | new_column["name"] = name; | ||||
| // if de_type is mstype | |||||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||||
| new_column["type"] = data_type.ToString(); | new_column["type"] = data_type.ToString(); | ||||
| if (shape.size() > 0) { | |||||
| new_column["shape"] = shape; | |||||
| new_column["rank"] = shape.size(); | |||||
| } else { | |||||
| new_column["rank"] = 1; | |||||
| } | |||||
| new_column["shape"] = shape; | |||||
| new_column["rank"] = shape.size(); | |||||
| columns_.push_back(new_column); | columns_.push_back(new_column); | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | |||||
| // Function to add a column to schema with a mstype de_type and without shape | |||||
| Status SchemaObj::add_column(std::string name, TypeId de_type) { | |||||
| DataType data_type = dataset::MSTypeToDEType(de_type); | |||||
| return add_column(name, data_type.ToString()); | |||||
| } | } | ||||
| // Function to add a column to schema with a string de_type | |||||
| bool SchemaObj::add_column(std::string name, std::string de_type, std::vector<int32_t> shape) { | |||||
| // Function to add a column to schema with a string de_type and without shape | |||||
| Status SchemaObj::add_column(std::string name, std::string de_type) { | |||||
| DataType data_type(de_type); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(data_type != DataType::DE_UNKNOWN, "Type is unknown."); | |||||
| nlohmann::json new_column; | nlohmann::json new_column; | ||||
| new_column["name"] = name; | new_column["name"] = name; | ||||
| DataType data_type(de_type); | |||||
| new_column["type"] = data_type.ToString(); | new_column["type"] = data_type.ToString(); | ||||
| if (shape.size() > 0) { | |||||
| new_column["shape"] = shape; | |||||
| new_column["rank"] = shape.size(); | |||||
| } else { | |||||
| new_column["rank"] = 1; | |||||
| } | |||||
| new_column["rank"] = 1; | |||||
| columns_.push_back(new_column); | columns_.push_back(new_column); | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::string SchemaObj::to_json() { | std::string SchemaObj::to_json() { | ||||
| @@ -792,7 +711,7 @@ std::string SchemaObj::to_json() { | |||||
| return json_file.dump(2); | return json_file.dump(2); | ||||
| } | } | ||||
| bool SchemaObj::parse_column(nlohmann::json columns) { | |||||
| Status SchemaObj::parse_column(nlohmann::json columns) { | |||||
| std::string name, de_type; | std::string name, de_type; | ||||
| std::vector<int32_t> shape; | std::vector<int32_t> shape; | ||||
| @@ -802,15 +721,13 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||||
| for (auto column : columns) { | for (auto column : columns) { | ||||
| auto key_name = column.find("name"); | auto key_name = column.find("name"); | ||||
| if (key_name == column.end()) { | if (key_name == column.end()) { | ||||
| MS_LOG(ERROR) << "Column's name is missing"; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Column's name is missing"); | |||||
| } | } | ||||
| name = *key_name; | name = *key_name; | ||||
| auto key_type = column.find("type"); | auto key_type = column.find("type"); | ||||
| if (key_type == column.end()) { | if (key_type == column.end()) { | ||||
| MS_LOG(ERROR) << "Column's type is missing"; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Column's type is missing"); | |||||
| } | } | ||||
| de_type = *key_type; | de_type = *key_type; | ||||
| @@ -819,17 +736,14 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||||
| if (key_shape != column.end()) { | if (key_shape != column.end()) { | ||||
| shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | ||||
| } | } | ||||
| if (!add_column(name, de_type, shape)) { | |||||
| return false; | |||||
| } | |||||
| RETURN_IF_NOT_OK(add_column(name, de_type, shape)); | |||||
| } | } | ||||
| } else if (columns.type() == nlohmann::json::value_t::object) { | } else if (columns.type() == nlohmann::json::value_t::object) { | ||||
| for (const auto &it_child : columns.items()) { | for (const auto &it_child : columns.items()) { | ||||
| name = it_child.key(); | name = it_child.key(); | ||||
| auto key_type = it_child.value().find("type"); | auto key_type = it_child.value().find("type"); | ||||
| if (key_type == it_child.value().end()) { | if (key_type == it_child.value().end()) { | ||||
| MS_LOG(ERROR) << "Column's type is missing"; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Column's type is missing"); | |||||
| } | } | ||||
| de_type = *key_type; | de_type = *key_type; | ||||
| @@ -839,43 +753,45 @@ bool SchemaObj::parse_column(nlohmann::json columns) { | |||||
| shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | shape.insert(shape.end(), (*key_shape).begin(), (*key_shape).end()); | ||||
| } | } | ||||
| if (!add_column(name, de_type, shape)) { | |||||
| return false; | |||||
| } | |||||
| RETURN_IF_NOT_OK(add_column(name, de_type, shape)); | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "columns must be dict or list, columns contain name, type, shape(optional)."; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("columns must be dict or list, columns contain name, type, shape(optional)."); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| bool SchemaObj::from_json(nlohmann::json json_obj) { | |||||
| Status SchemaObj::from_json(nlohmann::json json_obj) { | |||||
| for (const auto &it_child : json_obj.items()) { | for (const auto &it_child : json_obj.items()) { | ||||
| if (it_child.key() == "datasetType") { | if (it_child.key() == "datasetType") { | ||||
| dataset_type_ = it_child.value(); | dataset_type_ = it_child.value(); | ||||
| } else if (it_child.key() == "numRows") { | } else if (it_child.key() == "numRows") { | ||||
| num_rows_ = it_child.value(); | num_rows_ = it_child.value(); | ||||
| } else if (it_child.key() == "columns") { | } else if (it_child.key() == "columns") { | ||||
| if (!parse_column(it_child.value())) { | |||||
| MS_LOG(ERROR) << "parse columns failed"; | |||||
| return false; | |||||
| } | |||||
| RETURN_IF_NOT_OK(parse_column(it_child.value())); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unknown field " << it_child.key(); | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Unknown field " + it_child.key()); | |||||
| } | } | ||||
| } | } | ||||
| if (columns_.empty()) { | if (columns_.empty()) { | ||||
| MS_LOG(ERROR) << "Columns are missing."; | |||||
| return false; | |||||
| RETURN_STATUS_SYNTAX_ERROR("Columns are missing."); | |||||
| } | } | ||||
| if (num_rows_ <= 0) { | |||||
| MS_LOG(ERROR) << "numRows must be greater than 0"; | |||||
| return false; | |||||
| if (num_rows_ < 0) { | |||||
| RETURN_STATUS_SYNTAX_ERROR("numRows must be greater than or equal to 0"); | |||||
| } | } | ||||
| return true; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SchemaObj::FromJSONString(const std::string &json_string) { | |||||
| try { | |||||
| nlohmann::json js = nlohmann::json::parse(json_string); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(js.find("columns") != js.end(), | |||||
| "\"columns\" node is required in the schema json JSON."); | |||||
| RETURN_IF_NOT_OK(from_json(js)); | |||||
| } catch (const std::exception &err) { | |||||
| RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // OTHER FUNCTIONS | // OTHER FUNCTIONS | ||||
| @@ -1,136 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| #include "pybind11/stl_bind.h" | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | |||||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| PYBIND_REGISTER( | |||||
| DEPipeline, 0, ([](const py::module *m) { | |||||
| (void)py::class_<DEPipeline>(*m, "DEPipeline") | |||||
| .def(py::init<>()) | |||||
| .def( | |||||
| "AddNodeToTree", | |||||
| [](DEPipeline &de, const OpName &op_name, const py::dict &args) { | |||||
| py::dict out; | |||||
| THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); | |||||
| return out; | |||||
| }, | |||||
| py::return_value_policy::reference) | |||||
| .def_static("AddChildToParentNode", | |||||
| [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { | |||||
| THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); | |||||
| }) | |||||
| .def("AssignRootNode", | |||||
| [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) | |||||
| .def("SetBatchParameters", | |||||
| [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) | |||||
| .def("PrepareTree", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.PrepareTree(num_epochs)); }) | |||||
| .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) | |||||
| .def("GetColumnNames", | |||||
| [](DEPipeline &de) { | |||||
| py::list out; | |||||
| THROW_IF_ERROR(de.GetColumnNames(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetNextAsMap", | |||||
| [](DEPipeline &de) { | |||||
| py::dict out; | |||||
| THROW_IF_ERROR(de.GetNextAsMap(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetNextAsList", | |||||
| [](DEPipeline &de) { | |||||
| py::list out; | |||||
| THROW_IF_ERROR(de.GetNextAsList(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetOutputShapes", | |||||
| [](DEPipeline &de) { | |||||
| py::list out; | |||||
| THROW_IF_ERROR(de.GetOutputShapes(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetOutputTypes", | |||||
| [](DEPipeline &de) { | |||||
| py::list out; | |||||
| THROW_IF_ERROR(de.GetOutputTypes(&out)); | |||||
| return out; | |||||
| }) | |||||
| .def("GetDataInfo", | |||||
| [](DEPipeline &de) { | |||||
| py::list types, shapes; | |||||
| THROW_IF_ERROR(de.GetDataInfo(&types, &shapes)); | |||||
| return py::make_tuple(types, shapes); | |||||
| }) | |||||
| .def("GetDatasetSize", &DEPipeline::GetDatasetSize) | |||||
| .def("GetBatchSize", &DEPipeline::GetBatchSize) | |||||
| .def("GetNumClasses", &DEPipeline::GetNumClasses) | |||||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | |||||
| .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) | |||||
| .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); }) | |||||
| .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | |||||
| THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | |||||
| return true; | |||||
| }); | |||||
| })); | |||||
| PYBIND_REGISTER(OpName, 0, ([](const py::module *m) { | |||||
| (void)py::enum_<OpName>(*m, "OpName", py::arithmetic()) | |||||
| .value("SHUFFLE", OpName::kShuffle) | |||||
| .value("BATCH", OpName::kBatch) | |||||
| .value("BUCKETBATCH", OpName::kBucketBatch) | |||||
| .value("BARRIER", OpName::kBarrier) | |||||
| .value("MINDRECORD", OpName::kMindrecord) | |||||
| .value("CACHE", OpName::kCache) | |||||
| .value("REPEAT", OpName::kRepeat) | |||||
| .value("SKIP", OpName::kSkip) | |||||
| .value("TAKE", OpName::kTake) | |||||
| .value("ZIP", OpName::kZip) | |||||
| .value("CONCAT", OpName::kConcat) | |||||
| .value("MAP", OpName::kMap) | |||||
| .value("FILTER", OpName::kFilter) | |||||
| .value("DEVICEQUEUE", OpName::kDeviceQueue) | |||||
| .value("GENERATOR", OpName::kGenerator) | |||||
| .export_values() | |||||
| .value("RENAME", OpName::kRename) | |||||
| .value("TFREADER", OpName::kTfReader) | |||||
| .value("PROJECT", OpName::kProject) | |||||
| .value("IMAGEFOLDER", OpName::kImageFolder) | |||||
| .value("MNIST", OpName::kMnist) | |||||
| .value("MANIFEST", OpName::kManifest) | |||||
| .value("VOC", OpName::kVoc) | |||||
| .value("COCO", OpName::kCoco) | |||||
| .value("CIFAR10", OpName::kCifar10) | |||||
| .value("CIFAR100", OpName::kCifar100) | |||||
| .value("RANDOMDATA", OpName::kRandomData) | |||||
| .value("BUILDVOCAB", OpName::kBuildVocab) | |||||
| .value("SENTENCEPIECEVOCAB", OpName::kSentencePieceVocab) | |||||
| .value("CELEBA", OpName::kCelebA) | |||||
| .value("TEXTFILE", OpName::kTextFile) | |||||
| .value("EPOCHCTRL", OpName::kEpochCtrl) | |||||
| .value("CSV", OpName::kCsv) | |||||
| .value("CLUE", OpName::kClue); | |||||
| })); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -19,8 +19,10 @@ | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | #include "minddata/dataset/api/python/pybind_register.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/core/client.h" // DE client | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "pybind11/numpy.h" | |||||
| #include "minddata/dataset/core/constants.h" | #include "minddata/dataset/core/constants.h" | ||||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -0,0 +1,551 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| #include "pybind11/stl_bind.h" | |||||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | |||||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| // IR non-leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/batch_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/concat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/filter_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/map_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/project_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/rename_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/repeat_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/skip_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/take_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/transfer_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/zip_node.h" | |||||
| // IR non-leaf nodes - for android | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/sync_wait_node.h" | |||||
| #endif | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/data_type.h" | |||||
| #include "minddata/dataset/util/path.h" | |||||
| #include "minddata/dataset/util/random.h" | |||||
| #include "minddata/dataset/util/services.h" | |||||
| // IR leaf nodes | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/clue_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/coco_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||||
| // IR leaf nodes disabled for android | |||||
| #ifndef ENABLE_ANDROID | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| PYBIND_REGISTER(DatasetNode, 1, ([](const py::module *m) { | |||||
| (void)py::class_<DatasetNode, std::shared_ptr<DatasetNode>>(*m, "Dataset") | |||||
| .def("SetNumWorkers", | |||||
| [](std::shared_ptr<DatasetNode> self, std::optional<int32_t> num_workers) { | |||||
| return num_workers ? self->SetNumWorkers(*num_workers) : self; | |||||
| }) | |||||
| .def( | |||||
| "Zip", | |||||
| [](std::shared_ptr<DatasetNode> self, py::list datasets) { | |||||
| auto zip = std::make_shared<ZipNode>(std::move(toDatasetNode(self, datasets))); | |||||
| THROW_IF_ERROR(zip->ValidateParams()); | |||||
| return zip; | |||||
| }, | |||||
| py::arg("datasets")); | |||||
| })); | |||||
| // PYBIND FOR LEAF NODES | |||||
| // (In alphabetical order) | |||||
| PYBIND_REGISTER( | |||||
| CelebANode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<CelebANode, DatasetNode, std::shared_ptr<CelebANode>>(*m, "CelebANode", "to create a CelebANode") | |||||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, bool decode, | |||||
| std::optional<py::list> extensions, std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto celebA = std::make_shared<CelebANode>(dataset_dir, usage, toSamplerObj(sampler), decode, | |||||
| toStringSet(extensions), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(celebA->ValidateParams()); | |||||
| return celebA; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(Cifar10Node, 2, ([](const py::module *m) { | |||||
| (void)py::class_<Cifar10Node, DatasetNode, std::shared_ptr<Cifar10Node>>(*m, "Cifar10Node", | |||||
| "to create a Cifar10Node") | |||||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto cifar10 = std::make_shared<Cifar10Node>(dataset_dir, usage, toSamplerObj(sampler), | |||||
| toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(cifar10->ValidateParams()); | |||||
| return cifar10; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(Cifar100Node, 2, ([](const py::module *m) { | |||||
| (void)py::class_<Cifar100Node, DatasetNode, std::shared_ptr<Cifar100Node>>(*m, "Cifar100Node", | |||||
| "to create a Cifar100Node") | |||||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto cifar100 = std::make_shared<Cifar100Node>(dataset_dir, usage, toSamplerObj(sampler), | |||||
| toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(cifar100->ValidateParams()); | |||||
| return cifar100; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| CLUENode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<CLUENode, DatasetNode, std::shared_ptr<CLUENode>>(*m, "CLUENode", "to create a CLUENode") | |||||
| .def(py::init([](py::list files, std::string task, std::string usage, int64_t num_samples, int32_t shuffle, | |||||
| int32_t num_shards, int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| std::shared_ptr<CLUENode> clue_node = | |||||
| std::make_shared<dataset::CLUENode>(toStringVector(files), task, usage, num_samples, toShuffleMode(shuffle), | |||||
| num_shards, shard_id, toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(clue_node->ValidateParams()); | |||||
| return clue_node; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| CocoNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode", "to create a CocoNode") | |||||
| .def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task, bool decode, | |||||
| std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>( | |||||
| dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(coco->ValidateParams()); | |||||
| return coco; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(CSVNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<CSVNode, DatasetNode, std::shared_ptr<CSVNode>>(*m, "CSVNode", "to create a CSVNode") | |||||
| .def(py::init([](std::vector<std::string> csv_files, char field_delim, py::list column_defaults, | |||||
| std::vector<std::string> column_names, int64_t num_samples, int32_t shuffle, | |||||
| int32_t num_shards, int32_t shard_id, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto csv = std::make_shared<CSVNode>(csv_files, field_delim, toCSVBase(column_defaults), | |||||
| column_names, num_samples, toShuffleMode(shuffle), | |||||
| num_shards, shard_id, toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(csv->ValidateParams()); | |||||
| return csv; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<GeneratorNode, DatasetNode, std::shared_ptr<GeneratorNode>>( | |||||
| *m, "GeneratorNode", "to create a GeneratorNode") | |||||
| .def(py::init([](py::function generator_function, const std::vector<std::string> &column_names, | |||||
| const std::vector<DataType> &column_types) { | |||||
| auto gen = std::make_shared<GeneratorNode>(generator_function, column_names, column_types); | |||||
| THROW_IF_ERROR(gen->ValidateParams()); | |||||
| return gen; | |||||
| })) | |||||
| .def(py::init([](py::function generator_function, const std::shared_ptr<SchemaObj> schema) { | |||||
| auto gen = std::make_shared<GeneratorNode>(generator_function, schema); | |||||
| THROW_IF_ERROR(gen->ValidateParams()); | |||||
| return gen; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( | |||||
| *m, "ImageFolderNode", "to create an ImageFolderNode") | |||||
| .def(py::init([](std::string dataset_dir, bool decode, std::optional<py::handle> sampler, | |||||
| std::optional<py::list> extensions, std::optional<py::dict> class_indexing, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| bool recursive = true; | |||||
| auto imagefolder = std::make_shared<ImageFolderNode>( | |||||
| dataset_dir, decode, toSamplerObj(sampler), recursive, toStringSet(extensions), | |||||
| toStringMap(class_indexing), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(imagefolder->ValidateParams()); | |||||
| return imagefolder; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ManifestNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ManifestNode, DatasetNode, std::shared_ptr<ManifestNode>>(*m, "ManifestNode", | |||||
| "to create a ManifestNode") | |||||
| .def(py::init([](std::string dataset_file, std::string usage, std::optional<py::handle> sampler, | |||||
| std::optional<py::dict> class_indexing, bool decode, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto manifest = std::make_shared<ManifestNode>(dataset_file, usage, toSamplerObj(sampler), | |||||
| toStringMap(class_indexing), decode, | |||||
| toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(manifest->ValidateParams()); | |||||
| return manifest; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode", | |||||
| "to create a MindDataNode") | |||||
| .def(py::init([](std::string dataset_file, std::optional<py::list> columns_list, | |||||
| std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) { | |||||
| nlohmann::json padded_sample_json; | |||||
| std::map<std::string, std::string> sample_bytes; | |||||
| THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); | |||||
| auto minddata = | |||||
| std::make_shared<MindDataNode>(dataset_file, toStringVector(columns_list), | |||||
| toSamplerObj(sampler, true), padded_sample_json, num_padded); | |||||
| minddata->SetSampleBytes(&sample_bytes); | |||||
| THROW_IF_ERROR(minddata->ValidateParams()); | |||||
| return minddata; | |||||
| })) | |||||
| .def(py::init([](py::list dataset_file, std::optional<py::list> columns_list, | |||||
| std::optional<py::handle> sampler, py::dict padded_sample, int64_t num_padded) { | |||||
| nlohmann::json padded_sample_json; | |||||
| std::map<std::string, std::string> sample_bytes; | |||||
| THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes)); | |||||
| auto minddata = | |||||
| std::make_shared<MindDataNode>(toStringVector(dataset_file), toStringVector(columns_list), | |||||
| toSamplerObj(sampler, true), padded_sample_json, num_padded); | |||||
| minddata->SetSampleBytes(&sample_bytes); | |||||
| THROW_IF_ERROR(minddata->ValidateParams()); | |||||
| return minddata; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(MnistNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<MnistNode, DatasetNode, std::shared_ptr<MnistNode>>(*m, "MnistNode", | |||||
| "to create an MnistNode") | |||||
| .def(py::init([](std::string dataset_dir, std::string usage, std::optional<py::handle> sampler, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto mnist = std::make_shared<MnistNode>(dataset_dir, usage, toSamplerObj(sampler), | |||||
| toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(mnist->ValidateParams()); | |||||
| return mnist; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| RandomNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<RandomNode, DatasetNode, std::shared_ptr<RandomNode>>(*m, "RandomNode", "to create a RandomNode") | |||||
| .def(py::init([](int32_t total_rows, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto random_node = | |||||
| std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(random_node->ValidateParams()); | |||||
| return random_node; | |||||
| })) | |||||
| .def(py::init([](int32_t total_rows, std::string schema, std::optional<py::list> columns_list, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| auto random_node = | |||||
| std::make_shared<RandomNode>(total_rows, schema, toStringVector(columns_list), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(random_node->ValidateParams()); | |||||
| return random_node; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(TextFileNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<TextFileNode, DatasetNode, std::shared_ptr<TextFileNode>>(*m, "TextFileNode", | |||||
| "to create a TextFileNode") | |||||
| .def(py::init([](py::list dataset_files, int32_t num_samples, int32_t shuffle, int32_t num_shards, | |||||
| int32_t shard_id, std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| std::shared_ptr<TextFileNode> textfile_node = std::make_shared<TextFileNode>( | |||||
| toStringVector(dataset_files), num_samples, toShuffleMode(shuffle), num_shards, shard_id, | |||||
| toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(textfile_node->ValidateParams()); | |||||
| return textfile_node; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER( | |||||
| TFRecordNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<TFRecordNode, DatasetNode, std::shared_ptr<TFRecordNode>>(*m, "TFRecordNode", | |||||
| "to create a TFRecordNode") | |||||
| .def(py::init([](py::list dataset_files, std::shared_ptr<SchemaObj> schema, std::optional<py::list> columns_list, | |||||
| std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards, | |||||
| std::optional<int32_t> shard_id, bool shard_equal_rows, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| if (!num_samples) { | |||||
| *num_samples = 0; | |||||
| } | |||||
| std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( | |||||
| toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle), | |||||
| *num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(tfrecord->ValidateParams()); | |||||
| return tfrecord; | |||||
| })) | |||||
| .def(py::init([](py::list dataset_files, std::string schema, std::optional<py::list> columns_list, | |||||
| std::optional<int64_t> num_samples, int32_t shuffle, std::optional<int32_t> num_shards, | |||||
| std::optional<int32_t> shard_id, bool shard_equal_rows, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| if (!num_samples) { | |||||
| *num_samples = 0; | |||||
| } | |||||
| std::shared_ptr<TFRecordNode> tfrecord = std::make_shared<TFRecordNode>( | |||||
| toStringVector(dataset_files), schema, toStringVector(columns_list), *num_samples, toShuffleMode(shuffle), | |||||
| *num_shards, *shard_id, shard_equal_rows, toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(tfrecord->ValidateParams()); | |||||
| return tfrecord; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode") | |||||
| .def( | |||||
| py::init([](std::string dataset_dir, std::string task, std::string usage, | |||||
| std::optional<py::dict> class_indexing, bool decode, | |||||
| std::optional<py::handle> sampler, std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| std::shared_ptr<VOCNode> voc = | |||||
| std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode, | |||||
| toSamplerObj(sampler), toDatasetCache(std::move(cc))); | |||||
| THROW_IF_ERROR(voc->ValidateParams()); | |||||
| return voc; | |||||
| })); | |||||
| })); | |||||
| // PYBIND FOR NON-LEAF NODES | |||||
| // (In alphabetical order) | |||||
| PYBIND_REGISTER(BatchNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<BatchNode, DatasetNode, std::shared_ptr<BatchNode>>(*m, "BatchNode", | |||||
| "to create a BatchNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t batch_size, bool drop_remainder, | |||||
| bool pad, py::list in_col_names, py::list out_col_names, py::list col_order, | |||||
| py::object size_obj, py::object map_obj, py::dict pad_info) { | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; | |||||
| if (pad) { | |||||
| THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); | |||||
| } | |||||
| py::function size_func = | |||||
| py::isinstance<py::function>(size_obj) ? size_obj.cast<py::function>() : py::function(); | |||||
| py::function map_func = | |||||
| py::isinstance<py::function>(map_obj) ? map_obj.cast<py::function>() : py::function(); | |||||
| auto batch = std::make_shared<BatchNode>( | |||||
| self, batch_size, drop_remainder, pad, toStringVector(in_col_names), | |||||
| toStringVector(out_col_names), toStringVector(col_order), size_func, map_func, c_pad_info); | |||||
| THROW_IF_ERROR(batch->ValidateParams()); | |||||
| return batch; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(BucketBatchByLengthNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<BucketBatchByLengthNode, DatasetNode, std::shared_ptr<BucketBatchByLengthNode>>( | |||||
| *m, "BucketBatchByLengthNode", "to create a BucketBatchByLengthNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> dataset, py::list column_names, | |||||
| std::vector<int32_t> bucket_boundaries, std::vector<int32_t> bucket_batch_sizes, | |||||
| py::object element_length_function, py::dict pad_info, bool pad_to_bucket_boundary, | |||||
| bool drop_remainder) { | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> c_pad_info; | |||||
| THROW_IF_ERROR(toPadInfo(pad_info, &c_pad_info)); | |||||
| auto bucket_batch = std::make_shared<BucketBatchByLengthNode>( | |||||
| dataset, toStringVector(column_names), bucket_boundaries, bucket_batch_sizes, | |||||
| toPyFuncOp(std::move(element_length_function), DataType::DE_INT32), c_pad_info, | |||||
| pad_to_bucket_boundary, drop_remainder); | |||||
| THROW_IF_ERROR(bucket_batch->ValidateParams()); | |||||
| return bucket_batch; | |||||
| }), | |||||
| py::arg("dataset"), py::arg("column_names"), py::arg("bucket_boundaries"), | |||||
| py::arg("bucket_batch_sizes"), py::arg("element_length_function") = py::none(), | |||||
| py::arg("pad_info"), py::arg("pad_to_bucket_boundary"), py::arg("drop_remainder")); | |||||
| })); | |||||
| PYBIND_REGISTER(BuildSentenceVocabNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<BuildSentenceVocabNode, DatasetNode, std::shared_ptr<BuildSentenceVocabNode>>( | |||||
| *m, "BuildSentenceVocabNode", "to create a BuildSentenceVocabNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<SentencePieceVocab> vocab, | |||||
| const std::vector<std::string> &col_names, uint32_t vocab_size, | |||||
| float character_coverage, SentencePieceModel model_type, | |||||
| const std::unordered_map<std::string, std::string> ¶ms) { | |||||
| auto build_sentence_vocab = std::make_shared<BuildSentenceVocabNode>( | |||||
| self, vocab, col_names, vocab_size, character_coverage, model_type, params); | |||||
| THROW_IF_ERROR(build_sentence_vocab->ValidateParams()); | |||||
| return build_sentence_vocab; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(BuildVocabNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<BuildVocabNode, DatasetNode, std::shared_ptr<BuildVocabNode>>( | |||||
| *m, "BuildVocabNode", "to create a BuildVocabNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::shared_ptr<Vocab> vocab, py::list columns, | |||||
| py::tuple freq_range, int64_t top_k, py::list special_tokens, bool special_first) { | |||||
| auto build_vocab = | |||||
| std::make_shared<BuildVocabNode>(self, vocab, toStringVector(columns), toIntPair(freq_range), | |||||
| top_k, toStringVector(special_tokens), special_first); | |||||
| THROW_IF_ERROR(build_vocab->ValidateParams()); | |||||
| return build_vocab; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ConcatNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ConcatNode, DatasetNode, std::shared_ptr<ConcatNode>>(*m, "ConcatNode", | |||||
| "to create a ConcatNode") | |||||
| .def( | |||||
| py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets, std::optional<py::handle> sampler, | |||||
| py::list children_flag_and_nums, py::list children_start_end_index) { | |||||
| auto concat = std::make_shared<ConcatNode>(datasets, toSamplerObj(sampler), | |||||
| toPairVector(children_flag_and_nums), | |||||
| toPairVector(children_start_end_index)); | |||||
| THROW_IF_ERROR(concat->ValidateParams()); | |||||
| return concat; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(FilterNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<FilterNode, DatasetNode, std::shared_ptr<FilterNode>>(*m, "FilterNode", | |||||
| "to create a FilterNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, py::object predicate, | |||||
| std::vector<std::string> input_columns) { | |||||
| auto filter = | |||||
| std::make_shared<FilterNode>(self, toPyFuncOp(predicate, DataType::DE_BOOL), input_columns); | |||||
| THROW_IF_ERROR(filter->ValidateParams()); | |||||
| return filter; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(MapNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<MapNode, DatasetNode, std::shared_ptr<MapNode>>(*m, "MapNode", "to create a MapNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> operations, | |||||
| std::optional<py::list> input_columns, std::optional<py::list> output_columns, | |||||
| std::optional<py::list> project_columns, | |||||
| std::optional<std::shared_ptr<CacheClient>> cc, | |||||
| std::vector<std::shared_ptr<PyDSCallback>> py_callbacks) { | |||||
| auto map = std::make_shared<MapNode>( | |||||
| self, std::move(toTensorOperations(operations)), toStringVector(input_columns), | |||||
| toStringVector(output_columns), toStringVector(project_columns), toDatasetCache(std::move(cc)), | |||||
| std::vector<std::shared_ptr<DSCallback>>(py_callbacks.begin(), py_callbacks.end())); | |||||
| THROW_IF_ERROR(map->ValidateParams()); | |||||
| return map; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ProjectNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ProjectNode, DatasetNode, std::shared_ptr<ProjectNode>>(*m, "ProjectNode", | |||||
| "to create a ProjectNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, py::list columns) { | |||||
| auto project = std::make_shared<ProjectNode>(self, toStringVector(columns)); | |||||
| THROW_IF_ERROR(project->ValidateParams()); | |||||
| return project; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(RenameNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<RenameNode, DatasetNode, std::shared_ptr<RenameNode>>(*m, "RenameNode", | |||||
| "to create a RenameNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::optional<py::list> input_columns, | |||||
| std::optional<py::list> output_columns) { | |||||
| auto rename = std::make_shared<RenameNode>(self, toStringVector(input_columns), | |||||
| toStringVector(output_columns)); | |||||
| THROW_IF_ERROR(rename->ValidateParams()); | |||||
| return rename; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(RepeatNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<RepeatNode, DatasetNode, std::shared_ptr<RepeatNode>>(*m, "RepeatNode", | |||||
| "to create a RepeatNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> input, int32_t count) { | |||||
| auto repeat = std::make_shared<RepeatNode>(input, count); | |||||
| THROW_IF_ERROR(repeat->ValidateParams()); | |||||
| return repeat; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ShuffleNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ShuffleNode, DatasetNode, std::shared_ptr<ShuffleNode>>(*m, "ShuffleNode", | |||||
| "to create a ShuffleNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t shuffle_size, bool reset_every_epoch) { | |||||
| auto shuffle = std::make_shared<ShuffleNode>(self, shuffle_size, reset_every_epoch); | |||||
| THROW_IF_ERROR(shuffle->ValidateParams()); | |||||
| return shuffle; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(SkipNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<SkipNode, DatasetNode, std::shared_ptr<SkipNode>>(*m, "SkipNode", | |||||
| "to create a SkipNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { | |||||
| auto skip = std::make_shared<SkipNode>(self, count); | |||||
| THROW_IF_ERROR(skip->ValidateParams()); | |||||
| return skip; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(SyncWaitNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<SyncWaitNode, DatasetNode, std::shared_ptr<SyncWaitNode>>(*m, "SyncWaitNode", | |||||
| "to create a SyncWaitNode") | |||||
| .def( | |||||
| py::init([](std::shared_ptr<DatasetNode> self, std::string condition_name, py::object callback) { | |||||
| py::function callback_func = | |||||
| py::isinstance<py::function>(callback) ? callback.cast<py::function>() : py::function(); | |||||
| auto sync_wait = std::make_shared<SyncWaitNode>(self, condition_name, callback); | |||||
| THROW_IF_ERROR(sync_wait->ValidateParams()); | |||||
| return sync_wait; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(TakeNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<TakeNode, DatasetNode, std::shared_ptr<TakeNode>>(*m, "TakeNode", | |||||
| "to create a TakeNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, int32_t count) { | |||||
| auto take = std::make_shared<TakeNode>(self, count); | |||||
| THROW_IF_ERROR(take->ValidateParams()); | |||||
| return take; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(TransferNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<TransferNode, DatasetNode, std::shared_ptr<TransferNode>>(*m, "TransferNode", | |||||
| "to create a TransferNode") | |||||
| .def(py::init([](std::shared_ptr<DatasetNode> self, std::string queue_name, std::string device_type, | |||||
| bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) { | |||||
| auto transfer = std::make_shared<TransferNode>(self, queue_name, device_type, send_epoch_end, | |||||
| total_batch, create_data_info_queue); | |||||
| THROW_IF_ERROR(transfer->ValidateParams()); | |||||
| return transfer; | |||||
| })); | |||||
| })); | |||||
| PYBIND_REGISTER(ZipNode, 2, ([](const py::module *m) { | |||||
| (void)py::class_<ZipNode, DatasetNode, std::shared_ptr<ZipNode>>(*m, "ZipNode", "to create a ZipNode") | |||||
| .def(py::init([](std::vector<std::shared_ptr<DatasetNode>> datasets) { | |||||
| auto zip = std::make_shared<ZipNode>(datasets); | |||||
| THROW_IF_ERROR(zip->ValidateParams()); | |||||
| return zip; | |||||
| })); | |||||
| })); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,168 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | |||||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||||
| #include "minddata/dataset/engine/python_runtime_context.h" | |||||
| #include "minddata/dataset/engine/consumers/python_tree_consumer.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| PYBIND_REGISTER(TreeConsumer, 0, ([](const py::module *m) { | |||||
| (void)py::class_<TreeConsumer, std::shared_ptr<TreeConsumer>>(*m, "TreeConsumer"); | |||||
| })); | |||||
| PYBIND_REGISTER(PythonIteratorConsumer, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonIteratorConsumer, TreeConsumer, std::shared_ptr<PythonIteratorConsumer>>( | |||||
| *m, "PythonIteratorConsumer") | |||||
| .def(py::init<int32_t>()) | |||||
| .def("Init", [](PythonIteratorConsumer &self, | |||||
| std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||||
| .def("GetNextAsMap", | |||||
| [](PythonIteratorConsumer &self) { | |||||
| py::dict output; | |||||
| THROW_IF_ERROR(self.GetNextAsDict(&output)); | |||||
| return output; | |||||
| }) | |||||
| .def("GetNextAsList", [](PythonIteratorConsumer &self) { | |||||
| py::list output; | |||||
| THROW_IF_ERROR(self.GetNextAsList(&output)); | |||||
| return output; | |||||
| }); | |||||
| })); | |||||
| PYBIND_REGISTER(TreeGetters, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonTreeGetters, TreeConsumer, std::shared_ptr<PythonTreeGetters>>(*m, | |||||
| "TreeGetters") | |||||
| .def(py::init<>()) | |||||
| .def("Init", | |||||
| [](PythonTreeGetters &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||||
| .def("GetOutputShapes", | |||||
| [](PythonTreeGetters &self) { | |||||
| std::vector<TensorShape> shapes; | |||||
| THROW_IF_ERROR(self.GetOutputShapes(&shapes)); | |||||
| return shapesToListOfShape(shapes); | |||||
| }) | |||||
| .def("GetOutputTypes", | |||||
| [](PythonTreeGetters &self) { | |||||
| std::vector<DataType> types; | |||||
| THROW_IF_ERROR(self.GetOutputTypes(&types)); | |||||
| return typesToListOfType(types); | |||||
| }) | |||||
| .def("GetNumClasses", | |||||
| [](PythonTreeGetters &self) { | |||||
| int64_t num_classes; | |||||
| THROW_IF_ERROR(self.GetNumClasses(&num_classes)); | |||||
| return num_classes; | |||||
| }) | |||||
| .def("GetRepeatCount", | |||||
| [](PythonTreeGetters &self) { | |||||
| int64_t repeat_count; | |||||
| THROW_IF_ERROR(self.GetRepeatCount(&repeat_count)); | |||||
| return repeat_count; | |||||
| }) | |||||
| .def("GetBatchSize", | |||||
| [](PythonTreeGetters &self) { | |||||
| int64_t batch_size; | |||||
| THROW_IF_ERROR(self.GetBatchSize(&batch_size)); | |||||
| return batch_size; | |||||
| }) | |||||
| .def("GetColumnNames", | |||||
| [](PythonTreeGetters &self) { | |||||
| std::vector<std::string> col_names; | |||||
| THROW_IF_ERROR(self.GetColumnNames(&col_names)); | |||||
| return col_names; | |||||
| }) | |||||
| .def("GetClassIndexing", | |||||
| [](PythonTreeGetters &self) { | |||||
| std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing; | |||||
| THROW_IF_ERROR(self.GetClassIndexing(&output_class_indexing)); | |||||
| return output_class_indexing; | |||||
| }) | |||||
| .def("GetDatasetSize", | |||||
| [](PythonTreeGetters &self) { | |||||
| int64_t dataset_size; | |||||
| THROW_IF_ERROR(self.GetDatasetSize(&dataset_size)); | |||||
| return dataset_size; | |||||
| }) | |||||
| .def("__deepcopy__", [](py::object &tree_getter, py::dict memo) { return tree_getter; }); | |||||
| })); | |||||
| PYBIND_REGISTER(PythonRuntimeContext, 2, ([](const py::module *m) { | |||||
| (void)py::class_<PythonRuntimeContext, std::shared_ptr<PythonRuntimeContext>>(*m, | |||||
| "PythonRuntimeContext") | |||||
| .def(py::init<>()) | |||||
| .def("Init", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Init()); }) | |||||
| .def("AssignConsumer", &PythonRuntimeContext::AssignConsumer) | |||||
| .def("Terminate", [](PythonRuntimeContext &self) { THROW_IF_ERROR(self.Terminate()); }) | |||||
| .def("GetConsumer", &PythonRuntimeContext::GetPythonConsumer, py::return_value_policy::reference) | |||||
| .def("__deepcopy__", [](py::object &runtime_context, py::dict memo) { return runtime_context; }); | |||||
| })); | |||||
| PYBIND_REGISTER(PythonBuildVocabConsumer, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonBuildVocabConsumer, TreeConsumer, std::shared_ptr<PythonBuildVocabConsumer>>( | |||||
| *m, "PythonBuildVocabConsumer") | |||||
| .def(py::init<>()) | |||||
| .def("Init", [](PythonBuildVocabConsumer &self, | |||||
| std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||||
| .def("Start", [](PythonBuildVocabConsumer &self) { THROW_IF_ERROR(self.Start()); }); | |||||
| })); | |||||
| PYBIND_REGISTER(ToDevice, 1, ([](const py::module *m) { | |||||
| (void)py::class_<ToDevice, TreeConsumer, std::shared_ptr<ToDevice>>(*m, "ToDevice") | |||||
| .def(py::init<int32_t>()) | |||||
| .def("Init", [](ToDevice &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||||
| .def("Send", [](ToDevice &self) { THROW_IF_ERROR(self.Send()); }) | |||||
| .def("ContinueSend", [](ToDevice &self) { THROW_IF_ERROR(self.Continue()); }) | |||||
| .def("StopSend", [](ToDevice &self) { THROW_IF_ERROR(self.Stop()); }) | |||||
| .def("GetDataInfo", | |||||
| [](ToDevice &self) { | |||||
| std::vector<DataType> types_c; | |||||
| std::vector<TensorShape> shapes_c; | |||||
| { | |||||
| py::gil_scoped_release rel; | |||||
| THROW_IF_ERROR(self.GetDataInfo(&types_c, &shapes_c)); | |||||
| } | |||||
| py::list types, shapes; | |||||
| for (auto el : types_c) { | |||||
| types.append(el.AsNumpyType()); | |||||
| py::list shape; | |||||
| } | |||||
| for (auto el : shapes_c) { | |||||
| py::list shape = el.AsPyList(); | |||||
| shapes.append(shape); | |||||
| } | |||||
| return py::make_tuple(types, shapes); | |||||
| }) | |||||
| .def("__deepcopy__", [](py::object &to_device, py::dict memo) { return to_device; }); | |||||
| })); | |||||
| PYBIND_REGISTER(PythonSaveToDisk, 1, ([](const py::module *m) { | |||||
| (void)py::class_<PythonSaveToDisk, TreeConsumer, std::shared_ptr<PythonSaveToDisk>>( | |||||
| *m, "PythonSaveToDisk") | |||||
| .def(py::init([](std::string &dataset_path, int32_t numFiles, std::string &datasetType) { | |||||
| auto save = std::make_shared<PythonSaveToDisk>(dataset_path, numFiles, datasetType); | |||||
| THROW_IF_ERROR(save->ValidateParams()); | |||||
| return save; | |||||
| })) | |||||
| .def("Init", | |||||
| [](PythonSaveToDisk &self, std::shared_ptr<DatasetNode> d) { THROW_IF_ERROR(self.Init(d)); }) | |||||
| .def("Save", [](PythonSaveToDisk &self) { THROW_IF_ERROR(self.Save()); }); | |||||
| })); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| #include "pybind11/stl_bind.h" | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| PYBIND_REGISTER( | |||||
| SchemaObj, 0, ([](const py::module *m) { | |||||
| (void)py::class_<SchemaObj, std::shared_ptr<SchemaObj>>(*m, "SchemaObj", "to create a SchemaObj") | |||||
| .def(py::init([](std::string schema_file) { | |||||
| auto schema = std::make_shared<SchemaObj>(schema_file); | |||||
| THROW_IF_ERROR(schema->init()); | |||||
| return schema; | |||||
| })) | |||||
| .def("add_column", [](SchemaObj &self, std::string name, TypeId de_type, | |||||
| std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); }) | |||||
| .def("add_column", [](SchemaObj &self, std::string name, std::string de_type, | |||||
| std::vector<int32_t> shape) { THROW_IF_ERROR(self.add_column(name, de_type, shape)); }) | |||||
| .def("add_column", | |||||
| [](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | |||||
| .def("add_column", [](SchemaObj &self, std::string name, | |||||
| std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | |||||
| .def("to_json", &SchemaObj::to_json) | |||||
| .def("to_string", &SchemaObj::to_string) | |||||
| .def("from_string", | |||||
| [](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.FromJSONString(json_string)); }) | |||||
| .def("set_dataset_type", [](SchemaObj &self, std::string dataset_type) { self.set_dataset_type(dataset_type); }) | |||||
| .def("set_num_rows", [](SchemaObj &self, int32_t num_rows) { self.set_num_rows(num_rows); }) | |||||
| .def("get_num_rows", &SchemaObj::get_num_rows) | |||||
| .def("__deepcopy__", [](py::object &schema, py::dict memo) { return schema; }); | |||||
| })); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -17,7 +17,6 @@ | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | #include "minddata/dataset/api/python/pybind_register.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/api/python/de_pipeline.h" | |||||
| #include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h" | #include "mindspore/ccsrc/minddata/dataset/kernels/data/compose_op.h" | ||||
| #include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h" | #include "mindspore/ccsrc/minddata/dataset/kernels/data/no_op.h" | ||||
| @@ -1,265 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <stack> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "minddata/dataset/core/client.h" // DE client | |||||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||||
| #include "minddata/dataset/util/status.h" | |||||
| #include "pybind11/numpy.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| using json = nlohmann::json; | |||||
| using DsOpPtr = std::shared_ptr<DatasetOp>; | |||||
| class CacheClient; | |||||
| // enum for the dataset operator names | |||||
| enum OpName { | |||||
| kShuffle, | |||||
| kMindrecord, | |||||
| kBatch, | |||||
| kBucketBatch, | |||||
| kBarrier, | |||||
| kCache, | |||||
| kRepeat, | |||||
| kSkip, | |||||
| kTake, | |||||
| kZip, | |||||
| kConcat, | |||||
| kMap, | |||||
| kFilter, | |||||
| kDeviceQueue, | |||||
| kGenerator, | |||||
| kRename, | |||||
| kTfReader, | |||||
| kProject, | |||||
| kImageFolder, | |||||
| kMnist, | |||||
| kManifest, | |||||
| kVoc, | |||||
| kCoco, | |||||
| kCifar10, | |||||
| kCifar100, | |||||
| kCelebA, | |||||
| kRandomData, | |||||
| kTextFile, | |||||
| kBuildVocab, | |||||
| kClue, | |||||
| kEpochCtrl, | |||||
| kSentencePieceVocab, | |||||
| kCsv | |||||
| }; | |||||
| // The C++ binder class that we expose to the python script. | |||||
| class DEPipeline { | |||||
| public: | |||||
| DEPipeline(); | |||||
| ~DEPipeline(); | |||||
| // Function to add a Node to the Execution Tree. | |||||
| Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); | |||||
| // Function to add a child and parent relationship. | |||||
| static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); | |||||
| // Function to assign the node as root. | |||||
| Status AssignRootNode(const DsOpPtr &dataset_op); | |||||
| // Function to get the column names in the last node in the tree in order | |||||
| Status GetColumnNames(py::list *output); | |||||
| // Function to prepare the tree for execution | |||||
| Status PrepareTree(const int32_t num_epochs); | |||||
| // Function to launch the tree execution. | |||||
| Status LaunchTreeExec(); | |||||
| // Get a row of data as dictionary of column name to the value. | |||||
| Status GetNextAsMap(py::dict *output); | |||||
| // Get a row of data as list. | |||||
| Status GetNextAsList(py::list *output); | |||||
| Status GetOutputShapes(py::list *output); | |||||
| Status GetOutputTypes(py::list *output); | |||||
| Status GetDataInfo(py::list *types, py::list *shapes); | |||||
| Status SaveDataset(const std::vector<std::string> &file_names, const std::string &file_type); | |||||
| int GetDatasetSize() const; | |||||
| int GetBatchSize() const; | |||||
| int GetRepeatCount() const; | |||||
| Status ParseShuffleOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseMindRecordOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| template <typename T, typename S> | |||||
| Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | |||||
| std::unique_ptr<S> *s, bool need_convert = false); | |||||
| Status FetchMetaFromTensorRow(const std::unordered_map<std::string, int32_t> &column_name_id_map, | |||||
| const TensorRow &row, json *schema, std::vector<std::string> *index_fields); | |||||
| Status FetchDataFromTensorRow(const TensorRow &row, | |||||
| const std::unordered_map<std::string, int32_t> &column_name_id_map, json *row_raw_data, | |||||
| std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> *row_bin_data); | |||||
| Status BuildMindrecordSamplerChain(const py::handle &handle, | |||||
| std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators, | |||||
| int num_padded); | |||||
| Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||||
| std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); | |||||
| Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseTakeOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseProjectOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseImageFolderOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseManifestOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseCifar100Op(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseRandomDataOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| void PrintTree(); | |||||
| int32_t GetNumClasses() const; | |||||
| Status ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status SetBatchParameters(const py::dict &args); | |||||
| Status ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseTextFileOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status StopSend(); | |||||
| Status ContinueSend(); | |||||
| Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | |||||
| std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| Status ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | |||||
| private: | |||||
| // Execution tree that links the dataset operators. | |||||
| std::shared_ptr<ExecutionTree> tree_; | |||||
| std::unique_ptr<DatasetIterator> iterator_; | |||||
| static Status ParsePadInfo(py::handle value, PadInfo *pad_info); | |||||
| /// \brief Helper function to inject a cache operator over top of the current operation being built. | |||||
| /// \param[in] cache_client The client to use for caching | |||||
| /// \param[in] num_workers The number of workers to use in the cache op | |||||
| /// \param[in] input_op The operator to build the cache on top of | |||||
| /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be | |||||
| /// the cache operator | |||||
| /// \return Status return code | |||||
| Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op, | |||||
| std::shared_ptr<DatasetOp> *cache_op); | |||||
| /// \brief Helper function to inject a shuffle operator over top of the current operation being built. | |||||
| /// \param[in] shuffle_size The size to use in the shuffle buffer | |||||
| /// \param[in] input_op The operator to build shuffle on top of | |||||
| /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be | |||||
| /// the shuffle operator | |||||
| /// \return Status return code | |||||
| Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op, | |||||
| std::shared_ptr<DatasetOp> *shuffle_op); | |||||
| /// \brief Helper function to compute the shuffle size | |||||
| /// \param[in] num_files The number of files in the dataset | |||||
| /// \param[in] num_devices The number of devices in the dataset | |||||
| /// \param[in] num_rows The number of rows in the dataset | |||||
| /// \param[in] total_rows An upper bound on the total rows in the dataset | |||||
| /// \param[out] shuffle_size The resultant computed shuffle size | |||||
| /// \return Status return code | |||||
| Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||||
| int64_t *shuffle_size); | |||||
| int batch_size_; | |||||
| int repeat_num_; | |||||
| int num_rows_; | |||||
| int num_classes_; | |||||
| int temp_batch_size_; | |||||
| bool temp_drop_remainder_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_DE_PIPELINE_H_ | |||||
| @@ -0,0 +1,265 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/dataset/api/python/pybind_conversion.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| float toFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); } | |||||
| int toInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | |||||
| int64_t toInt64(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | |||||
| bool toBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } | |||||
| std::string toString(const py::handle &handle) { return py::reinterpret_borrow<py::str>(handle); } | |||||
| std::set<std::string> toStringSet(const std::optional<py::list> list) { | |||||
| std::set<std::string> set; | |||||
| if (list) { | |||||
| for (auto l : *list) { | |||||
| if (!l.is_none()) { | |||||
| (void)set.insert(py::str(l)); | |||||
| } | |||||
| } | |||||
| } | |||||
| return set; | |||||
| } | |||||
| std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict) { | |||||
| std::map<std::string, int32_t> map; | |||||
| if (dict) { | |||||
| for (auto p : *dict) { | |||||
| (void)map.emplace(toString(p.first), toInt(p.second)); | |||||
| } | |||||
| } | |||||
| return map; | |||||
| } | |||||
| std::vector<std::string> toStringVector(const std::optional<py::list> list) { | |||||
| std::vector<std::string> vector; | |||||
| if (list) { | |||||
| for (auto l : *list) { | |||||
| if (l.is_none()) | |||||
| vector.emplace_back(""); | |||||
| else | |||||
| vector.push_back(py::str(l)); | |||||
| } | |||||
| } | |||||
| return vector; | |||||
| } | |||||
| std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple) { | |||||
| std::pair<int64_t, int64_t> pair; | |||||
| if (tuple) { | |||||
| pair = std::make_pair(toInt64((*tuple)[0]), toInt64((*tuple)[1])); | |||||
| } | |||||
| return pair; | |||||
| } | |||||
| std::vector<std::pair<int, int>> toPairVector(const py::list list) { | |||||
| std::vector<std::pair<int, int>> vector; | |||||
| if (list) { | |||||
| for (auto data : list) { | |||||
| auto l = data.cast<py::tuple>(); | |||||
| if (l[1].is_none()) | |||||
| vector.emplace_back(toInt64(l[0]), 0); | |||||
| else | |||||
| vector.emplace_back(toInt64(l[0]), toInt64(l[1])); | |||||
| } | |||||
| } | |||||
| return vector; | |||||
| } | |||||
| std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations) { | |||||
| std::vector<std::shared_ptr<TensorOperation>> vector; | |||||
| if (operations) { | |||||
| for (auto op : *operations) { | |||||
| std::shared_ptr<TensorOp> tensor_op; | |||||
| if (py::isinstance<TensorOp>(op)) { | |||||
| tensor_op = op.cast<std::shared_ptr<TensorOp>>(); | |||||
| } else if (py::isinstance<py::function>(op)) { | |||||
| tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>()); | |||||
| } else { | |||||
| THROW_IF_ERROR( | |||||
| []() { RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); }()); | |||||
| } | |||||
| vector.push_back(std::make_shared<transforms::PreBuiltOperation>(tensor_op)); | |||||
| } | |||||
| } | |||||
| return vector; | |||||
| } | |||||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets) { | |||||
| std::vector<std::shared_ptr<DatasetNode>> vector; | |||||
| vector.push_back(self); | |||||
| if (datasets) { | |||||
| for (auto ds : *datasets) { | |||||
| if (py::isinstance<DatasetNode>(ds)) { | |||||
| vector.push_back(ds.cast<std::shared_ptr<DatasetNode>>()); | |||||
| } else { | |||||
| THROW_IF_ERROR( | |||||
| []() { RETURN_STATUS_UNEXPECTED("Error: datasets is not recognised (not a DatasetNode instance)."); }()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return vector; | |||||
| } | |||||
| std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset) { | |||||
| if (py_sampler) { | |||||
| std::shared_ptr<SamplerObj> sampler_obj; | |||||
| if (!isMindDataset) { | |||||
| // Common Sampler | |||||
| std::shared_ptr<SamplerRT> sampler; | |||||
| auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create"); | |||||
| sampler = create().cast<std::shared_ptr<SamplerRT>>(); | |||||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||||
| } else { | |||||
| // Mindrecord Sampler | |||||
| std::shared_ptr<mindrecord::ShardOperator> sampler; | |||||
| auto create = py::reinterpret_borrow<py::object>(py_sampler.value()).attr("create_for_minddataset"); | |||||
| sampler = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| sampler_obj = std::make_shared<PreBuiltSamplerObj>(std::move(sampler)); | |||||
| } | |||||
| return sampler_obj; | |||||
| } else { | |||||
| THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: sampler input is not SamplerRT."); }()); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| // Here we take in a python object, that holds a reference to a C++ object | |||||
| std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc) { | |||||
| if (cc) { | |||||
| std::shared_ptr<DatasetCache> built_cache; | |||||
| // Common Sampler | |||||
| built_cache = std::make_shared<PreBuiltDatasetCache>(std::move(cc.value())); | |||||
| return built_cache; | |||||
| } else { | |||||
| // don't need to check here as cache is not enabled. | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| ShuffleMode toShuffleMode(const int32_t shuffle) { | |||||
| if (shuffle == 0) return ShuffleMode::kFalse; | |||||
| if (shuffle == 1) return ShuffleMode::kFiles; | |||||
| if (shuffle == 2) return ShuffleMode::kGlobal; | |||||
| return ShuffleMode(); | |||||
| } | |||||
| std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases) { | |||||
| std::vector<std::shared_ptr<CsvBase>> vector; | |||||
| if (csv_bases) { | |||||
| for (auto base : *csv_bases) { | |||||
| if (py::isinstance<py::int_>(base)) { | |||||
| vector.push_back(std::make_shared<CsvRecord<int>>(CsvType::INT, toInt(base))); | |||||
| } else if (py::isinstance<py::float_>(base)) { | |||||
| vector.push_back(std::make_shared<CsvRecord<float>>(CsvType::FLOAT, toFloat(base))); | |||||
| } else if (py::isinstance<py::str>(base)) { | |||||
| vector.push_back(std::make_shared<CsvRecord<std::string>>(CsvType::STRING, toString(base))); | |||||
| } else { | |||||
| THROW_IF_ERROR([]() { RETURN_STATUS_UNEXPECTED("Error: each default value must be int, float, or string"); }()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return vector; | |||||
| } | |||||
| Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json, | |||||
| std::map<std::string, std::string> *sample_bytes) { | |||||
| for (const py::handle &key : padded_sample) { | |||||
| if (py::isinstance<py::bytes>(padded_sample[key])) { | |||||
| (*sample_bytes)[py::str(key).cast<std::string>()] = padded_sample[key].cast<std::string>(); | |||||
| // py::str(key) enter here will loss its key name, so we create an unuse key for it in json, to pass ValidateParam | |||||
| (*padded_sample_json)[py::str(key).cast<std::string>()] = nlohmann::json::object(); | |||||
| } else { | |||||
| nlohmann::json obj_json; | |||||
| if (padded_sample[key].is_none()) { | |||||
| obj_json = nullptr; | |||||
| } else if (py::isinstance<py::int_>(padded_sample[key])) { | |||||
| obj_json = padded_sample[key].cast<int64_t>(); | |||||
| } else if (py::isinstance<py::float_>(padded_sample[key])) { | |||||
| obj_json = padded_sample[key].cast<double>(); | |||||
| } else if (py::isinstance<py::str>(padded_sample[key])) { | |||||
| obj_json = padded_sample[key].cast<std::string>(); // also catch py::bytes | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Python object convert to json failed: " << py::cast<std::string>(padded_sample[key]); | |||||
| RETURN_STATUS_SYNTAX_ERROR("Python object convert to json failed"); | |||||
| } | |||||
| (*padded_sample_json)[py::str(key).cast<std::string>()] = obj_json; | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info) { | |||||
| for (auto p : value) { | |||||
| if (!p.second.is_none()) { | |||||
| auto tp = py::reinterpret_borrow<py::tuple>(p.second); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); | |||||
| TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); | |||||
| std::shared_ptr<Tensor> pad_val = nullptr; | |||||
| if (py::isinstance<py::str>(tp[1])) { | |||||
| std::string pad_val_string = tp[1].is_none() ? "" : toString(tp[1]); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| Tensor::CreateFromVector(std::vector<std::string>{pad_val_string}, TensorShape::CreateScalar(), &pad_val), | |||||
| "Cannot create pad_value Tensor"); | |||||
| } else { | |||||
| float pad_val_float = tp[1].is_none() ? 0 : toFloat(tp[1]); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| Tensor::CreateEmpty(TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32), &pad_val), | |||||
| "Cannot create pad_value Tensor"); | |||||
| pad_val->SetItemAt<float>({}, pad_val_float); | |||||
| } | |||||
| (void)pad_info->insert({toString(p.first), {shape, pad_val}}); | |||||
| } else { // tuple is None | |||||
| (void)pad_info->insert({toString(p.first), {TensorShape({}), nullptr}}); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type) { | |||||
| std::shared_ptr<TensorOp> py_func; | |||||
| if (!func.is_none()) { | |||||
| py::function py_function = func.cast<py::function>(); | |||||
| py_func = std::make_shared<PyFuncOp>(py_function, data_type); | |||||
| } else { | |||||
| py_func = nullptr; | |||||
| } | |||||
| return py_func; | |||||
| } | |||||
| py::list shapesToListOfShape(std::vector<TensorShape> shapes) { | |||||
| py::list shape_list; | |||||
| for (const auto &shape : shapes) { | |||||
| shape_list.append(shape.AsVector()); | |||||
| } | |||||
| return shape_list; | |||||
| } | |||||
| py::list typesToListOfType(std::vector<DataType> types) { | |||||
| py::list type_list; | |||||
| for (const auto &type : types) { | |||||
| type_list.append(type.AsNumpyType()); | |||||
| } | |||||
| return type_list; | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/stl.h" | |||||
| #include "pybind11/stl_bind.h" | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/include/samplers.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/api/python/pybind_register.h" | |||||
| #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" | |||||
| #include "minddata/dataset/engine/ir/datasetops/source/csv_node.h" | |||||
| #include "minddata/dataset/kernels/py_func_op.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| float toFloat(const py::handle &handle); | |||||
| int toInt(const py::handle &handle); | |||||
| int64_t toInt64(const py::handle &handle); | |||||
| bool toBool(const py::handle &handle); | |||||
| std::string toString(const py::handle &handle); | |||||
| std::set<std::string> toStringSet(const std::optional<py::list> list); | |||||
| std::map<std::string, int32_t> toStringMap(const std::optional<py::dict> dict); | |||||
| std::vector<std::string> toStringVector(const std::optional<py::list> list); | |||||
| std::pair<int64_t, int64_t> toIntPair(const std::optional<py::tuple> tuple); | |||||
| std::vector<std::pair<int, int>> toPairVector(const py::list list); | |||||
| std::vector<std::shared_ptr<TensorOperation>> toTensorOperations(std::optional<py::list> operations); | |||||
| std::vector<std::shared_ptr<DatasetNode>> toDatasetNode(std::shared_ptr<DatasetNode> self, py::list datasets); | |||||
| std::shared_ptr<SamplerObj> toSamplerObj(std::optional<py::handle> py_sampler, bool isMindDataset = false); | |||||
| std::shared_ptr<DatasetCache> toDatasetCache(std::optional<std::shared_ptr<CacheClient>> cc); | |||||
| ShuffleMode toShuffleMode(const int32_t shuffle); | |||||
| std::vector<std::shared_ptr<CsvBase>> toCSVBase(py::list csv_bases); | |||||
| std::shared_ptr<TensorOp> toPyFuncOp(py::object func, DataType::Type data_type); | |||||
| Status ToJson(const py::handle &padded_sample, nlohmann::json *padded_sample_json, | |||||
| std::map<std::string, std::string> *sample_bytes); | |||||
| Status toPadInfo(py::dict value, std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> *pad_info); | |||||
| py::list shapesToListOfShape(std::vector<TensorShape> shapes); | |||||
| py::list typesToListOfType(std::vector<DataType> types); | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_ | |||||
| @@ -190,6 +190,23 @@ std::shared_ptr<SamplerRT> PKSamplerObj::Build() { | |||||
| return sampler; | return sampler; | ||||
| } | } | ||||
| #ifndef ENABLE_ANDROID | |||||
| // PreBuiltOperation | |||||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler) | |||||
| : sp_(std::move(sampler)), sp_minddataset_(nullptr) {} | |||||
| PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler) | |||||
| : sp_(nullptr), sp_minddataset_(std::move(sampler)) {} | |||||
| #endif | |||||
| bool PreBuiltSamplerObj::ValidateParams() { return true; } | |||||
| std::shared_ptr<SamplerRT> PreBuiltSamplerObj::Build() { return sp_; } | |||||
| #ifndef ENABLE_ANDROID | |||||
| std::shared_ptr<mindrecord::ShardOperator> PreBuiltSamplerObj::BuildForMindDataset() { return sp_minddataset_; } | |||||
| #endif | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() { | ||||
| // runtime mindrecord sampler object | // runtime mindrecord sampler object | ||||
| @@ -222,6 +222,13 @@ Status OneHotOperation::ValidateParams() { | |||||
| std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } | ||||
| // PreBuiltOperation | |||||
| PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {} | |||||
| Status PreBuiltOperation::ValidateParams() { return Status::OK(); } | |||||
| std::shared_ptr<TensorOp> PreBuiltOperation::Build() { return op_; } | |||||
| // RandomApplyOperation | // RandomApplyOperation | ||||
| RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | RandomApplyOperation::RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob) | ||||
| : transforms_(transforms), prob_(prob) {} | : transforms_(transforms), prob_(prob) {} | ||||
| @@ -18,6 +18,7 @@ set(SRC_FILES_LIST | |||||
| dataset_iterator.cc | dataset_iterator.cc | ||||
| tree_adapter.cc | tree_adapter.cc | ||||
| runtime_context.cc | runtime_context.cc | ||||
| python_runtime_context.cc | |||||
| consumers/tree_consumer.cc | consumers/tree_consumer.cc | ||||
| ) | ) | ||||
| if (ENABLE_PYTHON) | if (ENABLE_PYTHON) | ||||
| @@ -32,15 +32,37 @@ Status PythonIteratorConsumer::GetNextAsList(py::list *out) { | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) { | Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) { | ||||
| std::unordered_map<std::string, TensorPtr> row; | |||||
| std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> vec; | |||||
| Status s; | |||||
| { | { | ||||
| py::gil_scoped_release gil_release; | py::gil_scoped_release gil_release; | ||||
| RETURN_IF_NOT_OK(GetNextAsMap(&row)); | |||||
| s = GetNextAsOrderedPair(&vec); | |||||
| } | } | ||||
| for (auto el : row) { | |||||
| (*out)[common::SafeCStr(el.first)] = el.second; | |||||
| RETURN_IF_NOT_OK(s); | |||||
| // Generate Python dict, python dict maintains its insertion order | |||||
| for (const auto &pair : vec) { | |||||
| (*out)[common::SafeCStr(pair.first)] = pair.second; | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status PythonBuildVocabConsumer::Start() { | |||||
| py::gil_scoped_release gil_release; | |||||
| return BuildVocabConsumer::Start(); | |||||
| } | |||||
| Status PythonSaveToDisk::Save() { | |||||
| py::gil_scoped_release gil_release; | |||||
| return SaveToDisk::Save(); | |||||
| } | |||||
| PythonSaveToDisk::PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType) | |||||
| : SaveToDisk(datasetPath, numFiles, datasetType) {} | |||||
| Status PythonTreeGetters::GetRow(TensorRow *r) { | |||||
| py::gil_scoped_release gil_release; | |||||
| return TreeGetters::GetRow(r); | |||||
| } | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| @@ -44,5 +44,21 @@ class PythonIteratorConsumer : public IteratorConsumer { | |||||
| /// \return Status error code | /// \return Status error code | ||||
| Status GetNextAsDict(py::dict *out); | Status GetNextAsDict(py::dict *out); | ||||
| }; | }; | ||||
| class PythonBuildVocabConsumer : public BuildVocabConsumer { | |||||
| public: | |||||
| Status Start() override; | |||||
| }; | |||||
| class PythonSaveToDisk : public SaveToDisk { | |||||
| public: | |||||
| PythonSaveToDisk(const std::string &datasetPath, int32_t numFiles, const std::string &datasetType); | |||||
| Status Save() override; | |||||
| }; | |||||
| class PythonTreeGetters : public TreeGetters { | |||||
| public: | |||||
| Status GetRow(TensorRow *r) override; | |||||
| }; | |||||
| } // namespace mindspore::dataset | } // namespace mindspore::dataset | ||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_ | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/consumers/tree_consumer.h" | #include "minddata/dataset/engine/consumers/tree_consumer.h" | ||||
| #include "minddata/dataset/engine/tree_adapter.h" | #include "minddata/dataset/engine/tree_adapter.h" | ||||
| #include "minddata/dataset/engine/opt/pre/getter_pass.h" | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| #include "minddata/mindrecord/include/shard_header.h" | #include "minddata/mindrecord/include/shard_header.h" | ||||
| @@ -35,7 +36,7 @@ namespace mindspore::dataset { | |||||
| TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } | ||||
| Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); } | Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d)); } | ||||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); } | |||||
| Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->ServiceStop(); } | |||||
| // IteratorConsumer | // IteratorConsumer | ||||
| Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { | Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) { | ||||
| @@ -73,6 +74,38 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status IteratorConsumer::GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(vec != nullptr && vec->empty(), "vec is null or non-empty."); | |||||
| TensorRow curr_row; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&curr_row)); | |||||
| RETURN_OK_IF_TRUE(curr_row.empty()); | |||||
| size_t num_cols = curr_row.size(); // num_cols is non-empty. | |||||
| // order the column names according to their ids | |||||
| if (column_order_.empty()) { | |||||
| const int32_t invalid_col_id = -1; | |||||
| column_order_.resize(num_cols, {std::string(), invalid_col_id}); | |||||
| for (const auto &itr : tree_adapter_->GetColumnNameMap()) { | |||||
| int32_t ind = itr.second; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ind < num_cols && ind >= 0, "column id out of bounds."); | |||||
| column_order_[ind] = std::make_pair(itr.first, ind); | |||||
| } | |||||
| // error check, make sure the ids in col_name_id_map are continuous and starts from 0 | |||||
| for (const auto &col : column_order_) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(col.second != invalid_col_id, "column ids are not continuous."); | |||||
| } | |||||
| } | |||||
| vec->reserve(num_cols); | |||||
| std::transform(column_order_.begin(), column_order_.end(), std::back_inserter(*vec), | |||||
| [curr_row](const auto &col) { return std::make_pair(col.first, curr_row[col.second]); }); | |||||
| return Status::OK(); | |||||
| } | |||||
| // ToDevice | // ToDevice | ||||
| Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } | Status ToDevice::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), num_epochs_); } | ||||
| @@ -81,7 +114,6 @@ Status ToDevice::Send() { | |||||
| RETURN_IF_NOT_OK(tree_adapter_->Launch()); | RETURN_IF_NOT_OK(tree_adapter_->Launch()); | ||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | ||||
| RETURN_IF_NOT_OK(root->GetNextBuffer(&db)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -101,9 +133,36 @@ Status ToDevice::Stop() { | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | ||||
| op->StopSend(); | op->StopSend(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status ToDevice::GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes) { | |||||
| // tree_.root() must be DeviceQueueOp | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "GetDataInfo only supported by DeviceQueueOp"); | |||||
| DATA_INFO data_info; | |||||
| RETURN_IF_NOT_OK(op->GetDataInfo(&data_info)); | |||||
| for (auto el : data_info) { | |||||
| types->push_back(el.first); | |||||
| shapes->push_back(el.second); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status ToDevice::Terminate() { | |||||
| #ifdef ENABLE_TDTQUE | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(root.get()); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); | |||||
| op->StopWaiting(); | |||||
| #endif | |||||
| return TreeConsumer::Terminate(); | |||||
| } | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // SaveToDisk | // SaveToDisk | ||||
| Status SaveToDisk::ValidateParams() { | Status SaveToDisk::ValidateParams() { | ||||
| @@ -282,50 +341,50 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row, | |||||
| if (column_type == DataType::DE_INT8) { | if (column_type == DataType::DE_INT8) { | ||||
| std::unique_ptr<int32_t> data; | std::unique_ptr<int32_t> data; | ||||
| std::unique_ptr<int8_t> dummy; | std::unique_ptr<int8_t> dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_INT16) { | } else if (column_type == DataType::DE_INT16) { | ||||
| std::unique_ptr<int32_t> data; | std::unique_ptr<int32_t> data; | ||||
| std::unique_ptr<int16_t> dummy; | std::unique_ptr<int16_t> dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_UINT16) { | } else if (column_type == DataType::DE_UINT16) { | ||||
| std::unique_ptr<int32_t> data; | std::unique_ptr<int32_t> data; | ||||
| std::unique_ptr<uint16_t> dummy; | std::unique_ptr<uint16_t> dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_UINT8) { | } else if (column_type == DataType::DE_UINT8) { | ||||
| std::unique_ptr<uint8_t> data, dummy; | std::unique_ptr<uint8_t> data, dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_INT32) { | } else if (column_type == DataType::DE_INT32) { | ||||
| std::unique_ptr<int32_t> data, dummy; | std::unique_ptr<int32_t> data, dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_UINT32) { | } else if (column_type == DataType::DE_UINT32) { | ||||
| std::unique_ptr<int64_t> data; | std::unique_ptr<int64_t> data; | ||||
| std::unique_ptr<uint32_t> dummy; | std::unique_ptr<uint32_t> dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy, true); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_INT64) { | } else if (column_type == DataType::DE_INT64) { | ||||
| std::unique_ptr<int64_t> data, dummy; | std::unique_ptr<int64_t> data, dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_FLOAT32) { | } else if (column_type == DataType::DE_FLOAT32) { | ||||
| std::unique_ptr<float> data, dummy; | std::unique_ptr<float> data, dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_FLOAT64) { | } else if (column_type == DataType::DE_FLOAT64) { | ||||
| std::unique_ptr<double> data, dummy; | std::unique_ptr<double> data, dummy; | ||||
| s = TransfromTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| s = TransformTensor(tensor->GetBuffer(), tensor->shape(), tensor->Size(), &data, &data_ptr, &dummy); | |||||
| RETURN_IF_NOT_OK(s); | RETURN_IF_NOT_OK(s); | ||||
| if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | if (data != nullptr) (*row_raw_data)[column_name] = std::move(*data); | ||||
| } else if (column_type == DataType::DE_STRING) { | } else if (column_type == DataType::DE_STRING) { | ||||
| @@ -346,7 +405,7 @@ Status SaveToDisk::FetchDataFromTensorRow(const TensorRow &row, | |||||
| } | } | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||||
| Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | ||||
| std::unique_ptr<S> *s, bool need_convert) { | std::unique_ptr<S> *s, bool need_convert) { | ||||
| if (nullptr == src) { | if (nullptr == src) { | ||||
| @@ -379,47 +438,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape & | |||||
| } | } | ||||
| #endif | #endif | ||||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(false) { | |||||
| tree_adapter_ = std::make_unique<TreeAdapter>(); | |||||
| } | |||||
| TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); } | |||||
| Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { | ||||
| if (init_flag_) { | |||||
| return Status::OK(); | |||||
| } | |||||
| Status s = tree_adapter_->Compile(std::move(d), 1); | |||||
| if (!s.IsError()) { | |||||
| init_flag_ = true; | |||||
| } | |||||
| return s; | |||||
| } | |||||
| bool TreeGetters::isInitialized() { return init_flag_; } | |||||
| Status TreeGetters::GetRow(TensorRow *row) { | |||||
| if (row_flag_ == false) { | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(row)); | |||||
| row_flag_ = true; | |||||
| } | |||||
| root_ = std::move(d); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetRow(TensorRow *row) { return tree_adapter_->GetNext(row); } | |||||
| Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | ||||
| if (dataset_size_ == -1) { | if (dataset_size_ == -1) { | ||||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kDatasetSize))); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size)); | ||||
| dataset_size_ = *dataset_size; | |||||
| if (*dataset_size == -1) { | |||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| int64_t num_rows = 0; | |||||
| TensorRow row = row_; | |||||
| while (row.size() != 0) { | |||||
| num_rows++; | |||||
| RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row)); | |||||
| if (*dataset_size == -1) { // run through the tree and get everything | |||||
| TensorRow row; | |||||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||||
| int64_t row_cnt = 0; | |||||
| while (!row.empty()) { | |||||
| ++row_cnt; | |||||
| RETURN_IF_NOT_OK(GetRow(&row)); | |||||
| } | } | ||||
| dataset_size_ = num_rows; | |||||
| *dataset_size = row_cnt; | |||||
| } | } | ||||
| dataset_size_ = *dataset_size; // save the previous result | |||||
| } | } | ||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| @@ -427,68 +471,88 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { | |||||
| } | } | ||||
| Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { | ||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| for (auto ts : row_) { | |||||
| DataType dt = ts->type(); | |||||
| types->push_back(dt); | |||||
| } | |||||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); | |||||
| if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); | |||||
| std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types), | |||||
| [](const TensorPtr &t) { return t->type(); }); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { | Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { | ||||
| RETURN_IF_NOT_OK(GetRow(&row_)); | |||||
| for (auto ts : row_) { | |||||
| TensorShape t = ts->shape(); | |||||
| shapes->push_back(t); | |||||
| } | |||||
| RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); | |||||
| if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); | |||||
| std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes), | |||||
| [](const TensorPtr &t) { return t->shape(); }); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetBatchSize(int64_t *batch_size) { | Status TreeGetters::GetBatchSize(int64_t *batch_size) { | ||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| *batch_size = root->GetTreeBatchSize(); | *batch_size = root->GetTreeBatchSize(); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); | CHECK_FAIL_RETURN_UNEXPECTED(*batch_size != -1, "Error in finding the batch size."); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { | Status TreeGetters::GetRepeatCount(int64_t *repeat_count) { | ||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| *repeat_count = root->GetTreeRepeatCount(); | *repeat_count = root->GetTreeRepeatCount(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetNumClasses(int64_t *num_classes) { | Status TreeGetters::GetNumClasses(int64_t *num_classes) { | ||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | RETURN_IF_NOT_OK(root->GetNumClasses(num_classes)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetColumnNames(std::vector<std::string> *output) { | Status TreeGetters::GetColumnNames(std::vector<std::string> *output) { | ||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map(); | std::unordered_map<std::string, int32_t> column_name_id_map = root->column_name_id_map(); | ||||
| if (column_name_id_map.empty()) RETURN_STATUS_UNEXPECTED("GetColumnNames: column_name_id map was empty."); | |||||
| std::vector<std::pair<std::string, int32_t>> column_name_id_vector(column_name_id_map.begin(), | |||||
| column_name_id_map.end()); | |||||
| std::sort(column_name_id_vector.begin(), column_name_id_vector.end(), | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map.empty(), "GetColumnNames: column_name_id map is empty."); | |||||
| std::vector<std::pair<std::string, int32_t>> col_name_id_vec(column_name_id_map.begin(), column_name_id_map.end()); | |||||
| std::sort(col_name_id_vec.begin(), col_name_id_vec.end(), | |||||
| [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { | [](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) { | ||||
| return a.second < b.second; | return a.second < b.second; | ||||
| }); | }); | ||||
| for (auto item : column_name_id_vector) { | |||||
| (*output).push_back(item.first); | |||||
| } | |||||
| std::transform(col_name_id_vec.begin(), col_name_id_vec.end(), std::back_inserter(*output), | |||||
| [](const std::pair<std::string, int32_t> &p) { return p.first; }); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | Status TreeGetters::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) { | ||||
| RETURN_IF_NOT_OK(InternalInit()); | |||||
| std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot()); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr."); | |||||
| RETURN_UNEXPECTED_IF_NULL(root); | |||||
| RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing)); | RETURN_IF_NOT_OK(root->GetClassIndexing(output_class_indexing)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status TreeGetters::InternalInit(int8_t type) { | |||||
| if (init_flag_) return Status::OK(); | |||||
| tree_adapter_->SetPrePassOverride([&type](OptPass pre) { | |||||
| pre.push_back(std::make_unique<GetterPass>(static_cast<GetterPass::GetterType>(type))); | |||||
| return pre; | |||||
| }); | |||||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | |||||
| if (!s.IsError()) init_flag_ = true; | |||||
| return s; | |||||
| } | |||||
| Status TreeGetters::InternalInit() { | |||||
| if (init_flag_) return Status::OK(); | |||||
| Status s = tree_adapter_->Compile(std::move(root_), 1); | |||||
| if (!s.IsError()) init_flag_ = true; | |||||
| return s; | |||||
| } | |||||
| Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); } | Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); } | ||||
| Status BuildVocabConsumer::Start() { | Status BuildVocabConsumer::Start() { | ||||
| @@ -41,7 +41,7 @@ class TreeConsumer { | |||||
| /// \return Status error code. | /// \return Status error code. | ||||
| virtual Status Init(std::shared_ptr<DatasetNode> d); | virtual Status Init(std::shared_ptr<DatasetNode> d); | ||||
| Status Terminate(); | |||||
| virtual Status Terminate(); | |||||
| protected: | protected: | ||||
| /// The class owns the tree_adapter that handles execution tree operations. | /// The class owns the tree_adapter that handles execution tree operations. | ||||
| @@ -72,6 +72,11 @@ class IteratorConsumer : public TreeConsumer { | |||||
| /// \return Status error code | /// \return Status error code | ||||
| Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); | Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out); | ||||
| /// Returns the next row in as a map | |||||
| /// \param[out] out std::map of string to Tensor | |||||
| /// \return Status error code | |||||
| Status GetNextAsOrderedPair(std::vector<std::pair<std::string, std::shared_ptr<Tensor>>> *vec); | |||||
| protected: | protected: | ||||
| /// Method to return the name of the consumer | /// Method to return the name of the consumer | ||||
| /// \return string | /// \return string | ||||
| @@ -79,6 +84,7 @@ class IteratorConsumer : public TreeConsumer { | |||||
| private: | private: | ||||
| int32_t num_epochs_; | int32_t num_epochs_; | ||||
| std::vector<std::pair<std::string, int32_t>> column_order_; // key: column name, val: column id | |||||
| }; | }; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -101,7 +107,7 @@ class SaveToDisk : public TreeConsumer { | |||||
| /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | /// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows | ||||
| /// would be written to disk) | /// would be written to disk) | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Save(); | |||||
| virtual Status Save(); | |||||
| protected: | protected: | ||||
| /// Method to return the name of the consumer | /// Method to return the name of the consumer | ||||
| @@ -110,7 +116,7 @@ class SaveToDisk : public TreeConsumer { | |||||
| private: | private: | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| Status TransfromTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||||
| Status TransformTensor(const unsigned char *src, const TensorShape &shape, const int64_t num_of_elements, | |||||
| std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | std::unique_ptr<T> *data, std::unique_ptr<std::vector<uint8_t>> *data_ptr, | ||||
| std::unique_ptr<S> *s, bool need_convert = false); | std::unique_ptr<S> *s, bool need_convert = false); | ||||
| @@ -131,24 +137,29 @@ class SaveToDisk : public TreeConsumer { | |||||
| /// Consumer that iterates over the dataset and send it to a device | /// Consumer that iterates over the dataset and send it to a device | ||||
| class ToDevice : public TreeConsumer { | class ToDevice : public TreeConsumer { | ||||
| public: | public: | ||||
| explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1) | |||||
| : TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} | |||||
| explicit ToDevice(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {} | |||||
| ~ToDevice() = default; | ~ToDevice() = default; | ||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | Status Init(std::shared_ptr<DatasetNode> d) override; | ||||
| Status Terminate() override; | |||||
| /// Send the data to device | /// Send the data to device | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Send(); | |||||
| virtual Status Send(); | |||||
| /// Stop to send data to device | /// Stop to send data to device | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Stop(); | |||||
| virtual Status Stop(); | |||||
| /// Continue to send data to device | /// Continue to send data to device | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Continue(); | |||||
| virtual Status Continue(); | |||||
| /// Get data info from TDT | |||||
| /// \return Status error code | |||||
| virtual Status GetDataInfo(std::vector<DataType> *types, std::vector<TensorShape> *shapes); | |||||
| protected: | protected: | ||||
| /// Method to return the name of the consumer | /// Method to return the name of the consumer | ||||
| @@ -156,8 +167,6 @@ class ToDevice : public TreeConsumer { | |||||
| std::string Name() override { return "ToDevice"; } | std::string Name() override { return "ToDevice"; } | ||||
| private: | private: | ||||
| std::string device_type_; | |||||
| bool send_epoch_end_; | |||||
| int32_t num_epochs_; | int32_t num_epochs_; | ||||
| }; | }; | ||||
| @@ -167,6 +176,7 @@ class TreeGetters : public TreeConsumer { | |||||
| TreeGetters(); | TreeGetters(); | ||||
| ~TreeGetters() = default; | ~TreeGetters() = default; | ||||
| Status Init(std::shared_ptr<DatasetNode> d) override; | Status Init(std::shared_ptr<DatasetNode> d) override; | ||||
| Status GetDatasetSize(int64_t *size); | Status GetDatasetSize(int64_t *size); | ||||
| Status GetOutputTypes(std::vector<DataType> *types); | Status GetOutputTypes(std::vector<DataType> *types); | ||||
| Status GetOutputShapes(std::vector<TensorShape> *shapes); | Status GetOutputShapes(std::vector<TensorShape> *shapes); | ||||
| @@ -175,15 +185,17 @@ class TreeGetters : public TreeConsumer { | |||||
| Status GetNumClasses(int64_t *num_classes); | Status GetNumClasses(int64_t *num_classes); | ||||
| Status GetColumnNames(std::vector<std::string> *output); | Status GetColumnNames(std::vector<std::string> *output); | ||||
| Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); | Status GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing); | ||||
| bool isInitialized(); | |||||
| std::string Name() override { return "TreeGetters"; } | std::string Name() override { return "TreeGetters"; } | ||||
| Status GetRow(TensorRow *r); | |||||
| virtual Status GetRow(TensorRow *r); | |||||
| private: | private: | ||||
| std::shared_ptr<DatasetNode> root_; | |||||
| int64_t dataset_size_; | int64_t dataset_size_; | ||||
| TensorRow row_; | |||||
| TensorRow first_row_; | |||||
| bool init_flag_; // indicate whether the tree has initialized | bool init_flag_; // indicate whether the tree has initialized | ||||
| bool row_flag_; // indicate whether the first row has been stored in row_ | |||||
| Status InternalInit(int8_t type); | |||||
| Status InternalInit(); | |||||
| }; | }; | ||||
| class BuildVocabConsumer : public TreeConsumer { | class BuildVocabConsumer : public TreeConsumer { | ||||
| @@ -197,7 +209,7 @@ class BuildVocabConsumer : public TreeConsumer { | |||||
| /// Start consuming | /// Start consuming | ||||
| /// \return Status error code | /// \return Status error code | ||||
| Status Start(); | |||||
| virtual Status Start(); | |||||
| protected: | protected: | ||||
| /// Method to return the name of the consumer | /// Method to return the name of the consumer | ||||
| @@ -44,9 +44,9 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) { | |||||
| } | } | ||||
| // Constructor of the ConcatOp. | // Constructor of the ConcatOp. | ||||
| ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||||
| std::vector<std::pair<int, int>> children_start_end_index) | |||||
| ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||||
| const std::vector<std::pair<int, int>> &children_start_end_index) | |||||
| : PipelineOp(op_connector_size), | : PipelineOp(op_connector_size), | ||||
| children_num_(0), | children_num_(0), | ||||
| sampler_(sampler), | sampler_(sampler), | ||||
| @@ -70,9 +70,9 @@ class ConcatOp : public PipelineOp { | |||||
| // @note The builder class should be used to call it | // @note The builder class should be used to call it | ||||
| // @param op_connector_size - connector size | // @param op_connector_size - connector size | ||||
| explicit ConcatOp(int32_t op_connector_size); | explicit ConcatOp(int32_t op_connector_size); | ||||
| explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler, | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums, | |||||
| std::vector<std::pair<int, int>> children_start_end_index); | |||||
| ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler, | |||||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||||
| const std::vector<std::pair<int, int>> &children_start_end_index); | |||||
| // Destructor | // Destructor | ||||
| ~ConcatOp() = default; | ~ConcatOp() = default; | ||||
| @@ -346,6 +346,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return Name of the current Op | /// \return Name of the current Op | ||||
| virtual std::string Name() const = 0; | virtual std::string Name() const = 0; | ||||
| /// Op name and ID getter | |||||
| /// \return Name and ID of the current Op | |||||
| std::string NameWithID() const { return Name() + "(ID:" + std::to_string(id()) + ")"; } | |||||
| /// Execution Tree getter | /// Execution Tree getter | ||||
| /// \return Pointer to the ExecutionTree the current op belongs to, no ownership | /// \return Pointer to the ExecutionTree the current op belongs to, no ownership | ||||
| ExecutionTree *Tree() { return tree_; } | ExecutionTree *Tree() { return tree_; } | ||||
| @@ -205,7 +205,6 @@ Status DeviceQueueOp::SendDataToAscend() { | |||||
| } | } | ||||
| tree_->SetFinished(); | tree_->SetFinished(); | ||||
| MS_LOG(INFO) << "Device queue total batch is " << send_batch; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -39,10 +39,10 @@ using mindspore::device::GpuBufferMgr; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>; | using DATA_INFO = std::vector<std::pair<DataType, TensorShape>>; | ||||
| using DATA_INFO_QUEUE = Queue<DATA_INFO>; | using DATA_INFO_QUEUE = Queue<DATA_INFO>; | ||||
| const int kDataInfoQueueCapacity = 128; | const int kDataInfoQueueCapacity = 128; | ||||
| class DeviceQueueOp : public PipelineOp { | class DeviceQueueOp : public PipelineOp { | ||||
| public: | public: | ||||
| static const uint32_t INVALID_HANDLE = 0xffffffffUL; | static const uint32_t INVALID_HANDLE = 0xffffffffUL; | ||||
| @@ -184,7 +184,6 @@ class DeviceQueueOp : public PipelineOp { | |||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| Status SendDataToAscend(); | Status SendDataToAscend(); | ||||
| bool ascend_keep_waiting_; | bool ascend_keep_waiting_; | ||||
| #endif | #endif | ||||
| #ifdef ENABLE_GPUQUE | #ifdef ENABLE_GPUQUE | ||||
| @@ -169,7 +169,7 @@ Status MapOp::operator()() { | |||||
| } | } | ||||
| // The operator class just starts off threads by calling the tree_ function | // The operator class just starts off threads by calling the tree_ function | ||||
| rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); | |||||
| rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID()); | |||||
| // Synchronize with TaskManager | // Synchronize with TaskManager | ||||
| TaskManager::FindMe()->Post(); | TaskManager::FindMe()->Post(); | ||||
| RETURN_IF_NOT_OK(rc); | RETURN_IF_NOT_OK(rc); | ||||
| @@ -704,6 +704,8 @@ Status CocoOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| } | } | ||||
| if (image_ids_.size() == 0) { | if (image_ids_.size() == 0) { | ||||
| RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows)); | ||||
| } else { | |||||
| num_rows = image_ids_.size(); | |||||
| } | } | ||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | sample_size = sampler_->CalculateNumSamples(num_rows); | ||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| @@ -480,13 +480,13 @@ Status MindRecordOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| *dataset_size = dataset_size_; | *dataset_size = dataset_size_; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| int64_t num_rows = num_rows_, sample_size; | |||||
| int64_t num_rows = num_rows_; | |||||
| if (num_rows_ <= 0) { | if (num_rows_ <= 0) { | ||||
| std::shared_ptr<ShardOperator> op; | |||||
| // The last operator is parent sampler | |||||
| std::shared_ptr<ShardOperator> op = operators_.back(); | |||||
| RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); | RETURN_IF_NOT_OK(CountTotalRows(dataset_file_, load_dataset_, op, &num_rows, num_padded_)); | ||||
| } | } | ||||
| sample_size = operators_[0]->GetNumSamples(num_rows, 0); | |||||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||||
| *dataset_size = num_rows; | |||||
| dataset_size_ = *dataset_size; | dataset_size_ = *dataset_size; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -1067,6 +1067,19 @@ Status TFReaderOp::PrepareNodePostAction() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Get the file list of the specific shard ID | |||||
| Status TFReaderOp::GetShardFileList(std::vector<std::string> *shard_filenames) { | |||||
| if (!shard_filenames->empty()) { | |||||
| RETURN_STATUS_UNEXPECTED("The initial file list must be empty.\n"); | |||||
| } | |||||
| for (int index = 0; index < dataset_files_list_.size(); index++) { | |||||
| if (index % num_devices_ == device_id_) { | |||||
| shard_filenames->push_back(dataset_files_list_.at(index)); | |||||
| } | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // Get Dataset size | // Get Dataset size | ||||
| Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | ||||
| if (dataset_size_ > 0) { | if (dataset_size_ > 0) { | ||||
| @@ -1080,7 +1093,9 @@ Status TFReaderOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); | ||||
| num_rows = num_rows_per_shard_; | num_rows = num_rows_per_shard_; | ||||
| } else { | } else { | ||||
| RETURN_IF_NOT_OK(CountTotalRows(&num_rows, dataset_files_list_)); | |||||
| std::vector<std::string> shard_file_list; | |||||
| RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); | |||||
| RETURN_IF_NOT_OK(CountTotalRows(&num_rows, shard_file_list)); | |||||
| } | } | ||||
| } | } | ||||
| sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | sample_size = total_rows_ != 0 ? total_rows_ : data_schema_->num_rows(); | ||||
| @@ -400,6 +400,11 @@ class TFReaderOp : public ParallelOp { | |||||
| // @return - Status | // @return - Status | ||||
| Status ComputeColMap() override; | Status ComputeColMap() override; | ||||
| // Private function for computing the file list of the specific shard ID. This is because in distributed scenario, | |||||
| // data will be divided into shards by row when equal_rows_per_shard is true, but by file in the opposite case. | |||||
| // @return - Status - the status code returned. | |||||
| Status GetShardFileList(std::vector<std::string> *shard_filenames); | |||||
| int32_t device_id_; | int32_t device_id_; | ||||
| int32_t num_devices_; | int32_t num_devices_; | ||||
| int64_t rows_per_buffer_; | int64_t rows_per_buffer_; | ||||
| @@ -536,6 +536,8 @@ Status VOCOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | RETURN_IF_NOT_OK(op->ParseImageIds()); | ||||
| num_rows = static_cast<int64_t>(op->image_ids_.size()); | num_rows = static_cast<int64_t>(op->image_ids_.size()); | ||||
| } | } | ||||
| } else { | |||||
| num_rows = image_ids_.size(); | |||||
| } | } | ||||
| sample_size = sampler_->CalculateNumSamples(num_rows); | sample_size = sampler_->CalculateNumSamples(num_rows); | ||||
| *dataset_size = sample_size; | *dataset_size = sample_size; | ||||
| @@ -141,8 +141,6 @@ Status ExecutionTree::Launch() { | |||||
| " Expected state: " + std::to_string(static_cast<int>(kDeTStateReady)); | " Expected state: " + std::to_string(static_cast<int>(kDeTStateReady)); | ||||
| RETURN_STATUS_UNEXPECTED(err_msg); | RETURN_STATUS_UNEXPECTED(err_msg); | ||||
| } | } | ||||
| std::ostringstream ss; | |||||
| ss << *this; | |||||
| // Profiling infrastructures need to be initialized before Op launching | // Profiling infrastructures need to be initialized before Op launching | ||||
| if (profiling_manager_->IsProfilingEnable()) { | if (profiling_manager_->IsProfilingEnable()) { | ||||
| @@ -152,6 +150,8 @@ Status ExecutionTree::Launch() { | |||||
| RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor()); | RETURN_IF_NOT_OK(profiling_manager_->LaunchMonitor()); | ||||
| } | } | ||||
| std::ostringstream ss; | |||||
| ss << *this; | |||||
| MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); | MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); | ||||
| for (auto itr = this->begin(); itr != this->end(); ++itr) { | for (auto itr = this->begin(); itr != this->end(); ++itr) { | ||||
| // An inlined operator is one that has an output connector size of 0, and it does not | // An inlined operator is one that has an output connector size of 0, and it does not | ||||
| @@ -160,7 +160,7 @@ Status ExecutionTree::Launch() { | |||||
| // the launching tree/user thread. Do not exec any thread for an inlined op. | // the launching tree/user thread. Do not exec any thread for an inlined op. | ||||
| itr->state_ = DatasetOp::OpState::kDeOpRunning; | itr->state_ = DatasetOp::OpState::kDeOpRunning; | ||||
| if (!itr->inlined()) { | if (!itr->inlined()) { | ||||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); | |||||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask(itr->NameWithID(), std::ref(*itr))); | |||||
| // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp | // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp | ||||
| } | } | ||||
| } | } | ||||
| @@ -189,10 +189,10 @@ ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_ | |||||
| // Given the number of workers, launches the worker entry function for each. Essentially a | // Given the number of workers, launches the worker entry function for each. Essentially a | ||||
| // wrapper for the TaskGroup handling that is stored inside the execution tree. | // wrapper for the TaskGroup handling that is stored inside the execution tree. | ||||
| Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func) { | |||||
| Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name) { | |||||
| // Launch the workers | // Launch the workers | ||||
| for (int32_t i = 0; i < num_workers; ++i) { | for (int32_t i = 0; i < num_workers; ++i) { | ||||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); | |||||
| RETURN_IF_NOT_OK(tg_->CreateAsyncTask(name, std::bind(func, i))); | |||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -150,7 +150,7 @@ class ExecutionTree { | |||||
| // @param num_workers - The number of workers to launch | // @param num_workers - The number of workers to launch | ||||
| // @param func - The function entry point that workers will execute | // @param func - The function entry point that workers will execute | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func); | |||||
| Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = ""); | |||||
| // Getter method | // Getter method | ||||
| // @return shared_ptr to the root operator | // @return shared_ptr to the root operator | ||||
| @@ -1,4 +1,5 @@ | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| add_library(engine-ir-cache OBJECT | add_library(engine-ir-cache OBJECT | ||||
| dataset_cache_impl.cc) | |||||
| pre_built_dataset_cache.cc | |||||
| dataset_cache_impl.cc) | |||||
| @@ -18,8 +18,8 @@ | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | #include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h" | ||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| namespace mindspore::dataset { | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | /// Method to initialize the DatasetCache by creating an instance of a CacheClient | ||||
| /// \return Status Error code | /// \return Status Error code | ||||
| Status DatasetCacheImpl::Build() { | Status DatasetCacheImpl::Build() { | ||||
| @@ -40,5 +40,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace mindspore::dataset | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -24,8 +24,8 @@ | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | #include "minddata/dataset/engine/datasetops/cache_op.h" | ||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | ||||
| namespace mindspore::dataset { | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// DatasetCache is the IR of CacheClient | /// DatasetCache is the IR of CacheClient | ||||
| class DatasetCacheImpl : public DatasetCache { | class DatasetCacheImpl : public DatasetCache { | ||||
| public: | public: | ||||
| @@ -67,6 +67,6 @@ class DatasetCacheImpl : public DatasetCache { | |||||
| std::optional<int32_t> num_connections_; | std::optional<int32_t> num_connections_; | ||||
| std::optional<int32_t> prefetch_sz_; | std::optional<int32_t> prefetch_sz_; | ||||
| }; | }; | ||||
| } // namespace mindspore::dataset | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_IMPL_H_ | ||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h" | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||||
| /// \return Status Error code | |||||
| Status PreBuiltDatasetCache::Build() { | |||||
| // we actually want to keep a reference of the runtime object so it can be shared by different pipelines | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); | |||||
| std::shared_ptr<CacheOp> cache_op = nullptr; | |||||
| RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); | |||||
| *ds = cache_op; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||||
| #include "minddata/dataset/engine/datasetops/cache_op.h" | |||||
| #include "minddata/dataset/engine/ir/cache/dataset_cache.h" | |||||
| namespace mindspore { | |||||
| namespace dataset { | |||||
| /// DatasetCache is the IR of CacheClient | |||||
| class PreBuiltDatasetCache : public DatasetCache { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| /// \param cc a pre-built cache client | |||||
| explicit PreBuiltDatasetCache(std::shared_ptr<CacheClient> cc) : cache_client_(std::move(cc)) {} | |||||
| /// Method to initialize the DatasetCache by creating an instance of a CacheClient | |||||
| /// \return Status Error code | |||||
| Status Build() override; | |||||
| Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; | |||||
| Status ValidateParams() override { return Status::OK(); } | |||||
| private: | |||||
| std::shared_ptr<CacheClient> cache_client_; | |||||
| }; | |||||
| } // namespace dataset | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_ | |||||
| @@ -31,7 +31,7 @@ namespace dataset { | |||||
| BucketBatchByLengthNode::BucketBatchByLengthNode( | BucketBatchByLengthNode::BucketBatchByLengthNode( | ||||
| std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | ||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | ||||
| std::function<TensorRow(TensorRow)> element_length_function, | |||||
| std::shared_ptr<TensorOp> element_length_function, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary, | ||||
| bool drop_remainder) | bool drop_remainder) | ||||
| : column_names_(column_names), | : column_names_(column_names), | ||||
| @@ -47,16 +47,13 @@ BucketBatchByLengthNode::BucketBatchByLengthNode( | |||||
| std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| std::shared_ptr<TensorOp> c_func; | |||||
| if (element_length_function_ != nullptr) { | |||||
| c_func = std::make_shared<CFuncOp>(element_length_function_); | |||||
| } else { | |||||
| c_func = nullptr; | |||||
| bucket_boundaries_.insert(bucket_boundaries_.begin(), 0); | |||||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>( | |||||
| column_names_, bucket_boundaries_, bucket_batch_sizes_, element_length_function_, pad_info_, | |||||
| pad_to_bucket_boundary_, drop_remainder_, connector_que_size_)); | |||||
| if (bucket_boundaries_[0] == 0) { | |||||
| bucket_boundaries_.erase(bucket_boundaries_.begin()); | |||||
| } | } | ||||
| node_ops.push_back(std::make_shared<BucketBatchByLengthOp>(column_names_, bucket_boundaries_, bucket_batch_sizes_, | |||||
| c_func, pad_info_, pad_to_bucket_boundary_, | |||||
| drop_remainder_, connector_que_size_)); | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -33,7 +33,7 @@ class BucketBatchByLengthNode : public DatasetNode { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names, | ||||
| const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes, | ||||
| std::function<TensorRow(TensorRow)> element_length_function = nullptr, | |||||
| std::shared_ptr<TensorOp> element_length_function = nullptr, | |||||
| const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {}, | ||||
| bool pad_to_bucket_boundary = false, bool drop_remainder = false); | bool pad_to_bucket_boundary = false, bool drop_remainder = false); | ||||
| @@ -52,7 +52,7 @@ class BucketBatchByLengthNode : public DatasetNode { | |||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| std::vector<int32_t> bucket_boundaries_; | std::vector<int32_t> bucket_boundaries_; | ||||
| std::vector<int32_t> bucket_batch_sizes_; | std::vector<int32_t> bucket_batch_sizes_; | ||||
| std::function<TensorRow(TensorRow)> element_length_function_; | |||||
| std::shared_ptr<TensorOp> element_length_function_; | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_info_; | ||||
| bool pad_to_bucket_boundary_; | bool pad_to_bucket_boundary_; | ||||
| bool drop_remainder_; | bool drop_remainder_; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/concat_op.h" | #include "minddata/dataset/engine/datasetops/concat_op.h" | ||||
| @@ -27,7 +28,15 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Function to build ConcatOp | // Function to build ConcatOp | ||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; } | |||||
| ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets, | |||||
| const std::shared_ptr<SamplerObj> &sampler, | |||||
| const std::vector<std::pair<int, int>> &children_flag_and_nums, | |||||
| const std::vector<std::pair<int, int>> &children_start_end_index) | |||||
| : sampler_(sampler), | |||||
| children_flag_and_nums_(children_flag_and_nums), | |||||
| children_start_end_index_(children_start_end_index) { | |||||
| this->children = datasets; | |||||
| } | |||||
| Status ConcatNode::ValidateParams() { | Status ConcatNode::ValidateParams() { | ||||
| if (children.size() < 2) { | if (children.size() < 2) { | ||||
| @@ -42,14 +51,25 @@ Status ConcatNode::ValidateParams() { | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | RETURN_STATUS_SYNTAX_ERROR(err_msg); | ||||
| } | } | ||||
| if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) || | |||||
| (!children_flag_and_nums_.empty() && children_start_end_index_.empty())) { | |||||
| std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together"; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> ConcatNode::Build() { | ||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) { | |||||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||||
| } else { | |||||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_, sampler_->Build(), children_flag_and_nums_, | |||||
| children_start_end_index_)); | |||||
| } | |||||
| node_ops.push_back(std::make_shared<ConcatOp>(connector_que_size_)); | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| @@ -29,7 +30,10 @@ namespace dataset { | |||||
| class ConcatNode : public DatasetNode { | class ConcatNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets); | |||||
| explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets, | |||||
| const std::shared_ptr<SamplerObj> &sampler = nullptr, | |||||
| const std::vector<std::pair<int, int>> &children_flag_and_nums = {}, | |||||
| const std::vector<std::pair<int, int>> &children_start_end_index = {}); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~ConcatNode() = default; | ~ConcatNode() = default; | ||||
| @@ -41,6 +45,11 @@ class ConcatNode : public DatasetNode { | |||||
| /// \brief Parameters validation | /// \brief Parameters validation | ||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | |||||
| std::shared_ptr<SamplerObj> sampler_; | |||||
| std::vector<std::pair<int, int>> children_flag_and_nums_; | |||||
| std::vector<std::pair<int, int>> children_start_end_index_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -240,6 +240,7 @@ DatasetNode::DatasetNode() { | |||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| connector_que_size_ = cfg->op_connector_size(); | connector_que_size_ = cfg->op_connector_size(); | ||||
| worker_connector_size_ = cfg->worker_connector_size(); | worker_connector_size_ = cfg->worker_connector_size(); | ||||
| build_status = Status::OK(); // remove me after changing return val of Build() | |||||
| } | } | ||||
| // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | // In DFS tree traversal, each node is visited twice. Accept is called on the first visit. | ||||
| @@ -254,5 +255,13 @@ Status DatasetNode::AcceptAfter(NodePass *p, bool *modified) { | |||||
| // This method will only be called if its derived class does not implement one. | // This method will only be called if its derived class does not implement one. | ||||
| return p->VisitAfter(shared_from_this(), modified); | return p->VisitAfter(shared_from_this(), modified); | ||||
| } | } | ||||
| Status DatasetNode::GetShardId(int32_t *shard_id) { | |||||
| if (!Children().empty()) { | |||||
| // Get shard id from the child node | |||||
| return Children()[0]->GetShardId(shard_id); | |||||
| } else { | |||||
| RETURN_STATUS_SYNTAX_ERROR("Get Shard Id failed at source node"); | |||||
| } | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -99,9 +99,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| /// \brief Pure virtual function for derived class to get the shard id of specific node | /// \brief Pure virtual function for derived class to get the shard id of specific node | ||||
| /// \return Status Status::OK() if get shard id successfully | /// \return Status Status::OK() if get shard id successfully | ||||
| virtual Status GetShardId(int32_t *shard_id) { | |||||
| return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); | |||||
| } | |||||
| virtual Status GetShardId(int32_t *shard_id); | |||||
| /// \brief Setter function for runtime number of workers | /// \brief Setter function for runtime number of workers | ||||
| /// \param[in] num_workers The number of threads in this operator | /// \param[in] num_workers The number of threads in this operator | ||||
| @@ -126,6 +124,10 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| /// \return Status of the node visit | /// \return Status of the node visit | ||||
| virtual Status AcceptAfter(NodePass *p, bool *modified); | virtual Status AcceptAfter(NodePass *p, bool *modified); | ||||
| /// \brief Method to get status from Node.Build() | |||||
| /// \notes Remove me after changing return val of Build() | |||||
| Status BuildStatus() { return build_status; } | |||||
| protected: | protected: | ||||
| std::vector<std::shared_ptr<DatasetNode>> children; | std::vector<std::shared_ptr<DatasetNode>> children; | ||||
| std::shared_ptr<DatasetCache> cache_; | std::shared_ptr<DatasetCache> cache_; | ||||
| @@ -135,6 +137,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> { | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t connector_que_size_; | int32_t connector_que_size_; | ||||
| int32_t worker_connector_size_; | int32_t worker_connector_size_; | ||||
| Status build_status; // remove me after changing return val of Build() | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -28,7 +28,7 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for FilterNode | // Constructor for FilterNode | ||||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, | |||||
| FilterNode::FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | |||||
| std::vector<std::string> input_columns) | std::vector<std::string> input_columns) | ||||
| : predicate_(predicate), input_columns_(input_columns) { | : predicate_(predicate), input_columns_(input_columns) { | ||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| @@ -38,10 +38,7 @@ std::vector<std::shared_ptr<DatasetOp>> FilterNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| std::shared_ptr<TensorOp> c_func; | |||||
| c_func = std::make_shared<CFuncOp>(predicate_); | |||||
| node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, c_func)); | |||||
| node_ops.push_back(std::make_shared<FilterOp>(input_columns_, num_workers_, connector_que_size_, predicate_)); | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ namespace dataset { | |||||
| class FilterNode : public DatasetNode { | class FilterNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| FilterNode(std::shared_ptr<DatasetNode> child, std::function<TensorRow(TensorRow)> predicate, | |||||
| FilterNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<TensorOp> predicate, | |||||
| std::vector<std::string> input_columns = {}); | std::vector<std::string> input_columns = {}); | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| @@ -44,7 +44,7 @@ class FilterNode : public DatasetNode { | |||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| private: | private: | ||||
| std::function<TensorRow(TensorRow)> predicate_; | |||||
| std::shared_ptr<TensorOp> predicate_; | |||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| }; | }; | ||||
| @@ -64,7 +64,8 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | auto project_op = std::make_shared<ProjectOp>(project_columns_); | ||||
| node_ops.push_back(project_op); | node_ops.push_back(project_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(map_op); | node_ops.push_back(map_op); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -59,7 +59,8 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() { | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| auto schema = std::make_unique<DataSchema>(); | auto schema = std::make_unique<DataSchema>(); | ||||
| RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_)); | |||||
| build_status = schema->LoadSchemaFile(schema_path_, column_names_); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| // Argument that is not exposed to user in the API. | // Argument that is not exposed to user in the API. | ||||
| std::set<std::string> extensions = {}; | std::set<std::string> extensions = {}; | ||||
| @@ -60,7 +60,8 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| decode_, usage_, extensions_, std::move(schema), | decode_, usage_, extensions_, std::move(schema), | ||||
| @@ -56,7 +56,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| @@ -54,7 +54,8 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | ||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| @@ -197,18 +197,23 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() { | |||||
| std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( | ||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, sorted_dataset_files, | ||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(clue_op->Init()); | |||||
| build_status = clue_op->Init(); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | std::shared_ptr<DatasetOp> shuffle_op = nullptr; | ||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| // First, get the number of rows in the dataset | // First, get the number of rows in the dataset | ||||
| RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||||
| build_status = ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| // Add the shuffle op after this op | // Add the shuffle op after this op | ||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | ||||
| @@ -111,7 +111,8 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() { | |||||
| std::shared_ptr<CocoOp> op = | std::shared_ptr<CocoOp> op = | ||||
| std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_, | ||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(op); | node_ops.push_back(op); | ||||
| @@ -108,18 +108,23 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() { | |||||
| std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | std::make_shared<CsvOp>(sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, | ||||
| rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | rows_per_buffer_, num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, | ||||
| num_shards_, shard_id_, std::move(sampler_->Build())); | num_shards_, shard_id_, std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(csv_op->Init()); | |||||
| build_status = csv_op->Init(); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | std::shared_ptr<DatasetOp> shuffle_op = nullptr; | ||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| // First, get the number of rows in the dataset | // First, get the number of rows in the dataset | ||||
| RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows)); | |||||
| build_status = CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| // Add the shuffle op after this op | // Add the shuffle op after this op | ||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| @@ -30,7 +30,25 @@ GeneratorNode::GeneratorNode(py::function generator_function, const std::vector< | |||||
| const std::vector<DataType> &column_types) | const std::vector<DataType> &column_types) | ||||
| : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} | : generator_function_(generator_function), column_names_(column_names), column_types_(column_types) {} | ||||
| GeneratorNode::GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema) | |||||
| : generator_function_(generator_function), schema_(schema) {} | |||||
| std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | ||||
| std::unique_ptr<DataSchema> data_schema = std::make_unique<DataSchema>(); | |||||
| if (schema_ != nullptr) { | |||||
| column_names_.clear(); | |||||
| column_types_.clear(); | |||||
| std::string schema_json_string = schema_->to_json(); | |||||
| RETURN_EMPTY_IF_ERROR(data_schema->LoadSchemaString(schema_json_string, {})); | |||||
| for (int32_t i = 0; i < data_schema->NumColumns(); i++) { | |||||
| ColDescriptor col = data_schema->column(i); | |||||
| column_names_.push_back(col.name()); | |||||
| column_types_.push_back((col.type())); | |||||
| } | |||||
| } | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | // GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by | ||||
| @@ -43,6 +61,8 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | |||||
| // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which | // This method can be privatized once we move Init() to Generator's functor. However, that is a bigger change which | ||||
| // best be delivered when the test cases for this api is ready. | // best be delivered when the test cases for this api is ready. | ||||
| Status rc = op->Init(); | Status rc = op->Init(); | ||||
| build_status = rc; // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| if (rc.IsOk()) { | if (rc.IsOk()) { | ||||
| node_ops.push_back(op); | node_ops.push_back(op); | ||||
| @@ -56,5 +76,11 @@ std::vector<std::shared_ptr<DatasetOp>> GeneratorNode::Build() { | |||||
| // no validation is needed for generator op. | // no validation is needed for generator op. | ||||
| Status GeneratorNode::ValidateParams() { return Status::OK(); } | Status GeneratorNode::ValidateParams() { return Status::OK(); } | ||||
| Status GeneratorNode::GetShardId(int32_t *shard_id) { | |||||
| RETURN_UNEXPECTED_IF_NULL(shard_id); | |||||
| *shard_id = 0; | |||||
| return Status::OK(); | |||||
| } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,6 +35,9 @@ class GeneratorNode : public DatasetNode { | |||||
| GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | GeneratorNode(py::function generator_function, const std::vector<std::string> &column_names, | ||||
| const std::vector<DataType> &column_types); | const std::vector<DataType> &column_types); | ||||
| /// \brief Constructor | |||||
| GeneratorNode(py::function generator_function, const std::shared_ptr<SchemaObj> &schema); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~GeneratorNode() = default; | ~GeneratorNode() = default; | ||||
| @@ -46,10 +49,15 @@ class GeneratorNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| /// \brief Get the shard id of node, is always 0 because generator_node doesn't support sharding | |||||
| /// \return Status Status::OK() if get shard id successfully | |||||
| Status GetShardId(int32_t *shard_id) override; | |||||
| private: | private: | ||||
| py::function generator_function_; | py::function generator_function_; | ||||
| std::vector<std::string> column_names_; | std::vector<std::string> column_names_; | ||||
| std::vector<DataType> column_types_; | std::vector<DataType> column_types_; | ||||
| std::shared_ptr<SchemaObj> schema_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -62,7 +62,8 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| recursive_, decode_, exts_, class_indexing_, std::move(schema), | recursive_, decode_, exts_, class_indexing_, std::move(schema), | ||||
| @@ -79,7 +79,8 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() { | |||||
| manifest_op = | manifest_op = | ||||
| std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_, | ||||
| class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | class_index_, std::move(schema), std::move(sampler_->Build()), usage_); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(manifest_op); | node_ops.push_back(manifest_op); | ||||
| @@ -138,7 +138,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| std::vector<std::shared_ptr<ShardOperator>> operators_; | std::vector<std::shared_ptr<ShardOperator>> operators_; | ||||
| RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_)); | |||||
| build_status = BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| std::shared_ptr<MindRecordOp> mindrecord_op; | std::shared_ptr<MindRecordOp> mindrecord_op; | ||||
| // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | // If pass a string to MindData(), it will be treated as a pattern to search for matched files, | ||||
| @@ -154,7 +155,8 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() { | |||||
| padded_sample_, sample_bytes_); | padded_sample_, sample_bytes_); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(mindrecord_op->Init()); | |||||
| build_status = mindrecord_op->Init(); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(mindrecord_op); | node_ops.push_back(mindrecord_op); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -51,7 +51,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() { | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | TensorShape scalar = TensorShape::CreateScalar(); | ||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | ||||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | ||||
| @@ -98,7 +98,8 @@ std::vector<std::shared_ptr<DatasetOp>> RandomNode::Build() { | |||||
| std::shared_ptr<RandomDataOp> op; | std::shared_ptr<RandomDataOp> op; | ||||
| op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | op = std::make_shared<RandomDataOp>(num_workers_, connector_que_size_, rows_per_buffer_, total_rows_, | ||||
| std::move(data_schema), std::move(sampler_->Build())); | std::move(data_schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(op); | node_ops.push_back(op); | ||||
| @@ -78,7 +78,8 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | ||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, | ||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||||
| build_status = text_file_op->Init(); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| @@ -86,14 +87,17 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileNode::Build() { | |||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| // First, get the number of rows in the dataset | // First, get the number of rows in the dataset | ||||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows)); | |||||
| build_status = TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| // Add the shuffle op after this op | // Add the shuffle op after this op | ||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| build_status = AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| // Add TextFileOp | // Add TextFileOp | ||||
| node_ops.push_back(text_file_op); | node_ops.push_back(text_file_op); | ||||
| @@ -118,7 +118,8 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | std::move(data_schema), connector_que_size_, columns_list_, shuffle_files, num_shards_, | ||||
| shard_id_, shard_equal_rows_, std::move(sampler_->Build())); | shard_id_, shard_equal_rows_, std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(tf_reader_op->Init()); | |||||
| build_status = tf_reader_op->Init(); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | if (cache_ == nullptr && shuffle_ == ShuffleMode::kGlobal) { | ||||
| // Inject ShuffleOp | // Inject ShuffleOp | ||||
| @@ -127,14 +128,17 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordNode::Build() { | |||||
| int64_t num_rows = 0; | int64_t num_rows = 0; | ||||
| // First, get the number of rows in the dataset | // First, get the number of rows in the dataset | ||||
| RETURN_EMPTY_IF_ERROR(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files)); | |||||
| build_status = TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| // Add the shuffle op after this op | // Add the shuffle op after this op | ||||
| RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op)); | |||||
| build_status = AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, | |||||
| rows_per_buffer_, &shuffle_op); | |||||
| RETURN_EMPTY_IF_ERROR(build_status); // remove me after changing return val of Build() | |||||
| node_ops.push_back(shuffle_op); | node_ops.push_back(shuffle_op); | ||||
| } | } | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| // Add TFReaderOp | // Add TFReaderOp | ||||
| node_ops.push_back(tf_reader_op); | node_ops.push_back(tf_reader_op); | ||||
| @@ -106,7 +106,8 @@ std::vector<std::shared_ptr<DatasetOp>> VOCNode::Build() { | |||||
| std::shared_ptr<VOCOp> voc_op; | std::shared_ptr<VOCOp> voc_op; | ||||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | ||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | ||||
| RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops)); | |||||
| build_status = AddCacheOp(&node_ops); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| node_ops.push_back(voc_op); | node_ops.push_back(voc_op); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -27,9 +27,8 @@ namespace mindspore { | |||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for SyncWaitNode | // Constructor for SyncWaitNode | ||||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||||
| py::function callback) | |||||
| : condition_name_(condition_name), num_batch_(num_batch), callback_(callback) { | |||||
| SyncWaitNode::SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback) | |||||
| : condition_name_(condition_name), callback_(callback) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -38,20 +37,16 @@ std::vector<std::shared_ptr<DatasetOp>> SyncWaitNode::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | // A vector containing shared pointer to the Dataset Ops that this object will create | ||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | std::vector<std::shared_ptr<DatasetOp>> node_ops; | ||||
| node_ops.push_back(std::make_shared<BarrierOp>(num_batch_, connector_que_size_, condition_name_, callback_)); | |||||
| // Right now barrier should only take num_rows_per_buffer = 1 | |||||
| // The reason for this is because having it otherwise can lead to blocking issues | |||||
| // See barrier_op.h for more details | |||||
| int32_t rows_per_buffer = 1; | |||||
| node_ops.push_back(std::make_shared<BarrierOp>(rows_per_buffer, connector_que_size_, condition_name_, callback_)); | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Function to validate the parameters for SyncWaitNode | // Function to validate the parameters for SyncWaitNode | ||||
| Status SyncWaitNode::ValidateParams() { | |||||
| if (num_batch_ <= 0) { | |||||
| std::string err_msg = "SyncWaitNode: num_batch must be greater than 0, num_batch: " + std::to_string(num_batch_); | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status SyncWaitNode::ValidateParams() { return Status::OK(); } | |||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,8 +31,7 @@ namespace dataset { | |||||
| class SyncWaitNode : public DatasetNode { | class SyncWaitNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, int32_t num_batch, | |||||
| py::function callback); | |||||
| SyncWaitNode(std::shared_ptr<DatasetNode> child, const std::string &condition_name, py::function callback); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~SyncWaitNode() = default; | ~SyncWaitNode() = default; | ||||
| @@ -47,7 +46,6 @@ class SyncWaitNode : public DatasetNode { | |||||
| private: | private: | ||||
| std::string condition_name_; | std::string condition_name_; | ||||
| int32_t num_batch_; | |||||
| py::function callback_; | py::function callback_; | ||||
| }; | }; | ||||
| @@ -18,72 +18,80 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/engine/datasetops/device_queue_op.h" | #include "minddata/dataset/engine/datasetops/device_queue_op.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| // Constructor for TransferNode | // Constructor for TransferNode | ||||
| TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end) | |||||
| : prefetch_size_(16), send_epoch_end_(send_epoch_end), total_batch_(0) { | |||||
| TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, | |||||
| bool send_epoch_end, int32_t total_batch, bool create_data_info_queue) | |||||
| : prefetch_size_(16), | |||||
| queue_name_(std::move(queue_name)), | |||||
| device_type_(std::move(device_type)), | |||||
| send_epoch_end_(send_epoch_end), | |||||
| total_batch_(total_batch), | |||||
| create_data_info_queue_(create_data_info_queue), | |||||
| device_id_(0) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| // Validator for TransferNode | // Validator for TransferNode | ||||
| Status TransferNode::ValidateParams() { | Status TransferNode::ValidateParams() { | ||||
| // Check if device_type_ is in {"CPU", "GPU", "Ascend"} | |||||
| RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"})); | |||||
| if (total_batch_ < 0) { | |||||
| std::string err_msg = "TransferNode: Total batches should be >= 0, value given: "; | |||||
| MS_LOG(ERROR) << err_msg << total_batch_; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| // Function to build TransferNode | // Function to build TransferNode | ||||
| std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | std::vector<std::shared_ptr<DatasetOp>> TransferNode::Build() { | ||||
| // Get a uuid for queue name | |||||
| queue_name_ = Services::GetUniqueID(); | |||||
| // TODO(CRC): | |||||
| // Get device type from ms context | |||||
| device_type_ = "CPU"; | |||||
| // Get device ID from children | |||||
| device_id_ = 0; | |||||
| RETURN_EMPTY_IF_ERROR(TransferNode::get_distribution(shared_from_this(), &device_id_)); | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| if (queue_name_.empty()) { | |||||
| // Get a uuid for queue name | |||||
| queue_name_ = Services::GetUniqueID(); | |||||
| } | |||||
| if (device_type_.empty()) { | |||||
| auto context = MsContext::GetInstance(); | |||||
| if (context == nullptr) { | |||||
| device_type_ = kCPUDevice; | |||||
| } else { | |||||
| device_type_ = context->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| } | |||||
| } | |||||
| // Get device type from ms context | |||||
| // Convert device_type_ from string to DeviceType | // Convert device_type_ from string to DeviceType | ||||
| DeviceQueueOp::DeviceType type; | DeviceQueueOp::DeviceType type; | ||||
| if (device_type_ == "CPU") { | |||||
| if (device_type_ == kCPUDevice) { | |||||
| type = DeviceQueueOp::DeviceType::CPU; | type = DeviceQueueOp::DeviceType::CPU; | ||||
| } else if (device_type_ == "GPU") { | |||||
| } else if (device_type_ == kGPUDevice) { | |||||
| type = DeviceQueueOp::DeviceType::GPU; | type = DeviceQueueOp::DeviceType::GPU; | ||||
| } else if (device_type_ == "Ascend") { | |||||
| } else if (device_type_ == kAscendDevice) { | |||||
| type = DeviceQueueOp::DeviceType::Ascend; | type = DeviceQueueOp::DeviceType::Ascend; | ||||
| } else { | |||||
| MS_LOG(ERROR) << "Unknown device target."; | |||||
| return {}; | |||||
| } | } | ||||
| node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||||
| total_batch_, false)); | |||||
| return node_ops; | |||||
| } | |||||
| // Function to get the device_id | |||||
| Status TransferNode::get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id) { | |||||
| // Get device id according to the type of dataset | |||||
| Status rc = ds->GetShardId(device_id); | |||||
| if (rc != Status::OK()) { | |||||
| // Get device id from the child node | |||||
| if (ds->Children().size()) { | |||||
| ds = ds->Children()[0]; | |||||
| return TransferNode::get_distribution(ds, device_id); | |||||
| } else { | |||||
| std::string err_msg = "Unknown dataset type."; | |||||
| MS_LOG(ERROR) << err_msg; | |||||
| RETURN_STATUS_SYNTAX_ERROR(err_msg); | |||||
| } | |||||
| } | |||||
| // Get device ID (shard ID) from children | |||||
| device_id_ = 0; | |||||
| build_status = this->GetShardId(&device_id_); // remove me after changing return val of Build() | |||||
| RETURN_EMPTY_IF_ERROR(build_status); | |||||
| return Status::OK(); | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| node_ops.push_back(std::make_shared<DeviceQueueOp>(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, | |||||
| total_batch_, create_data_info_queue_)); | |||||
| return node_ops; | |||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -29,7 +29,8 @@ namespace dataset { | |||||
| class TransferNode : public DatasetNode { | class TransferNode : public DatasetNode { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| TransferNode(std::shared_ptr<DatasetNode> child, bool send_epoch_end); | |||||
| TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type, bool send_epoch_end, | |||||
| int32_t total_batch, bool create_data_info_queue); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~TransferNode() = default; | ~TransferNode() = default; | ||||
| @@ -42,8 +43,6 @@ class TransferNode : public DatasetNode { | |||||
| /// \return Status Status::OK() if all the parameters are valid | /// \return Status Status::OK() if all the parameters are valid | ||||
| Status ValidateParams() override; | Status ValidateParams() override; | ||||
| static Status get_distribution(std::shared_ptr<DatasetNode> ds, int32_t *device_id); | |||||
| private: | private: | ||||
| std::string queue_name_; | std::string queue_name_; | ||||
| int32_t device_id_; | int32_t device_id_; | ||||
| @@ -51,6 +50,7 @@ class TransferNode : public DatasetNode { | |||||
| int32_t prefetch_size_; | int32_t prefetch_size_; | ||||
| bool send_epoch_end_; | bool send_epoch_end_; | ||||
| int32_t total_batch_; | int32_t total_batch_; | ||||
| bool create_data_info_queue_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -40,21 +40,7 @@ Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<TakeOp> node, bool *mo | |||||
| } | } | ||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) { | ||||
| if (type_ == kOutputShapeAndType) { | |||||
| nodes_to_clear_callback_.push_back(node); | |||||
| } else if (type_ == kDatasetSize) { | |||||
| nodes_to_remove_.push_back(node); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) { | |||||
| if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); | |||||
| return Status::OK(); | |||||
| } | |||||
| Status GetterPass::GetterNodes::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) { | |||||
| if (type_ == kDatasetSize) nodes_to_remove_.push_back(node); | |||||
| nodes_to_clear_callback_.push_back(node); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -83,5 +69,6 @@ Status GetterPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,10 @@ class GetterPass : public TreePass { | |||||
| enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; | enum GetterType { kDatasetSize = 1, kOutputShapeAndType = 2 }; | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| explicit GetterPass(GetterType tp) : pass_(tp) {} | explicit GetterPass(GetterType tp) : pass_(tp) {} | ||||
| /// \brief default copy Constructor | |||||
| explicit GetterPass(const GetterPass &) = default; | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~GetterPass() = default; | ~GetterPass() = default; | ||||
| @@ -51,11 +55,10 @@ class GetterPass : public TreePass { | |||||
| Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override; | ||||
| Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override; | ||||
| Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override { return Status::OK(); } | |||||
| Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override; | ||||
| Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override; | ||||
| Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override; | ||||
| Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override; | |||||
| Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override; | |||||
| // whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp | // whether this is Run or PreRun does not matter here, however, Only Accept() is defined in ConcatOp | ||||
| Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | Status PreRunOnNode(std::shared_ptr<ConcatOp> node, bool *modified) override; | ||||
| @@ -67,7 +70,7 @@ class GetterPass : public TreePass { | |||||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_; | std::list<std::shared_ptr<DatasetOp>> nodes_to_clear_callback_; | ||||
| std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_; | std::list<std::shared_ptr<DatasetOp>> nodes_to_remove_; | ||||
| }; | }; | ||||
| // outter class needs only to own the inner class object since it automatically has access to its private variables | |||||
| // outer class needs only to own the inner class object since it automatically has access to its private variables | |||||
| GetterNodes pass_; | GetterNodes pass_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -19,7 +19,14 @@ | |||||
| namespace mindspore::dataset { | namespace mindspore::dataset { | ||||
| Status PythonRuntimeContext::Terminate() { return TerminateImpl(); } | |||||
| Status PythonRuntimeContext::Terminate() { | |||||
| MS_LOG(INFO) << "Terminating a PythonRuntime"; | |||||
| if (tree_consumer_ != nullptr) { | |||||
| return TerminateImpl(); | |||||
| } | |||||
| MS_LOG(WARNING) << "TreeConsumer was not initialized"; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status PythonRuntimeContext::TerminateImpl() { | Status PythonRuntimeContext::TerminateImpl() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | ||||
| @@ -22,7 +22,14 @@ namespace mindspore::dataset { | |||||
| void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { | void RuntimeContext::AssignConsumer(std::shared_ptr<TreeConsumer> tree_consumer) { | ||||
| tree_consumer_ = std::move(tree_consumer); | tree_consumer_ = std::move(tree_consumer); | ||||
| } | } | ||||
| Status NativeRuntimeContext::Terminate() { return TerminateImpl(); } | |||||
| Status NativeRuntimeContext::Terminate() { | |||||
| MS_LOG(INFO) << "Terminating a NativeRuntime"; | |||||
| if (tree_consumer_ != nullptr) { | |||||
| return TerminateImpl(); | |||||
| } | |||||
| MS_LOG(WARNING) << "TreeConsumer was not initialized"; | |||||
| return Status::OK(); | |||||
| } | |||||
| Status NativeRuntimeContext::TerminateImpl() { | Status NativeRuntimeContext::TerminateImpl() { | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | CHECK_FAIL_RETURN_UNEXPECTED(tree_consumer_ != nullptr, " Tree Consumer is not initialized"); | ||||
| @@ -97,6 +97,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) { | |||||
| Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | Status TreeAdapter::BuildExecutionTree(std::shared_ptr<DatasetNode> ir, std::shared_ptr<DatasetOp> *op) { | ||||
| // Build the DatasetOp ExecutionTree from the optimized IR tree | // Build the DatasetOp ExecutionTree from the optimized IR tree | ||||
| std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | std::vector<std::shared_ptr<DatasetOp>> ops = ir->Build(); | ||||
| RETURN_IF_NOT_OK(ir->BuildStatus()); // remove me after changing return val of Build() | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node."); | CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build node."); | ||||
| (*op) = ops.front(); // return the first op to be added as child by the caller of this function | (*op) = ops.front(); // return the first op to be added as child by the caller of this function | ||||
| @@ -141,6 +143,8 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||||
| RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); | RETURN_IF_NOT_OK(BuildExecutionTree(root_ir, &root_op)); | ||||
| RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); | ||||
| if (pre_pass_override_) tree_->SetPrePassOverride(pre_pass_override_); | |||||
| // Note: We will gradually move the pre pass, optimizer pass, and post pass | // Note: We will gradually move the pre pass, optimizer pass, and post pass | ||||
| // on ExecutionTree to perform on IR tree. | // on ExecutionTree to perform on IR tree. | ||||
| // Prepare the tree | // Prepare the tree | ||||
| @@ -149,6 +153,11 @@ Status TreeAdapter::Compile(std::shared_ptr<DatasetNode> root_ir, int32_t num_ep | |||||
| // After the tree is prepared, the col_name_id_map can safely be obtained | // After the tree is prepared, the col_name_id_map can safely be obtained | ||||
| column_name_map_ = tree_->root()->column_name_id_map(); | column_name_map_ = tree_->root()->column_name_id_map(); | ||||
| // Profiling parameters init | |||||
| cur_batch_num_ = 0; | |||||
| cur_connector_size_ = 0; | |||||
| cur_connector_capacity_ = 0; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -156,21 +165,55 @@ Status TreeAdapter::GetNext(TensorRow *row) { | |||||
| RETURN_UNEXPECTED_IF_NULL(tree_); | RETURN_UNEXPECTED_IF_NULL(tree_); | ||||
| RETURN_UNEXPECTED_IF_NULL(row); | RETURN_UNEXPECTED_IF_NULL(row); | ||||
| row->clear(); // make sure row is empty | row->clear(); // make sure row is empty | ||||
| bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); | |||||
| // When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree | // When cur_db_ is a nullptr, it means this is the first call to get_next, launch ExecutionTree | ||||
| if (cur_db_ == nullptr) { | if (cur_db_ == nullptr) { | ||||
| RETURN_IF_NOT_OK(tree_->Launch()); | RETURN_IF_NOT_OK(tree_->Launch()); | ||||
| // Profiling | |||||
| std::shared_ptr<Tracing> node; | |||||
| Status s = tree_->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); | |||||
| if (s.IsOk()) { | |||||
| tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node); | |||||
| } | |||||
| if (tracing_ != nullptr) { | |||||
| cur_connector_size_ = tree_->root()->ConnectorSize(); | |||||
| cur_connector_capacity_ = tree_->root()->ConnectorCapacity(); | |||||
| } | |||||
| RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag | RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); // first buf can't be eof or empty buf with none flag | ||||
| RETURN_OK_IF_TRUE(cur_db_->eoe()); // return empty tensor if 1st buf is a ctrl buf (no rows) | |||||
| if (cur_db_->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows) | |||||
| MS_LOG(INFO) << "End of data iteration."; | |||||
| if (isProfilingEnable) { | |||||
| tree_->SetEpochEnd(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached."); | CHECK_FAIL_RETURN_UNEXPECTED(!cur_db_->eof(), "EOF has already been reached."); | ||||
| if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf | if (cur_db_->NumRows() == 0) { // a new row is fetched if cur buf is empty or a ctrl buf | ||||
| RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); | RETURN_IF_NOT_OK(tree_->root()->GetNextBuffer(&cur_db_)); | ||||
| RETURN_OK_IF_TRUE(cur_db_->eoe() || cur_db_->eof()); // return empty if this new buffer is a ctrl flag | |||||
| if (cur_db_->eoe()) { // return empty if this new buffer is a ctrl flag | |||||
| MS_LOG(INFO) << "End of data iteration."; | |||||
| if (isProfilingEnable) { | |||||
| tree_->SetEpochEnd(); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| if (cur_db_->eof()) { | |||||
| tree_->SetFinished(); | |||||
| std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."; | |||||
| RETURN_STATUS_UNEXPECTED(err); | |||||
| } | |||||
| } | } | ||||
| RETURN_IF_NOT_OK(cur_db_->PopRow(row)); | RETURN_IF_NOT_OK(cur_db_->PopRow(row)); | ||||
| // Record profiling info | |||||
| if (tracing_ != nullptr) { | |||||
| cur_batch_num_++; | |||||
| tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); | |||||
| } | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | ||||
| #include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -60,6 +61,9 @@ class TreeAdapter { | |||||
| // Set optional optimization pass | // Set optional optimization pass | ||||
| void SetOptimize(bool value) { optimize_ = value; } | void SetOptimize(bool value) { optimize_ = value; } | ||||
| // function to override override the pre-pass | |||||
| void SetPrePassOverride(std::function<OptPass(OptPass)> pre_pass_override) { pre_pass_override_ = pre_pass_override; } | |||||
| // Optional optimizations status | // Optional optimizations status | ||||
| bool OptimizationEnabled() const { return optimize_; } | bool OptimizationEnabled() const { return optimize_; } | ||||
| @@ -82,9 +86,14 @@ class TreeAdapter { | |||||
| std::unique_ptr<DataBuffer> cur_db_; | std::unique_ptr<DataBuffer> cur_db_; | ||||
| std::unordered_map<std::string, int32_t> column_name_map_; | std::unordered_map<std::string, int32_t> column_name_map_; | ||||
| std::unique_ptr<ExecutionTree> tree_; | |||||
| std::unique_ptr<ExecutionTree> tree_; // current connector capacity of root op, used for profiling | |||||
| int32_t num_epochs_; | int32_t num_epochs_; | ||||
| bool optimize_; // Flag to enable optional optimization pass | |||||
| bool optimize_; // Flag to enable optional optimization pass | |||||
| std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data | |||||
| int32_t cur_batch_num_; // current batch number, used for profiling | |||||
| int32_t cur_connector_size_; // current connector size of root op, used for profiling | |||||
| int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling | |||||
| std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -145,9 +145,16 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| /// \brief Function to transfer data through a device. | /// \brief Function to transfer data through a device. | ||||
| /// \notes If device is Ascend, features of data will be transferred one by one. The limitation | /// \notes If device is Ascend, features of data will be transferred one by one. The limitation | ||||
| /// of data transmission per time is 256M. | /// of data transmission per time is 256M. | ||||
| /// \param[in] queue_name Channel name (default="", create new unique name). | |||||
| /// \param[in] device_type Type of device (default="", get from MSContext). | |||||
| /// \param[in] num_epochs Number of epochs (default=-1, infinite epochs). | |||||
| /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). | /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true). | ||||
| /// \param[in] total_batches Number of batches to be sent to the device (default=0, all data). | |||||
| /// \param[in] create_data_info_queue Whether to create queue which stores types and shapes | |||||
| /// of data or not(default=false). | |||||
| /// \return Returns true if no error encountered else false. | /// \return Returns true if no error encountered else false. | ||||
| bool DeviceQueue(bool send_epoch_end = true); | |||||
| bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t num_epochs = -1, | |||||
| bool send_epoch_end = true, int32_t total_batches = 0, bool create_data_info_queue = false); | |||||
| /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline | /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline | ||||
| /// \note Usage restrictions: | /// \note Usage restrictions: | ||||
| @@ -371,21 +378,34 @@ class SchemaObj { | |||||
| /// \brief SchemaObj init function | /// \brief SchemaObj init function | ||||
| /// \return bool true if schema init success | /// \return bool true if schema init success | ||||
| bool init(); | |||||
| Status init(); | |||||
| /// \brief Add new column to the schema with unknown shape of rank 1 | |||||
| /// \param[in] name name of the column. | |||||
| /// \param[in] de_type data type of the column(TypeId). | |||||
| /// \return bool true if schema init success | |||||
| Status add_column(std::string name, TypeId de_type); | |||||
| /// \brief Add new column to the schema with unknown shape of rank 1 | |||||
| /// \param[in] name name of the column. | |||||
| /// \param[in] de_type data type of the column(std::string). | |||||
| /// \param[in] shape shape of the column. | |||||
| /// \return bool true if schema init success | |||||
| Status add_column(std::string name, std::string de_type); | |||||
| /// \brief Add new column to the schema | /// \brief Add new column to the schema | ||||
| /// \param[in] name name of the column. | /// \param[in] name name of the column. | ||||
| /// \param[in] de_type data type of the column(TypeId). | /// \param[in] de_type data type of the column(TypeId). | ||||
| /// \param[in] shape shape of the column. | /// \param[in] shape shape of the column. | ||||
| /// \return bool true if schema init success | /// \return bool true if schema init success | ||||
| bool add_column(std::string name, TypeId de_type, std::vector<int32_t> shape); | |||||
| Status add_column(std::string name, TypeId de_type, std::vector<int32_t> shape); | |||||
| /// \brief Add new column to the schema | /// \brief Add new column to the schema | ||||
| /// \param[in] name name of the column. | /// \param[in] name name of the column. | ||||
| /// \param[in] de_type data type of the column(std::string). | /// \param[in] de_type data type of the column(std::string). | ||||
| /// \param[in] shape shape of the column. | /// \param[in] shape shape of the column. | ||||
| /// \return bool true if schema init success | /// \return bool true if schema init success | ||||
| bool add_column(std::string name, std::string de_type, std::vector<int32_t> shape); | |||||
| Status add_column(std::string name, std::string de_type, std::vector<int32_t> shape); | |||||
| /// \brief Get a JSON string of the schema | /// \brief Get a JSON string of the schema | ||||
| /// \return JSON string of the schema | /// \return JSON string of the schema | ||||
| @@ -395,25 +415,27 @@ class SchemaObj { | |||||
| std::string to_string() { return to_json(); } | std::string to_string() { return to_json(); } | ||||
| /// \brief set a new value to dataset_type | /// \brief set a new value to dataset_type | ||||
| inline void set_dataset_type(std::string dataset_type) { dataset_type_ = dataset_type; } | |||||
| inline void set_dataset_type(std::string dataset_type) { dataset_type_ = std::move(dataset_type); } | |||||
| /// \brief set a new value to num_rows | /// \brief set a new value to num_rows | ||||
| inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; } | inline void set_num_rows(int32_t num_rows) { num_rows_ = num_rows; } | ||||
| /// \brief get the current num_rows | /// \brief get the current num_rows | ||||
| inline int32_t get_num_rows() { return num_rows_; } | |||||
| inline int32_t get_num_rows() const { return num_rows_; } | |||||
| Status FromJSONString(const std::string &json_string); | |||||
| private: | private: | ||||
| /// \brief Parse the columns and add it to columns | /// \brief Parse the columns and add it to columns | ||||
| /// \param[in] columns dataset attribution information, decoded from schema file. | /// \param[in] columns dataset attribution information, decoded from schema file. | ||||
| /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. | /// support both nlohmann::json::value_t::array and nlohmann::json::value_t::onject. | ||||
| /// \return JSON string of the schema | /// \return JSON string of the schema | ||||
| bool parse_column(nlohmann::json columns); | |||||
| Status parse_column(nlohmann::json columns); | |||||
| /// \brief Get schema file from json file | /// \brief Get schema file from json file | ||||
| /// \param[in] json_obj object of json parsed. | /// \param[in] json_obj object of json parsed. | ||||
| /// \return bool true if json dump success | /// \return bool true if json dump success | ||||
| bool from_json(nlohmann::json json_obj); | |||||
| Status from_json(nlohmann::json json_obj); | |||||
| int32_t num_rows_; | int32_t num_rows_; | ||||
| std::string dataset_type_; | std::string dataset_type_; | ||||
| @@ -61,6 +61,7 @@ class SamplerObj : public std::enable_shared_from_this<SamplerObj> { | |||||
| class DistributedSamplerObj; | class DistributedSamplerObj; | ||||
| class PKSamplerObj; | class PKSamplerObj; | ||||
| class PreBuiltSamplerObj; | |||||
| class RandomSamplerObj; | class RandomSamplerObj; | ||||
| class SequentialSamplerObj; | class SequentialSamplerObj; | ||||
| class SubsetRandomSamplerObj; | class SubsetRandomSamplerObj; | ||||
| @@ -171,6 +172,31 @@ class PKSamplerObj : public SamplerObj { | |||||
| int64_t num_samples_; | int64_t num_samples_; | ||||
| }; | }; | ||||
| class PreBuiltSamplerObj : public SamplerObj { | |||||
| public: | |||||
| #ifndef ENABLE_ANDROID | |||||
| explicit PreBuiltSamplerObj(std::shared_ptr<SamplerRT> sampler); | |||||
| explicit PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator> sampler); | |||||
| #endif | |||||
| ~PreBuiltSamplerObj() = default; | |||||
| std::shared_ptr<SamplerRT> Build() override; | |||||
| #ifndef ENABLE_ANDROID | |||||
| std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override; | |||||
| #endif | |||||
| bool ValidateParams() override; | |||||
| private: | |||||
| std::shared_ptr<SamplerRT> sp_; | |||||
| #ifndef ENABLE_ANDROID | |||||
| std::shared_ptr<mindrecord::ShardOperator> sp_minddataset_; | |||||
| #endif | |||||
| }; | |||||
| class RandomSamplerObj : public SamplerObj { | class RandomSamplerObj : public SamplerObj { | ||||
| public: | public: | ||||
| RandomSamplerObj(bool replacement, int64_t num_samples); | RandomSamplerObj(bool replacement, int64_t num_samples); | ||||
| @@ -70,6 +70,7 @@ namespace transforms { | |||||
| class ComposeOperation; | class ComposeOperation; | ||||
| class DuplicateOperation; | class DuplicateOperation; | ||||
| class OneHotOperation; | class OneHotOperation; | ||||
| class PreBuiltOperation; | |||||
| class RandomApplyOperation; | class RandomApplyOperation; | ||||
| class RandomChoiceOperation; | class RandomChoiceOperation; | ||||
| class TypeCastOperation; | class TypeCastOperation; | ||||
| @@ -164,6 +165,20 @@ class OneHotOperation : public TensorOperation { | |||||
| float num_classes_; | float num_classes_; | ||||
| }; | }; | ||||
| class PreBuiltOperation : public TensorOperation { | |||||
| public: | |||||
| explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op); | |||||
| ~PreBuiltOperation() = default; | |||||
| std::shared_ptr<TensorOp> Build() override; | |||||
| Status ValidateParams() override; | |||||
| private: | |||||
| std::shared_ptr<TensorOp> op_; | |||||
| }; | |||||
| class RandomApplyOperation : public TensorOperation { | class RandomApplyOperation : public TensorOperation { | ||||
| public: | public: | ||||
| explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob); | explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob); | ||||
| @@ -192,7 +207,6 @@ class RandomChoiceOperation : public TensorOperation { | |||||
| private: | private: | ||||
| std::vector<std::shared_ptr<TensorOperation>> transforms_; | std::vector<std::shared_ptr<TensorOperation>> transforms_; | ||||
| }; | }; | ||||
| class TypeCastOperation : public TensorOperation { | class TypeCastOperation : public TensorOperation { | ||||
| public: | public: | ||||
| explicit TypeCastOperation(std::string data_type); | explicit TypeCastOperation(std::string data_type); | ||||
| @@ -71,6 +71,15 @@ namespace dataset { | |||||
| return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ | return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ | ||||
| } while (false) | } while (false) | ||||
| #define RETURN_SECOND_IF_ERROR(_s, _r) \ | |||||
| do { \ | |||||
| Status __rc = (_s); \ | |||||
| if (__rc.IsError()) { \ | |||||
| MS_LOG(ERROR) << __rc; \ | |||||
| return _r; \ | |||||
| } \ | |||||
| } while (false) | |||||
| enum class StatusCode : char { | enum class StatusCode : char { | ||||
| kOK = 0, | kOK = 0, | ||||
| kOutOfMemory = 1, | kOutOfMemory = 1, | ||||
| @@ -138,7 +138,9 @@ Status Task::Join(WaitFlag blocking) { | |||||
| while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { | while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { | ||||
| // We can't tell which conditional_variable this thread is waiting on. So we may need | // We can't tell which conditional_variable this thread is waiting on. So we may need | ||||
| // to interrupt everything one more time. | // to interrupt everything one more time. | ||||
| MS_LOG(INFO) << "Some threads not responding. Interrupt again"; | |||||
| std::stringstream ss; | |||||
| ss << get_id(); | |||||
| MS_LOG(ERROR) << MyName() << " Thread ID " << ss.str() << " is not responding. Interrupt again"; | |||||
| interrupt_svc->InterruptAll(); | interrupt_svc->InterruptAll(); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -21,7 +21,8 @@ import numpy | |||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', | __all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', | ||||
| 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load'] | |||||
| 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load', | |||||
| 'get_callback_timeout'] | |||||
| INT32_MAX = 2147483647 | INT32_MAX = 2147483647 | ||||
| UINT32_MAX = 4294967295 | UINT32_MAX = 4294967295 | ||||
| @@ -65,5 +65,7 @@ def mstypelist_to_detypelist(type_list): | |||||
| for index, _ in enumerate(type_list): | for index, _ in enumerate(type_list): | ||||
| if type_list[index] is not None: | if type_list[index] is not None: | ||||
| type_list[index] = mstype_to_detype(type_list[index]) | type_list[index] = mstype_to_detype(type_list[index]) | ||||
| else: | |||||
| type_list[index] = cde.DataType("") | |||||
| return type_list | return type_list | ||||
| @@ -15,17 +15,13 @@ | |||||
| """Built-in iterators. | """Built-in iterators. | ||||
| """ | """ | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| import copy | |||||
| import weakref | import weakref | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore._c_dataengine import DEPipeline | |||||
| from mindspore._c_dataengine import OpName | |||||
| import mindspore._c_dataengine as cde | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from . import datasets as de | |||||
| _ITERATOR_CLEANUP = False | _ITERATOR_CLEANUP = False | ||||
| @@ -57,29 +53,6 @@ def _cleanup(): | |||||
| itr.release() | itr.release() | ||||
| def alter_tree(node): | |||||
| """Traversing the Python dataset tree/graph to perform some alteration to some specific nodes.""" | |||||
| if not node.children: | |||||
| return _alter_node(node) | |||||
| converted_children = [] | |||||
| for input_op in node.children: | |||||
| converted_children.append(alter_tree(input_op)) | |||||
| node.children = converted_children | |||||
| return _alter_node(node) | |||||
| def _alter_node(node): | |||||
| """DEPRECATED""" | |||||
| # Please check ccsrc/dataset/engine/opt for tree transformation. | |||||
| if isinstance(node, de.MapDataset): | |||||
| if node.python_multiprocessing: | |||||
| # Bootstrap can only be performed on a copy of the original dataset node. | |||||
| # Bootstrap on original dataset node will make all iterators share the same process pool | |||||
| node.iterator_bootstrap() | |||||
| return node | |||||
| class Iterator: | class Iterator: | ||||
| """ | """ | ||||
| General Iterator over a dataset. | General Iterator over a dataset. | ||||
| @@ -89,185 +62,62 @@ class Iterator: | |||||
| """ | """ | ||||
| def __init__(self, dataset, num_epochs=-1, output_numpy=False): | def __init__(self, dataset, num_epochs=-1, output_numpy=False): | ||||
| self.num_epochs = num_epochs | |||||
| self.output_numpy = output_numpy | |||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| _unset_iterator_cleanup() | |||||
| self._col_names = None | |||||
| # create a copy of tree and work on it. | # create a copy of tree and work on it. | ||||
| self.dataset = copy.deepcopy(dataset) | |||||
| self.ori_dataset = dataset | self.ori_dataset = dataset | ||||
| self.parent_subtree = [] | |||||
| # The dataset passed into the iterator is not the root of the tree. | |||||
| # Trim the tree by saving the parent subtree into self.parent_subtree and | |||||
| # restore it after launching our C++ pipeline. | |||||
| if self.dataset.parent: | |||||
| logger.info("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.") | |||||
| self.parent_subtree = self.dataset.parent | |||||
| self.dataset.parent = [] | |||||
| self.dataset = alter_tree(self.dataset) | |||||
| if not self.__is_tree(): | |||||
| raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers).") | |||||
| self.depipeline = DEPipeline() | |||||
| # for manifest temporary use | |||||
| self.__batch_node(self.dataset, 0) | |||||
| root = self.__convert_node_postorder(self.dataset) | |||||
| self.depipeline.AssignRootNode(root) | |||||
| self.depipeline.PrepareTree(self.num_epochs) | |||||
| self.ir_tree, self.dataset = dataset.create_ir_tree() | |||||
| self._runtime_context = cde.PythonRuntimeContext() | |||||
| self._runtime_context.Init() | |||||
| consumer = cde.PythonIteratorConsumer(num_epochs) | |||||
| consumer.Init(self.ir_tree) | |||||
| self._runtime_context.AssignConsumer(consumer) | |||||
| self._iterator = self._runtime_context.GetConsumer() | |||||
| self._transform_tensor = lambda t: t.as_array() | |||||
| if not output_numpy: | |||||
| self._transform_tensor = lambda t: Tensor(t.as_array()) | |||||
| self._index = 0 | self._index = 0 | ||||
| # todo remove next when ContextManager is done | |||||
| ITERATORS_LIST.append(weakref.ref(self)) | |||||
| _unset_iterator_cleanup() | |||||
| ####### | |||||
| def __iter__(self): | |||||
| return self | |||||
| def stop(self): | def stop(self): | ||||
| """ | """ | ||||
| Manually terminate Python iterator instead of relying on out of scope destruction. | Manually terminate Python iterator instead of relying on out of scope destruction. | ||||
| """ | """ | ||||
| logger.info("Terminating Python iterator. This will also terminate C++ pipeline.") | logger.info("Terminating Python iterator. This will also terminate C++ pipeline.") | ||||
| if hasattr(self, 'depipeline') and self.depipeline: | |||||
| del self.depipeline | |||||
| def __is_tree_node(self, node): | |||||
| """Check if a node is tree node.""" | |||||
| if not node.children: | |||||
| if len(node.parent) > 1: | |||||
| return False | |||||
| if len(node.parent) > 1: | |||||
| return False | |||||
| for input_node in node.children: | |||||
| cls = self.__is_tree_node(input_node) | |||||
| if not cls: | |||||
| return False | |||||
| return True | |||||
| def __is_tree(self): | |||||
| return self.__is_tree_node(self.dataset) | |||||
| @staticmethod | |||||
| def __get_dataset_type(dataset): | |||||
| """Get the dataset type.""" | |||||
| op_type = None | |||||
| if isinstance(dataset, de.ShuffleDataset): | |||||
| op_type = OpName.SHUFFLE | |||||
| elif isinstance(dataset, de.MindDataset): | |||||
| op_type = OpName.MINDRECORD | |||||
| elif isinstance(dataset, de.BatchDataset): | |||||
| op_type = OpName.BATCH | |||||
| elif isinstance(dataset, de.BucketBatchByLengthDataset): | |||||
| op_type = OpName.BUCKETBATCH | |||||
| elif isinstance(dataset, de.SyncWaitDataset): | |||||
| op_type = OpName.BARRIER | |||||
| elif isinstance(dataset, de.ZipDataset): | |||||
| op_type = OpName.ZIP | |||||
| elif isinstance(dataset, de.ConcatDataset): | |||||
| op_type = OpName.CONCAT | |||||
| elif isinstance(dataset, de.MapDataset): | |||||
| op_type = OpName.MAP | |||||
| elif isinstance(dataset, de.FilterDataset): | |||||
| op_type = OpName.FILTER | |||||
| elif isinstance(dataset, de.RepeatDataset): | |||||
| op_type = OpName.REPEAT | |||||
| elif isinstance(dataset, de.SkipDataset): | |||||
| op_type = OpName.SKIP | |||||
| elif isinstance(dataset, de.TakeDataset): | |||||
| op_type = OpName.TAKE | |||||
| elif isinstance(dataset, de.ImageFolderDataset): | |||||
| op_type = OpName.IMAGEFOLDER | |||||
| elif isinstance(dataset, de.GeneratorDataset): | |||||
| op_type = OpName.GENERATOR | |||||
| elif isinstance(dataset, de.TransferDataset): | |||||
| op_type = OpName.DEVICEQUEUE | |||||
| elif isinstance(dataset, de.RenameDataset): | |||||
| op_type = OpName.RENAME | |||||
| elif isinstance(dataset, de.TFRecordDataset): | |||||
| op_type = OpName.TFREADER | |||||
| elif isinstance(dataset, de.ProjectDataset): | |||||
| op_type = OpName.PROJECT | |||||
| elif isinstance(dataset, de.MnistDataset): | |||||
| op_type = OpName.MNIST | |||||
| elif isinstance(dataset, de.ManifestDataset): | |||||
| op_type = OpName.MANIFEST | |||||
| elif isinstance(dataset, de.VOCDataset): | |||||
| op_type = OpName.VOC | |||||
| elif isinstance(dataset, de.CocoDataset): | |||||
| op_type = OpName.COCO | |||||
| elif isinstance(dataset, de.Cifar10Dataset): | |||||
| op_type = OpName.CIFAR10 | |||||
| elif isinstance(dataset, de.Cifar100Dataset): | |||||
| op_type = OpName.CIFAR100 | |||||
| elif isinstance(dataset, de.CelebADataset): | |||||
| op_type = OpName.CELEBA | |||||
| elif isinstance(dataset, de.RandomDataset): | |||||
| op_type = OpName.RANDOMDATA | |||||
| elif isinstance(dataset, de.TextFileDataset): | |||||
| op_type = OpName.TEXTFILE | |||||
| elif isinstance(dataset, de.BuildVocabDataset): | |||||
| op_type = OpName.BUILDVOCAB | |||||
| elif isinstance(dataset, de.BuildSentencePieceVocabDataset): | |||||
| op_type = OpName.SENTENCEPIECEVOCAB | |||||
| elif isinstance(dataset, de.CLUEDataset): | |||||
| op_type = OpName.CLUE | |||||
| elif isinstance(dataset, de.CSVDataset): | |||||
| op_type = OpName.CSV | |||||
| else: | |||||
| raise ValueError("Unsupported DatasetOp.") | |||||
| return op_type | |||||
| # Convert Python node into C node and add to C layer execution tree in postorder traversal. | |||||
| def __convert_node_postorder(self, node): | |||||
| self.check_node_type(node) | |||||
| op_type = self.__get_dataset_type(node) | |||||
| c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) | |||||
| for py_child in node.children: | |||||
| c_child = self.__convert_node_postorder(py_child) | |||||
| self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"]) | |||||
| return c_nodes["top"] | |||||
| def __batch_node(self, dataset, level): | |||||
| """Recursively get batch node in the dataset tree.""" | |||||
| if isinstance(dataset, de.BatchDataset): | |||||
| return | |||||
| for input_op in dataset.children: | |||||
| self.__batch_node(input_op, level + 1) | |||||
| @staticmethod | |||||
| def __print_local(dataset, level): | |||||
| """Recursively print the name and address of nodes in the dataset tree.""" | |||||
| name = dataset.__class__.__name__ | |||||
| ptr = hex(id(dataset)) | |||||
| for _ in range(level): | |||||
| logger.info("\t", end='') | |||||
| if not dataset.children: | |||||
| logger.info("-%s (%s)", name, ptr) | |||||
| else: | |||||
| logger.info("+%s (%s)", name, ptr) | |||||
| for input_op in dataset.children: | |||||
| Iterator.__print_local(input_op, level + 1) | |||||
| def print(self): | |||||
| """Print the dataset tree""" | |||||
| self.__print_local(self.dataset, 0) | |||||
| if hasattr(self, '_runtime_context') and self._runtime_context: | |||||
| if hasattr(self, '_iterator') and self._iterator: | |||||
| self._runtime_context.Terminate() | |||||
| del self._iterator | |||||
| del self._runtime_context | |||||
| del self.dataset | |||||
| def release(self): | def release(self): | ||||
| if hasattr(self, 'depipeline') and self.depipeline: | |||||
| del self.depipeline | |||||
| self.stop() | |||||
| def __del__(self): | |||||
| self.release() | |||||
| @abstractmethod | @abstractmethod | ||||
| def get_next(self): | |||||
| def _get_next(self): | |||||
| raise RuntimeError("Calling base class Iterator's get_next is invalid.") | raise RuntimeError("Calling base class Iterator's get_next is invalid.") | ||||
| def __next__(self): | def __next__(self): | ||||
| if not self.depipeline: | |||||
| if not self._runtime_context: | |||||
| logger.warning("Iterator does not have a running C++ pipeline." + | logger.warning("Iterator does not have a running C++ pipeline." + | ||||
| "It might because Iterator stop() had been called, or C++ pipeline crashed silently.") | "It might because Iterator stop() had been called, or C++ pipeline crashed silently.") | ||||
| raise RuntimeError("Iterator does not have a running C++ pipeline.") | raise RuntimeError("Iterator does not have a running C++ pipeline.") | ||||
| data = self.get_next() | |||||
| data = self._get_next() | |||||
| if not data: | if not data: | ||||
| if self._index == 0: | if self._index == 0: | ||||
| logger.warning("No records available.") | logger.warning("No records available.") | ||||
| @@ -277,100 +127,56 @@ class Iterator: | |||||
| self._index += 1 | self._index += 1 | ||||
| return data | return data | ||||
| @abstractmethod | |||||
| def check_node_type(self, node): | |||||
| pass | |||||
| def get_output_shapes(self): | |||||
| return [t for t in self.depipeline.GetOutputShapes()] | |||||
| def get_output_types(self): | |||||
| return [t for t in self.depipeline.GetOutputTypes()] | |||||
| def get_dataset_size(self): | |||||
| return self.depipeline.GetDatasetSize() | |||||
| def get_batch_size(self): | |||||
| return self.depipeline.GetBatchSize() | |||||
| def get_repeat_count(self): | |||||
| return self.depipeline.GetRepeatCount() | |||||
| def num_classes(self): | |||||
| return self.depipeline.GetNumClasses() | |||||
| def get_col_names(self): | |||||
| return self.depipeline.GetColumnNames() | |||||
| def __deepcopy__(self, memo): | def __deepcopy__(self, memo): | ||||
| return self | return self | ||||
| def _getters(self): | |||||
| """ | |||||
| Get pipeline information. | |||||
| """ | |||||
| getter = cde.TreeGetters() | |||||
| getter.Init(self.ir_tree) | |||||
| self._runtime_context.AssignConsumer(getter) | |||||
| self._col_names = getter.GetColumnNames() | |||||
| class SaveOp(Iterator): | |||||
| """ | |||||
| The derived class of Iterator with dict type. | |||||
| """ | |||||
| def __init__(self, dataset, num_epochs=-1): | |||||
| super().__init__(dataset, num_epochs) | |||||
| self.depipeline.LaunchTreeExec() | |||||
| def get_next(self): | |||||
| pass | |||||
| def check_node_type(self, node): | |||||
| if isinstance(node, (de.ShuffleDataset, de.RepeatDataset, de.BatchDataset)): | |||||
| logger.warning("Used shuffle, repeat, batch before save operator.") | |||||
| def save(self, file_names, file_type): | |||||
| return self.depipeline.SaveDataset(file_names, file_type) | |||||
| def get_col_names(self): | |||||
| """ | |||||
| Get names of the columns in the dataset | |||||
| """ | |||||
| if self._col_names is None: | |||||
| self._getters() | |||||
| return self._col_names | |||||
| class DictIterator(Iterator): | class DictIterator(Iterator): | ||||
| """ | """ | ||||
| The derived class of Iterator with dict type. | The derived class of Iterator with dict type. | ||||
| """ | """ | ||||
| def __init__(self, dataset, num_epochs=-1, output_numpy=False): | |||||
| super().__init__(dataset, num_epochs, output_numpy) | |||||
| self.depipeline.LaunchTreeExec() | |||||
| def check_node_type(self, node): | |||||
| pass | |||||
| def __iter__(self): | |||||
| return self | |||||
| def get_next(self): | |||||
| def _get_next(self): | |||||
| """ | """ | ||||
| Returns the next record in the dataset as dictionary | Returns the next record in the dataset as dictionary | ||||
| Returns: | Returns: | ||||
| Dict, the next record in the dataset. | Dict, the next record in the dataset. | ||||
| """ | """ | ||||
| if self.output_numpy: | |||||
| return {k: v.as_array() for k, v in self.depipeline.GetNextAsMap().items()} | |||||
| return {k: Tensor(v.as_array()) for k, v in self.depipeline.GetNextAsMap().items()} | |||||
| return {k: self._transform_tensor(t) for k, t in self._iterator.GetNextAsMap().items()} | |||||
| class TupleIterator(Iterator): | class TupleIterator(Iterator): | ||||
| """ | """ | ||||
| The derived class of Iterator with list type. | The derived class of Iterator with list type. | ||||
| """ | """ | ||||
| def check_node_type(self, node): | |||||
| pass | |||||
| def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False): | def __init__(self, dataset, columns=None, num_epochs=-1, output_numpy=False): | ||||
| if columns is not None: | if columns is not None: | ||||
| if not isinstance(columns, list): | if not isinstance(columns, list): | ||||
| columns = [columns] | columns = [columns] | ||||
| # todo: move next to IR | |||||
| dataset = dataset.project(columns) | dataset = dataset.project(columns) | ||||
| super().__init__(dataset, num_epochs, output_numpy) | super().__init__(dataset, num_epochs, output_numpy) | ||||
| self.depipeline.LaunchTreeExec() | |||||
| def __iter__(self): | |||||
| return self | |||||
| def get_next(self): | |||||
| def _get_next(self): | |||||
| """ | """ | ||||
| Returns the next record in the dataset as a list | Returns the next record in the dataset as a list | ||||
| @@ -378,15 +184,14 @@ class TupleIterator(Iterator): | |||||
| List, the next record in the dataset. | List, the next record in the dataset. | ||||
| """ | """ | ||||
| if self.output_numpy: | |||||
| return [t.as_array() for t in self.depipeline.GetNextAsList()] | |||||
| return [Tensor(t.as_array()) for t in self.depipeline.GetNextAsList()] | |||||
| return [self._transform_tensor(t) for t in self._iterator.GetNextAsList()] | |||||
| class DummyIterator: | class DummyIterator: | ||||
| """ | """ | ||||
| A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" | A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" | ||||
| """ | """ | ||||
| def __init__(self, dataset, mode): | def __init__(self, dataset, mode): | ||||
| self.mode = mode | self.mode = mode | ||||
| self.shapes = dataset.output_shapes() | self.shapes = dataset.output_shapes() | ||||
| @@ -283,9 +283,12 @@ def create_node(node): | |||||
| node.get('shard_id'), sampler) | node.get('shard_id'), sampler) | ||||
| elif dataset_op == 'TFRecordDataset': | elif dataset_op == 'TFRecordDataset': | ||||
| shuffle = node.get('shuffle') | |||||
| if shuffle is not None and isinstance(shuffle, str): | |||||
| shuffle = de.Shuffle(shuffle) | |||||
| pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), | pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('column_list'), | ||||
| node.get('num_samples'), node.get('num_parallel_workers'), | node.get('num_samples'), node.get('num_parallel_workers'), | ||||
| de.Shuffle(node.get('shuffle')), node.get('num_shards'), node.get('shard_id')) | |||||
| shuffle, node.get('num_shards'), node.get('shard_id')) | |||||
| elif dataset_op == 'ManifestDataset': | elif dataset_op == 'ManifestDataset': | ||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| @@ -293,14 +293,38 @@ def check_save(method): | |||||
| return new_method | return new_method | ||||
| def check_iterator(method): | |||||
| def check_tuple_iterator(method): | |||||
| """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" | """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" | ||||
| @wraps(method) | @wraps(method) | ||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||||
| [columns, num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) | |||||
| nreq_param_bool = ['output_numpy'] | nreq_param_bool = ['output_numpy'] | ||||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | validate_dataset_param_value(nreq_param_bool, param_dict, bool) | ||||
| if num_epochs is not None: | |||||
| type_check(num_epochs, (int,), "num_epochs") | |||||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||||
| if columns is not None: | |||||
| check_columns(columns, "column_names") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_dict_iterator(method): | |||||
| """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [num_epochs, _], param_dict = parse_user_args(method, *args, **kwargs) | |||||
| nreq_param_bool = ['output_numpy'] | |||||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||||
| if num_epochs is not None: | |||||
| type_check(num_epochs, (int,), "num_epochs") | |||||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||||
| return method(self, *args, **kwargs) | return method(self, *args, **kwargs) | ||||
| return new_method | return new_method | ||||
| @@ -523,6 +547,8 @@ def check_batch(method): | |||||
| sig = ins.signature(batch_size) | sig = ins.signature(batch_size) | ||||
| if len(sig.parameters) != 1: | if len(sig.parameters) != 1: | ||||
| raise ValueError("callable batch_size should take one parameter (BatchInfo).") | raise ValueError("callable batch_size should take one parameter (BatchInfo).") | ||||
| else: | |||||
| check_pos_int32(int(batch_size), "batch_size") | |||||
| if num_parallel_workers is not None: | if num_parallel_workers is not None: | ||||
| check_num_parallel_workers(num_parallel_workers) | check_num_parallel_workers(num_parallel_workers) | ||||
| @@ -807,6 +833,21 @@ def check_project(method): | |||||
| return new_method | return new_method | ||||
| def check_schema(method): | |||||
| """check the input arguments of Schema.__init__.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [schema_file], _ = parse_user_args(method, *args, **kwargs) | |||||
| if schema_file is not None: | |||||
| type_check(schema_file, (str,), "schema_file") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def check_add_column(method): | def check_add_column(method): | ||||
| """check the input arguments of add_column.""" | """check the input arguments of add_column.""" | ||||
| @@ -1261,3 +1302,23 @@ def check_cache_option(cache): | |||||
| """Sanity check for cache parameter""" | """Sanity check for cache parameter""" | ||||
| if cache is not None: | if cache is not None: | ||||
| type_check(cache, (cache_client.DatasetCache,), "cache") | type_check(cache, (cache_client.DatasetCache,), "cache") | ||||
| def check_to_device_send(method): | |||||
| """A wrapper that wraps a parameter checker around the check_to_device_send.""" | |||||
| @wraps(method) | |||||
| def new_method(self, *args, **kwargs): | |||||
| [num_epochs], _ = parse_user_args(method, *args, **kwargs) | |||||
| if num_epochs is not None: | |||||
| type_check(num_epochs, (int,), "num_epochs") | |||||
| check_value(num_epochs, [-1, INT32_MAX], "num_epochs") | |||||
| return method(self, *args, **kwargs) | |||||
| return new_method | |||||
| def replace_none(value, default): | |||||
| return value if value is not None else default | |||||
| @@ -18,13 +18,13 @@ use to_bytes and to_str to encode and decode strings into a specified format. | |||||
| """ | """ | ||||
| from enum import IntEnum | from enum import IntEnum | ||||
| import copy | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore._c_dataengine as cde | import mindspore._c_dataengine as cde | ||||
| from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ | from .validators import check_from_file, check_from_list, check_from_dict, check_from_dataset, \ | ||||
| check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model | check_from_dataset_sentencepiece, check_from_file_sentencepiece, check_save_model | ||||
| __all__ = [ | __all__ = [ | ||||
| "Vocab", "SentencePieceVocab", "to_str", "to_bytes" | "Vocab", "SentencePieceVocab", "to_str", "to_bytes" | ||||
| ] | ] | ||||
| @@ -39,8 +39,7 @@ class Vocab(cde.Vocab): | |||||
| @classmethod | @classmethod | ||||
| @check_from_dataset | @check_from_dataset | ||||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, | |||||
| special_first=True): | |||||
| def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, special_first=True): | |||||
| """ | """ | ||||
| Build a vocab from a dataset. | Build a vocab from a dataset. | ||||
| @@ -69,21 +68,7 @@ class Vocab(cde.Vocab): | |||||
| Returns: | Returns: | ||||
| Vocab, Vocab object built from dataset. | Vocab, Vocab object built from dataset. | ||||
| """ | """ | ||||
| vocab = Vocab() | |||||
| if columns is None: | |||||
| columns = [] | |||||
| if not isinstance(columns, list): | |||||
| columns = [columns] | |||||
| if freq_range is None: | |||||
| freq_range = (None, None) | |||||
| if special_tokens is None: | |||||
| special_tokens = [] | |||||
| root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) | |||||
| for d in root.create_dict_iterator(num_epochs=1): | |||||
| if d is not None: | |||||
| raise ValueError("from_dataset should receive data other than None.") | |||||
| return vocab | |||||
| return dataset.build_vocab(columns, freq_range, top_k, special_tokens, special_first) | |||||
| @classmethod | @classmethod | ||||
| @check_from_list | @check_from_list | ||||
| @@ -143,6 +128,7 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||||
| """ | """ | ||||
| SentencePiece obiect that is used to segmentate words | SentencePiece obiect that is used to segmentate words | ||||
| """ | """ | ||||
| @classmethod | @classmethod | ||||
| @check_from_dataset_sentencepiece | @check_from_dataset_sentencepiece | ||||
| def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params): | def from_dataset(cls, dataset, col_names, vocab_size, character_coverage, model_type, params): | ||||
| @@ -164,13 +150,8 @@ class SentencePieceVocab(cde.SentencePieceVocab): | |||||
| SentencePiece, SentencePiece object from dataset. | SentencePiece, SentencePiece object from dataset. | ||||
| """ | """ | ||||
| vocab = SentencePieceVocab() | |||||
| root = copy.deepcopy(dataset).build_sentencepiece_vocab(vocab, col_names, vocab_size, character_coverage, | |||||
| model_type, params) | |||||
| for d in root.create_dict_iterator(num_epochs=1): | |||||
| if d is None: | |||||
| raise ValueError("from_dataset should receive data other than None.") | |||||
| return vocab | |||||
| return dataset.build_sentencepiece_vocab(col_names, vocab_size, character_coverage, | |||||
| DE_C_INTER_SENTENCEPIECE_MODE[model_type], params) | |||||
| @classmethod | @classmethod | ||||
| @check_from_file_sentencepiece | @check_from_file_sentencepiece | ||||
| @@ -270,6 +251,7 @@ class SentencePieceModel(IntEnum): | |||||
| CHAR = 2 | CHAR = 2 | ||||
| WORD = 3 | WORD = 3 | ||||
| DE_C_INTER_SENTENCEPIECE_MODE = { | DE_C_INTER_SENTENCEPIECE_MODE = { | ||||
| SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM, | SentencePieceModel.UNIGRAM: cde.SentencePieceModel.DE_SENTENCE_PIECE_UNIGRAM, | ||||
| SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE, | SentencePieceModel.BPE: cde.SentencePieceModel.DE_SENTENCE_PIECE_BPE, | ||||
| @@ -432,7 +432,7 @@ def check_from_dataset_sentencepiece(method): | |||||
| [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) | [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) | ||||
| if col_names is not None: | if col_names is not None: | ||||
| type_check(col_names, (list,), "col_names") | |||||
| type_check_list(col_names, (str,), "col_names") | |||||
| if vocab_size is not None: | if vocab_size is not None: | ||||
| check_uint32(vocab_size, "vocab_size") | check_uint32(vocab_size, "vocab_size") | ||||
| @@ -146,6 +146,7 @@ if (BUILD_MINDDATA STREQUAL "full") | |||||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_IR_CACHE_SRC_FILES | ||||
| "${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc" | "${MINDDATA_DIR}/engine/ir/cache/dataset_cache_impl.cc" | ||||
| "${MINDDATA_DIR}/engine/ir/cache/pre_built_dataset_cache.cc" | |||||
| ) | ) | ||||
| list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | list(REMOVE_ITEM MINDDATA_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | ||||
| @@ -123,6 +123,7 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) | ||||
| return network | return network | ||||
| class DatasetHelper: | class DatasetHelper: | ||||
| """ | """ | ||||
| DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. | DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. | ||||
| @@ -197,7 +198,6 @@ class DatasetHelper: | |||||
| def get_data_info(self): | def get_data_info(self): | ||||
| return self.iter.get_data_info() | return self.iter.get_data_info() | ||||
| class _DatasetIter: | class _DatasetIter: | ||||
| """Base iter for dataset helper""" | """Base iter for dataset helper""" | ||||
| @@ -331,7 +331,6 @@ class _DatasetIterPSLite(_DatasetIter): | |||||
| class _DatasetIterNormal: | class _DatasetIterNormal: | ||||
| """Iter for normal(non sink) mode, feed the data from host.""" | """Iter for normal(non sink) mode, feed the data from host.""" | ||||
| def __init__(self, dataset, epoch_num=-1): | def __init__(self, dataset, epoch_num=-1): | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| self.device_num = _get_device_num() | self.device_num = _get_device_num() | ||||
| @@ -61,15 +61,15 @@ class MindData: | |||||
| def send(self, num_epochs=-1): | def send(self, num_epochs=-1): | ||||
| pass | pass | ||||
| def get_data_info(self): | |||||
| pass | |||||
| def stop_send(self): | def stop_send(self): | ||||
| pass | pass | ||||
| def continue_send(self): | def continue_send(self): | ||||
| pass | pass | ||||
| def get_data_info(self): | |||||
| pass | |||||
| def __len__(self): | def __len__(self): | ||||
| return self._size | return self._size | ||||
| @@ -177,8 +177,8 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) { | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| // 5 batches of size 2 | |||||
| EXPECT_EQ(i, 5); | |||||
| // With 2 boundaries, 3 buckets are created | |||||
| EXPECT_EQ(i, 3); | |||||
| // Manually terminate the pipeline | // Manually terminate the pipeline | ||||
| iter->Stop(); | iter->Stop(); | ||||
| @@ -132,6 +132,6 @@ TEST_F(MindDataTestOptimizationPass, MindDataTestDatasetSizePass) { | |||||
| // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not | // verify that Shuffle and RepeatOp are removed, but Batch and ProjectOp are not | ||||
| EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); | EXPECT_EQ(ss_str.find("ShuffleOp"), ss_str.npos); | ||||
| EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | EXPECT_NE(ss_str.find("RepeatOp"), ss_str.npos); | ||||
| EXPECT_EQ(ss_str.find("ProjectOp"), ss_str.npos); | |||||
| EXPECT_NE(ss_str.find("ProjectOp"), ss_str.npos); | |||||
| EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | EXPECT_NE(ss_str.find("BatchOp"), ss_str.npos); | ||||
| } | } | ||||
| @@ -63,7 +63,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { | |||||
| const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}}; | const std::unordered_map<std::string, int32_t> map = {{"label", 1}, {"image", 0}}; | ||||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | ||||
| std::vector<size_t> row_sizes = {2, 2, 0, 0}; | |||||
| std::vector<size_t> row_sizes = {2, 2, 0}; | |||||
| TensorRow row; | TensorRow row; | ||||
| for (size_t sz : row_sizes) { | for (size_t sz : row_sizes) { | ||||
| @@ -75,7 +75,7 @@ TEST_F(MindDataTestTreeAdapter, TestSimpleTreeAdapter) { | |||||
| rc = tree_adapter.GetNext(&row); | rc = tree_adapter.GetNext(&row); | ||||
| EXPECT_TRUE(rc.IsError()); | EXPECT_TRUE(rc.IsError()); | ||||
| const std::string err_msg = rc.ToString(); | const std::string err_msg = rc.ToString(); | ||||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||||
| } | } | ||||
| TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | ||||
| @@ -97,7 +97,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||||
| const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); | const std::unordered_map<std::string, int32_t> map = tree_adapter.GetColumnNameMap(); | ||||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | ||||
| std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0, 0}; | |||||
| std::vector<size_t> row_sizes = {2, 2, 0, 2, 2, 0}; | |||||
| TensorRow row; | TensorRow row; | ||||
| for (size_t sz : row_sizes) { | for (size_t sz : row_sizes) { | ||||
| @@ -107,7 +107,7 @@ TEST_F(MindDataTestTreeAdapter, TestTreeAdapterWithRepeat) { | |||||
| } | } | ||||
| rc = tree_adapter.GetNext(&row); | rc = tree_adapter.GetNext(&row); | ||||
| const std::string err_msg = rc.ToString(); | const std::string err_msg = rc.ToString(); | ||||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||||
| } | } | ||||
| TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | ||||
| @@ -135,7 +135,7 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||||
| const std::unordered_map<std::string, int32_t> map = {{"label", 0}}; | const std::unordered_map<std::string, int32_t> map = {{"label", 0}}; | ||||
| EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | EXPECT_EQ(tree_adapter.GetColumnNameMap(), map); | ||||
| std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0, 0}; | |||||
| std::vector<size_t> row_sizes = {1, 1, 0, 1, 1, 0}; | |||||
| TensorRow row; | TensorRow row; | ||||
| for (size_t sz : row_sizes) { | for (size_t sz : row_sizes) { | ||||
| @@ -145,5 +145,5 @@ TEST_F(MindDataTestTreeAdapter, TestProjectMapTreeAdapter) { | |||||
| } | } | ||||
| rc = tree_adapter.GetNext(&row); | rc = tree_adapter.GetNext(&row); | ||||
| const std::string err_msg = rc.ToString(); | const std::string err_msg = rc.ToString(); | ||||
| EXPECT_TRUE(err_msg.find("EOF has already been reached") != err_msg.npos); | |||||
| EXPECT_TRUE(err_msg.find("EOF buffer encountered.") != err_msg.npos); | |||||
| } | } | ||||
| @@ -451,6 +451,10 @@ def test_batch_exception_13(): | |||||
| def test_batch_exception_14(): | def test_batch_exception_14(): | ||||
| """ | |||||
| Test per_batch_map and input column name | |||||
| """ | |||||
| logger.info("test_batch_exception_14") | |||||
| batch_size = 2 | batch_size = 2 | ||||
| input_columns = ["num"] | input_columns = ["num"] | ||||
| data1 = ds.TFRecordDataset(DATA_DIR) | data1 = ds.TFRecordDataset(DATA_DIR) | ||||
| @@ -460,6 +464,22 @@ def test_batch_exception_14(): | |||||
| assert "per_batch_map and input_columns need to be passed in together." in str(e) | assert "per_batch_map and input_columns need to be passed in together." in str(e) | ||||
| def test_batch_exception_15(): | |||||
| """ | |||||
| Test batch_size = int32 max value + 1 | |||||
| """ | |||||
| logger.info("test_batch_exception_15") | |||||
| batch_size = 2147483647 + 1 | |||||
| input_columns = ["num"] | |||||
| data1 = ds.TFRecordDataset(DATA_DIR) | |||||
| err_msg = "" | |||||
| try: | |||||
| _ = data1.batch(batch_size=batch_size, input_columns=input_columns) | |||||
| except ValueError as e: | |||||
| err_msg = str(e) | |||||
| assert "batch_size is not within the required interval of (1 to 2147483647)" in err_msg | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_batch_01() | test_batch_01() | ||||
| test_batch_02() | test_batch_02() | ||||
| @@ -486,4 +506,5 @@ if __name__ == '__main__': | |||||
| test_batch_exception_12() | test_batch_exception_12() | ||||
| test_batch_exception_13() | test_batch_exception_13() | ||||
| test_batch_exception_14() | test_batch_exception_14() | ||||
| test_batch_exception_15() | |||||
| logger.info('\n') | logger.info('\n') | ||||
| @@ -12,7 +12,8 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import os | |||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| @@ -354,6 +355,18 @@ def test_clue_to_device(): | |||||
| data.send() | data.send() | ||||
| def test_clue_invalid_files(): | |||||
| """ | |||||
| Test CLUE with invalid files | |||||
| """ | |||||
| AFQMC_DIR = '../data/dataset/testCLUE/afqmc' | |||||
| afqmc_train_json = os.path.join(AFQMC_DIR) | |||||
| with pytest.raises(ValueError) as info: | |||||
| _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False) | |||||
| assert "The following patterns did not match any files" in str(info.value) | |||||
| assert AFQMC_DIR in str(info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_clue() | test_clue() | ||||
| test_clue_num_shards() | test_clue_num_shards() | ||||
| @@ -366,3 +379,4 @@ if __name__ == "__main__": | |||||
| test_clue_tnews() | test_clue_tnews() | ||||
| test_clue_wsc() | test_clue_wsc() | ||||
| test_clue_to_device() | test_clue_to_device() | ||||
| test_clue_invalid_files() | |||||
| @@ -195,30 +195,42 @@ def test_csv_dataset_size(): | |||||
| assert data.get_dataset_size() == 5 | assert data.get_dataset_size() == 5 | ||||
| def test_csv_dataset_exception(): | |||||
| def test_csv_dataset_type_error(): | |||||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | TEST_FILE = '../data/dataset/testCSV/exception.csv' | ||||
| data = ds.CSVDataset( | data = ds.CSVDataset( | ||||
| TEST_FILE, | TEST_FILE, | ||||
| column_defaults=["", "", "", ""], | |||||
| column_defaults=["", 0, "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | column_names=['col1', 'col2', 'col3', 'col4'], | ||||
| shuffle=False) | shuffle=False) | ||||
| with pytest.raises(Exception) as err: | with pytest.raises(Exception) as err: | ||||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| pass | pass | ||||
| assert "failed to parse file" in str(err.value) | |||||
| assert "type does not match" in str(err.value) | |||||
| def test_csv_dataset_type_error(): | |||||
| def test_csv_dataset_exception(): | |||||
| TEST_FILE = '../data/dataset/testCSV/exception.csv' | TEST_FILE = '../data/dataset/testCSV/exception.csv' | ||||
| data = ds.CSVDataset( | data = ds.CSVDataset( | ||||
| TEST_FILE, | TEST_FILE, | ||||
| column_defaults=["", 0, "", ""], | |||||
| column_defaults=["", "", "", ""], | |||||
| column_names=['col1', 'col2', 'col3', 'col4'], | column_names=['col1', 'col2', 'col3', 'col4'], | ||||
| shuffle=False) | shuffle=False) | ||||
| with pytest.raises(Exception) as err: | with pytest.raises(Exception) as err: | ||||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| pass | pass | ||||
| assert "type does not match" in str(err.value) | |||||
| assert "failed to parse file" in str(err.value) | |||||
| def test_csv_dataset_duplicate_columns(): | |||||
| data = ds.CSVDataset( | |||||
| DATA_FILE, | |||||
| column_defaults=["1", "2", "3", "4"], | |||||
| column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'], | |||||
| shuffle=False) | |||||
| with pytest.raises(RuntimeError) as info: | |||||
| _ = data.create_dict_iterator(num_epochs=1, output_numpy=True) | |||||
| assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value) | |||||
| assert "column_names" in str(info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -234,5 +246,6 @@ if __name__ == "__main__": | |||||
| test_csv_dataset_header() | test_csv_dataset_header() | ||||
| test_csv_dataset_number() | test_csv_dataset_number() | ||||
| test_csv_dataset_size() | test_csv_dataset_size() | ||||
| test_csv_dataset_exception() | |||||
| test_csv_dataset_type_error() | test_csv_dataset_type_error() | ||||
| test_csv_dataset_exception() | |||||
| test_csv_dataset_duplicate_columns() | |||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================== | # ============================================================================== | ||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| import mindspore.dataset.vision.c_transforms as vision | |||||
| IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" | IMAGENET_RAWDATA_DIR = "../data/dataset/testImageNetData2/train" | ||||
| IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | ||||
| @@ -21,9 +22,18 @@ IMAGENET_TFFILE_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-000 | |||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", | ||||
| "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] | ||||
| MNIST_DATA_DIR = "../data/dataset/testMnistData" | MNIST_DATA_DIR = "../data/dataset/testMnistData" | ||||
| MIND_CV_FILE_NAME = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord" | |||||
| SCHEMA_FILE = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" | MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" | ||||
| CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" | CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data" | ||||
| CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data" | CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data" | ||||
| VOC_DATA_DIR = "../data/dataset/testVOC2012" | |||||
| COCO_DATA_DIR = "../data/dataset/testCOCO/train/" | |||||
| ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" | |||||
| CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" | |||||
| CLUE_FILE = '../data/dataset/testCLUE/afqmc/train.json' | |||||
| CSV_FILE = '../data/dataset/testCSV/1.csv' | |||||
| TEXT_DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | |||||
| def test_imagenet_rawdata_dataset_size(): | def test_imagenet_rawdata_dataset_size(): | ||||
| @@ -50,8 +60,15 @@ def test_imagenet_tf_file_dataset_size(): | |||||
| ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0) | ds_shard_2_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=2, shard_id=0) | ||||
| assert ds_shard_2_0.get_dataset_size() == 6 | assert ds_shard_2_0.get_dataset_size() == 6 | ||||
| # FIXME: dataset_size == 6 looks wrong but seem it aims to match the current code. | |||||
| # Correct answer should be 12/3=4, the code issue should be addressed. | |||||
| ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0) | ds_shard_3_0 = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, num_shards=3, shard_id=0) | ||||
| assert ds_shard_3_0.get_dataset_size() == 4 | |||||
| assert ds_shard_3_0.get_dataset_size() == 6 | |||||
| count = 0 | |||||
| for _ in ds_shard_3_0.create_dict_iterator(): | |||||
| count += 1 | |||||
| assert ds_shard_3_0.get_dataset_size() == count | |||||
| def test_mnist_dataset_size(): | def test_mnist_dataset_size(): | ||||
| @@ -76,6 +93,14 @@ def test_mnist_dataset_size(): | |||||
| assert ds_shard_3_0.get_dataset_size() == 3334 | assert ds_shard_3_0.get_dataset_size() == 3334 | ||||
| def test_mind_dataset_size(): | |||||
| dataset = ds.MindDataset(MIND_CV_FILE_NAME + "0") | |||||
| assert dataset.get_dataset_size() == 20 | |||||
| dataset_shard_2_0 = ds.MindDataset(MIND_CV_FILE_NAME + "0", num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 10 | |||||
| def test_manifest_dataset_size(): | def test_manifest_dataset_size(): | ||||
| ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE) | ds_total = ds.ManifestDataset(MANIFEST_DATA_FILE) | ||||
| assert ds_total.get_dataset_size() == 4 | assert ds_total.get_dataset_size() == 4 | ||||
| @@ -95,10 +120,11 @@ def test_cifar10_dataset_size(): | |||||
| assert ds_total.get_dataset_size() == 10000 | assert ds_total.get_dataset_size() == 10000 | ||||
| # test get_dataset_size with usage flag | # test get_dataset_size with usage flag | ||||
| train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() | |||||
| assert train_size == 0 | |||||
| train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() | train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() | ||||
| assert train_size == 10000 | assert train_size == 10000 | ||||
| test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size() | |||||
| assert test_size == 0 | |||||
| all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() | all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() | ||||
| assert all_size == 10000 | assert all_size == 10000 | ||||
| @@ -120,8 +146,6 @@ def test_cifar100_dataset_size(): | |||||
| assert ds_total.get_dataset_size() == 10000 | assert ds_total.get_dataset_size() == 10000 | ||||
| # test get_dataset_size with usage flag | # test get_dataset_size with usage flag | ||||
| train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() | |||||
| assert train_size == 0 | |||||
| test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() | test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() | ||||
| assert test_size == 10000 | assert test_size == 10000 | ||||
| all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() | all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() | ||||
| @@ -137,10 +161,97 @@ def test_cifar100_dataset_size(): | |||||
| assert ds_shard_3_0.get_dataset_size() == 3334 | assert ds_shard_3_0.get_dataset_size() == 3334 | ||||
| def test_voc_dataset_size(): | |||||
| dataset = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) | |||||
| assert dataset.get_dataset_size() == 10 | |||||
| dataset_shard_2_0 = ds.VOCDataset(VOC_DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True, | |||||
| num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 5 | |||||
| def test_coco_dataset_size(): | |||||
| dataset = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", | |||||
| decode=True, shuffle=False) | |||||
| assert dataset.get_dataset_size() == 6 | |||||
| dataset_shard_2_0 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True, | |||||
| shuffle=False, num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 3 | |||||
| def test_celeba_dataset_size(): | |||||
| dataset = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) | |||||
| assert dataset.get_dataset_size() == 4 | |||||
| dataset_shard_2_0 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||||
| def test_clue_dataset_size(): | |||||
| dataset = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False) | |||||
| assert dataset.get_dataset_size() == 3 | |||||
| dataset_shard_2_0 = ds.CLUEDataset(CLUE_FILE, task='AFQMC', usage='train', shuffle=False, num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||||
| def test_csv_dataset_size(): | |||||
| dataset = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], | |||||
| shuffle=False) | |||||
| assert dataset.get_dataset_size() == 3 | |||||
| dataset_shard_2_0 = ds.CSVDataset(CSV_FILE, column_defaults=["0", 0, 0.0, "0"], column_names=['1', '2', '3', '4'], | |||||
| shuffle=False, num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||||
| def test_text_file_dataset_size(): | |||||
| dataset = ds.TextFileDataset(TEXT_DATA_FILE) | |||||
| assert dataset.get_dataset_size() == 3 | |||||
| dataset_shard_2_0 = ds.TextFileDataset(TEXT_DATA_FILE, num_shards=2, shard_id=0) | |||||
| assert dataset_shard_2_0.get_dataset_size() == 2 | |||||
| def test_padded_dataset_size(): | |||||
| dataset = ds.PaddedDataset([{"data": [1, 2, 3]}, {"data": [1, 0, 1]}]) | |||||
| assert dataset.get_dataset_size() == 2 | |||||
| def test_pipeline_get_dataset_size(): | |||||
| dataset = ds.TFRecordDataset(IMAGENET_TFFILE_DIR, SCHEMA_FILE, columns_list=["image"], shuffle=False) | |||||
| assert dataset.get_dataset_size() == 12 | |||||
| dataset = dataset.shuffle(buffer_size=3) | |||||
| assert dataset.get_dataset_size() == 12 | |||||
| decode_op = vision.Decode() | |||||
| resize_op = vision.RandomResize(10) | |||||
| dataset = dataset.map([decode_op, resize_op], input_columns=["image"]) | |||||
| assert dataset.get_dataset_size() == 12 | |||||
| dataset = dataset.batch(batch_size=3) | |||||
| assert dataset.get_dataset_size() == 4 | |||||
| dataset = dataset.repeat(count=2) | |||||
| assert dataset.get_dataset_size() == 8 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_imagenet_rawdata_dataset_size() | test_imagenet_rawdata_dataset_size() | ||||
| test_imagenet_tf_file_dataset_size() | test_imagenet_tf_file_dataset_size() | ||||
| test_mnist_dataset_size() | test_mnist_dataset_size() | ||||
| test_mind_dataset_size() | |||||
| test_manifest_dataset_size() | test_manifest_dataset_size() | ||||
| test_cifar10_dataset_size() | test_cifar10_dataset_size() | ||||
| test_cifar100_dataset_size() | test_cifar100_dataset_size() | ||||
| test_voc_dataset_size() | |||||
| test_coco_dataset_size() | |||||
| test_celeba_dataset_size() | |||||
| test_clue_dataset_size() | |||||
| test_csv_dataset_size() | |||||
| test_text_file_dataset_size() | |||||
| test_padded_dataset_size() | |||||
| test_pipeline_get_dataset_size() | |||||
| @@ -521,7 +521,7 @@ def test_chained_sampler_04(): | |||||
| # Verify dataset size | # Verify dataset size | ||||
| data1_size = data1.get_dataset_size() | data1_size = data1.get_dataset_size() | ||||
| logger.info("dataset size is: {}".format(data1_size)) | logger.info("dataset size is: {}".format(data1_size)) | ||||
| assert data1_size == 24 | |||||
| assert data1_size == 6 | |||||
| # Verify number of iterations | # Verify number of iterations | ||||
| num_iter = 0 | num_iter = 0 | ||||
| @@ -182,6 +182,15 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| def test_voc_num_classes(): | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| assert data1.num_classes() is None | |||||
| class_index = {'car': 0, 'cat': 1, 'train': 5} | |||||
| data2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True) | |||||
| assert data2.num_classes() is None | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_voc_segmentation() | test_voc_segmentation() | ||||
| test_voc_detection() | test_voc_detection() | ||||
| @@ -191,3 +200,4 @@ if __name__ == '__main__': | |||||
| test_case_1() | test_case_1() | ||||
| test_case_2() | test_case_2() | ||||
| test_voc_exception() | test_voc_exception() | ||||
| test_voc_num_classes() | |||||
| @@ -107,7 +107,7 @@ def test_decode_op(): | |||||
| # Expect a AttributeError since iter1 has been stopped. | # Expect a AttributeError since iter1 has been stopped. | ||||
| with pytest.raises(AttributeError) as info: | with pytest.raises(AttributeError) as info: | ||||
| iter1.__next__() | iter1.__next__() | ||||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||||
| with pytest.raises(RuntimeError) as info: | with pytest.raises(RuntimeError) as info: | ||||
| iter2.__next__() | iter2.__next__() | ||||
| @@ -205,7 +205,7 @@ def test_generator_dict_3(): | |||||
| # Expect a AttributeError since iter1 has been stopped. | # Expect a AttributeError since iter1 has been stopped. | ||||
| with pytest.raises(AttributeError) as info: | with pytest.raises(AttributeError) as info: | ||||
| iter1.__next__() | iter1.__next__() | ||||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||||
| def test_generator_dict_4(): | def test_generator_dict_4(): | ||||
| @@ -396,7 +396,7 @@ def test_generator_tuple_3(): | |||||
| # Expect a AttributeError since iter1 has been stopped. | # Expect a AttributeError since iter1 has been stopped. | ||||
| with pytest.raises(AttributeError) as info: | with pytest.raises(AttributeError) as info: | ||||
| iter1.__next__() | iter1.__next__() | ||||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||||
| def test_generator_tuple_4(): | def test_generator_tuple_4(): | ||||
| @@ -546,7 +546,7 @@ def test_generator_tuple_repeat_repeat_2(): | |||||
| # Expect a AttributeError since iter1 has been stopped. | # Expect a AttributeError since iter1 has been stopped. | ||||
| with pytest.raises(AttributeError) as info: | with pytest.raises(AttributeError) as info: | ||||
| iter1.__next__() | iter1.__next__() | ||||
| assert "object has no attribute 'depipeline'" in str(info.value) | |||||
| assert "object has no attribute '_runtime_context'" in str(info.value) | |||||
| def test_generator_tuple_repeat_repeat_3(): | def test_generator_tuple_repeat_repeat_3(): | ||||
| @@ -74,9 +74,11 @@ def test_case2(): | |||||
| def test_case3(): | def test_case3(): | ||||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) | |||||
| data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(5) | |||||
| data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2) | |||||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename( | |||||
| ["col_sint64"], ["a1"]) | |||||
| data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(5).rename( | |||||
| ["col_sint64"], ["a2"]) | |||||
| data3 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).rename(["col_sint64"], ["a3"]) | |||||
| data4 = ds.zip((data1, data2, data3)) | data4 = ds.zip((data1, data2, data3)) | ||||
| @@ -84,8 +86,9 @@ def test_case3(): | |||||
| def test_case4(): | def test_case4(): | ||||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10) | |||||
| data2 = ds.TFRecordDataset(FILES) | |||||
| data1 = ds.TFRecordDataset(FILES, SCHEMA_FILE, columns_list=["col_sint64"]).batch(2).repeat(10).rename( | |||||
| ["col_sint64"], ["a1"]) | |||||
| data2 = ds.TFRecordDataset(FILES, columns_list=["col_sint64"]).rename(["col_sint64"], ["a2"]) | |||||
| assert data2.get_dataset_size() == 12 | assert data2.get_dataset_size() == 12 | ||||
| data2 = data2.batch(2) | data2 = data2.batch(2) | ||||
| assert data2.get_dataset_size() == 6 | assert data2.get_dataset_size() == 6 | ||||