From: @ziruiwu Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -499,9 +499,10 @@ FilterDataset::FilterDataset(std::shared_ptr<Dataset> input, std::function<Tenso | |||||
| MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, | MapDataset::MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache) { | |||||
| auto ds = | |||||
| std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, cache); | |||||
| const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache, | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks) { | |||||
| auto ds = std::make_shared<MapNode>(input->IRNode(), operations, input_columns, output_columns, project_columns, | |||||
| cache, callbacks); | |||||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | ||||
| } | } | ||||
| @@ -44,7 +44,12 @@ bool Iterator::GetNextRow(TensorVec *row) { | |||||
| } | } | ||||
| // Shut down the data pipeline. | // Shut down the data pipeline. | ||||
| void Iterator::Stop() { runtime_context_->Terminate(); } | |||||
| void Iterator::Stop() { | |||||
| Status rc = runtime_context_->Terminate(); | |||||
| if (rc.IsError()) { | |||||
| MS_LOG(ERROR) << rc.ToString(); | |||||
| } | |||||
| } | |||||
| // Function to build and launch the execution tree. | // Function to build and launch the execution tree. | ||||
| Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) { | ||||
| @@ -385,6 +385,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||||
| /// \return Status | /// \return Status | ||||
| virtual Status WaitForWorkers() { return Status::OK(); } | virtual Status WaitForWorkers() { return Status::OK(); } | ||||
| /// \brief Add callback to DatasetOp, only MapOp supports Callback at the moment | |||||
| void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { callback_manager_.AddCallbacks(callbacks); } | |||||
| protected: | protected: | ||||
| /// \brief Removes a parent operator from this operator | /// \brief Removes a parent operator from this operator | ||||
| /// \notes External callers do not have access to this function | /// \notes External callers do not have access to this function | ||||
| @@ -13,9 +13,9 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <iostream> | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| @@ -26,8 +26,6 @@ | |||||
| #include "minddata/dataset/engine/data_buffer.h" | #include "minddata/dataset/engine/data_buffer.h" | ||||
| #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" | #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" | ||||
| #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" | #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" | ||||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||||
| #include "minddata/dataset/engine/execution_tree.h" | |||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "minddata/dataset/kernels/tensor_op.h" | #include "minddata/dataset/kernels/tensor_op.h" | ||||
| #include "minddata/dataset/util/task_manager.h" | #include "minddata/dataset/util/task_manager.h" | ||||
| @@ -60,7 +58,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) { | |||||
| RETURN_IF_NOT_OK(sanityCheck()); | RETURN_IF_NOT_OK(sanityCheck()); | ||||
| *ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_), | *ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_), | ||||
| std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); | std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); | ||||
| (*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_)); | |||||
| (*ptr)->AddCallbacks(std::move(builder_callbacks_)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -31,13 +31,15 @@ namespace dataset { | |||||
| // constructor #1, called by Pybind | // constructor #1, called by Pybind | ||||
| BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | ||||
| const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | ||||
| py::function batch_size_func, py::function batch_map_func, | |||||
| const std::vector<std::string> &col_order, py::function batch_size_func, | |||||
| py::function batch_map_func, | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map) | ||||
| : batch_size_(batch_size), | : batch_size_(batch_size), | ||||
| drop_remainder_(drop_remainder), | drop_remainder_(drop_remainder), | ||||
| pad_(pad), | pad_(pad), | ||||
| in_col_names_(in_col_names), | in_col_names_(in_col_names), | ||||
| out_col_names_(out_col_names), | out_col_names_(out_col_names), | ||||
| col_order_(col_order), | |||||
| batch_size_func_(batch_size_func), | batch_size_func_(batch_size_func), | ||||
| batch_map_func_(batch_map_func), | batch_map_func_(batch_map_func), | ||||
| pad_map_(pad_map) { | pad_map_(pad_map) { | ||||
| @@ -83,8 +85,8 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() { | |||||
| in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, | in_col_names_, out_col_names_, batch_size_func_, batch_map_func_, | ||||
| pad_map_)); | pad_map_)); | ||||
| // need to insert a project when per_batch_func changes the number of columns | // need to insert a project when per_batch_func changes the number of columns | ||||
| if (!out_col_names_.empty()) { | |||||
| auto project_op = std::make_shared<ProjectOp>(out_col_names_); | |||||
| if (!col_order_.empty()) { | |||||
| auto project_op = std::make_shared<ProjectOp>(col_order_); | |||||
| node_ops.push_back(project_op); | node_ops.push_back(project_op); | ||||
| } | } | ||||
| #else | #else | ||||
| @@ -34,7 +34,7 @@ class BatchNode : public DatasetNode { | |||||
| /// \brief Constructor #1, for Python API to create a BatchNode | /// \brief Constructor #1, for Python API to create a BatchNode | ||||
| BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad, | ||||
| const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names, | ||||
| py::function batch_size_func, py::function batch_map_func, | |||||
| const std::vector<std::string> &col_order, py::function batch_size_func, py::function batch_map_func, | |||||
| std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map); | std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map); | ||||
| #endif | #endif | ||||
| @@ -58,6 +58,7 @@ class BatchNode : public DatasetNode { | |||||
| bool pad_; | bool pad_; | ||||
| std::vector<std::string> in_col_names_; | std::vector<std::string> in_col_names_; | ||||
| std::vector<std::string> out_col_names_; | std::vector<std::string> out_col_names_; | ||||
| std::vector<std::string> col_order_; | |||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| py::function batch_size_func_; | py::function batch_size_func_; | ||||
| py::function batch_map_func_; | py::function batch_map_func_; | ||||
| @@ -29,12 +29,14 @@ namespace dataset { | |||||
| MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache) | |||||
| const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache, | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks) | |||||
| : operations_(operations), | : operations_(operations), | ||||
| input_columns_(input_columns), | input_columns_(input_columns), | ||||
| output_columns_(output_columns), | output_columns_(output_columns), | ||||
| project_columns_(project_columns), | project_columns_(project_columns), | ||||
| DatasetNode(std::move(cache)) { | |||||
| DatasetNode(std::move(cache)), | |||||
| callbacks_(callbacks) { | |||||
| this->children.push_back(child); | this->children.push_back(child); | ||||
| } | } | ||||
| @@ -53,6 +55,11 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() { | |||||
| // This parameter will be removed with next rebase | // This parameter will be removed with next rebase | ||||
| std::vector<std::string> col_orders; | std::vector<std::string> col_orders; | ||||
| auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); | auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_); | ||||
| if (!callbacks_.empty()) { | |||||
| map_op->AddCallbacks(callbacks_); | |||||
| } | |||||
| if (!project_columns_.empty()) { | if (!project_columns_.empty()) { | ||||
| auto project_op = std::make_shared<ProjectOp>(project_columns_); | auto project_op = std::make_shared<ProjectOp>(project_columns_); | ||||
| node_ops.push_back(project_op); | node_ops.push_back(project_op); | ||||
| @@ -31,7 +31,8 @@ class MapNode : public DatasetNode { | |||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {}, | ||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr); | |||||
| const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr, | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks = {}); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MapNode() = default; | ~MapNode() = default; | ||||
| @@ -49,6 +50,7 @@ class MapNode : public DatasetNode { | |||||
| std::vector<std::string> input_columns_; | std::vector<std::string> input_columns_; | ||||
| std::vector<std::string> output_columns_; | std::vector<std::string> output_columns_; | ||||
| std::vector<std::string> project_columns_; | std::vector<std::string> project_columns_; | ||||
| std::vector<std::shared_ptr<DSCallback>> callbacks_; | |||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -149,7 +149,7 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) { | |||||
| // We finish the walk of this RepeatOp's descendent nodes. | // We finish the walk of this RepeatOp's descendent nodes. | ||||
| // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. | // The total repeats of nodes above this Repeat(n) have nothing to do with this RepeatOp's parameter n. | ||||
| // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, | // But num_repeats_ has been multiplied by n during this Repeat(n)'s PreRunOnNode, | ||||
| // so we devide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. | |||||
| // so we divide num_repeats_ by n to be able to correctly set total repeats for nodes above this RepeatOp. | |||||
| num_repeats_ /= node->num_repeats(); | num_repeats_ /= node->num_repeats(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -120,7 +120,7 @@ Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr< | |||||
| RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); | ||||
| } | } | ||||
| // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf | |||||
| // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf | |||||
| // as save it for use by cache op in ascendant tree. | // as save it for use by cache op in ascendant tree. | ||||
| if (is_caching_) { | if (is_caching_) { | ||||
| RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); | ||||
| @@ -261,7 +261,8 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { | |||||
| // Then, execute the transform for each pair | // Then, execute the transform for each pair | ||||
| for (auto cache_pair : cache_pass.cache_pairs()) { | for (auto cache_pair : cache_pass.cache_pairs()) { | ||||
| MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; | ||||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); | |||||
| RETURN_IF_NOT_OK( | |||||
| ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client())); | |||||
| } | } | ||||
| MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; | MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -60,95 +60,95 @@ class CacheTransformPass : public TreePass { | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<AlbumOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override; | ||||
| #ifdef ENABLE_PYTHON | #ifdef ENABLE_PYTHON | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override; | ||||
| #endif | #endif | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override; | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override; | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Perform leaf node cache tranform identifications | |||||
| /// \brief Perform leaf node cache transform identifications | |||||
| /// \param[in] node The node being visited | /// \param[in] node The node being visited | ||||
| /// \param[inout] modified Indicator if the node was changed at all | /// \param[inout] modified Indicator if the node was changed at all | ||||
| /// \return Status The error code return | /// \return Status The error code return | ||||
| @@ -276,9 +276,10 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| std::vector<std::string> input_columns = {}, | std::vector<std::string> input_columns = {}, | ||||
| std::vector<std::string> output_columns = {}, | std::vector<std::string> output_columns = {}, | ||||
| const std::vector<std::string> &project_columns = {}, | const std::vector<std::string> &project_columns = {}, | ||||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||||
| const std::shared_ptr<DatasetCache> &cache = nullptr, | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks = {}) { | |||||
| return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns, | return std::make_shared<MapDataset>(shared_from_this(), operations, input_columns, output_columns, project_columns, | ||||
| cache); | |||||
| cache, callbacks); | |||||
| } | } | ||||
| /// \brief Function to create a Project Dataset | /// \brief Function to create a Project Dataset | ||||
| @@ -443,7 +444,8 @@ class MapDataset : public Dataset { | |||||
| public: | public: | ||||
| MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, | MapDataset(std::shared_ptr<Dataset> input, std::vector<std::shared_ptr<TensorOperation>> operations, | ||||
| std::vector<std::string> input_columns, std::vector<std::string> output_columns, | std::vector<std::string> input_columns, std::vector<std::string> output_columns, | ||||
| const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache); | |||||
| const std::vector<std::string> &project_columns, const std::shared_ptr<DatasetCache> &cache, | |||||
| std::vector<std::shared_ptr<DSCallback>> callbacks); | |||||
| }; | }; | ||||
| class ProjectDataset : public Dataset { | class ProjectDataset : public Dataset { | ||||
| @@ -21,6 +21,8 @@ | |||||
| #include "minddata/dataset/callback/ds_callback.h" | #include "minddata/dataset/callback/ds_callback.h" | ||||
| #include "minddata/dataset/core/client.h" | #include "minddata/dataset/core/client.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | ||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| #include "minddata/dataset/kernels/data/no_op.h" | #include "minddata/dataset/kernels/data/no_op.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| @@ -149,7 +151,7 @@ TEST_F(MindDataTestCallback, TestBasicCallback) { | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | ||||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | TensorShape shape({}); // empty shape is a 1-value scalar Tensor | ||||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ||||
| schema->AddColumn(col); | |||||
| ASSERT_OK(schema->AddColumn(col)); | |||||
| std::shared_ptr<RandomDataOp> leaf; | std::shared_ptr<RandomDataOp> leaf; | ||||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf); | rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf); | ||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||
| @@ -196,7 +198,7 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) { | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | ||||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | TensorShape shape({}); // empty shape is a 1-value scalar Tensor | ||||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ||||
| schema->AddColumn(col); | |||||
| ASSERT_OK(schema->AddColumn(col)); | |||||
| std::shared_ptr<RandomDataOp> leaf; | std::shared_ptr<RandomDataOp> leaf; | ||||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | ||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||
| @@ -253,7 +255,7 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { | |||||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | ||||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | TensorShape shape({}); // empty shape is a 1-value scalar Tensor | ||||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | ||||
| schema->AddColumn(col); | |||||
| ASSERT_OK(schema->AddColumn(col)); | |||||
| std::shared_ptr<RandomDataOp> leaf; | std::shared_ptr<RandomDataOp> leaf; | ||||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | ||||
| EXPECT_TRUE(rc.IsOk()); | EXPECT_TRUE(rc.IsOk()); | ||||
| @@ -296,3 +298,34 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { | |||||
| EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | ||||
| EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | ||||
| } | } | ||||
| TEST_F(MindDataTestCallback, TestCAPICallback) { | |||||
| MS_LOG(INFO) << "Doing: MindDataTestCallback-TestCAPICallback"; | |||||
| // config callback | |||||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64); | |||||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||||
| // config leaf_op, use random_data to avoid I/O | |||||
| std::shared_ptr<SchemaObj> schema = std::make_shared<SchemaObj>(); | |||||
| ASSERT_TRUE(schema->add_column("label", "uint32", {})); | |||||
| std::shared_ptr<Dataset> ds = RandomData(44, schema); | |||||
| ds = ds->Map({transforms::TypeCast("uint64")}, {"label"}, {}, {}, nullptr, {cb1}); | |||||
| ds = ds->Repeat(2); | |||||
| TreeAdapter tree_adapter; | |||||
| // using tree_adapter to set num_epoch = 1 | |||||
| ASSERT_OK(tree_adapter.Compile(ds->IRNode(), 1)); | |||||
| TensorRow row; | |||||
| ASSERT_OK(tree_adapter.GetNext(&row)); | |||||
| while (!row.empty()) { | |||||
| ASSERT_OK(tree_adapter.GetNext(&row)); | |||||
| } | |||||
| std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"}; | |||||
| std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88}; | |||||
| std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1}; | |||||
| // doing resize to make sure no unexpected epoch_end or extra epoch_begin is called | |||||
| size_t len = 7; | |||||
| EXPECT_EQ(tst_cb->all_names(len), callback_names); | |||||
| EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | |||||
| EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | |||||
| } | |||||