implment pause in MapOp, added more to callback add ds_callback - Initial drop of Python DSCallback - Pybind DSCallback - Pybind DSCallback added callback to mapOp - de_pipeline DSCallback - de_pipeline DSCallback add test case, segfault for now fix seg fault - de_pipeline DSCallback remove 1 line update callback test case, now works use builder class for mapOp callback - de_pipeline DSCallback - de_pipeline DSCallback - de_pipeline DSCallback better test case minor fix add comments and minor clean ups get rid of nullptr in MapOp, use other flag instead fix a bug ParseMapOp only takes 1 callback - Added WaitedDSCalabck refactor callback param fix text case incorrect number - added testing fix cpp test case - added testing - revert back lenet changes - cleanup test_callbacks.py - cleanup test_callbacks.py fix CI stage I fix CI stage II fix CI and update epoch counter - add validation - add more testing test_callbacks.py use random data op to do tests adjust when to call EpochBegin/End - add repeat with callback - addressing reviewers' comments - docstring and CI fixes - docstring and CI fixes - docstring and CI fixes - rebase with upstream/master fix cpp test case fix review comments addr review cmts, add test casetags/v0.7.0-beta
| @@ -58,6 +58,7 @@ add_subdirectory(kernels) | |||
| add_subdirectory(engine) | |||
| add_subdirectory(api) | |||
| add_subdirectory(text) | |||
| add_subdirectory(callback) | |||
| ###################################################################### | |||
| add_dependencies(utils core) | |||
| add_dependencies(kernels-image core) | |||
| @@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core) | |||
| add_dependencies(engine-perf core) | |||
| add_dependencies(engine-gnn core) | |||
| add_dependencies(engine core) | |||
| add_dependencies(callback core) | |||
| add_dependencies(text core) | |||
| add_dependencies(text-kernels core) | |||
| add_dependencies(cpp-API core) | |||
| @@ -87,6 +89,7 @@ endif () | |||
| ################### Create _c_dataengine Library ###################### | |||
| set(submodules | |||
| $<TARGET_OBJECTS:core> | |||
| $<TARGET_OBJECTS:callback> | |||
| $<TARGET_OBJECTS:utils> | |||
| $<TARGET_OBJECTS:kernels> | |||
| $<TARGET_OBJECTS:kernels-image> | |||
| @@ -135,14 +138,14 @@ endif() | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) | |||
| if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||
| if (ENABLE_PYTHON) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| else() | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| endif() | |||
| else() | |||
| set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) | |||
| if (ENABLE_PYTHON) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| else() | |||
| target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY}) | |||
| endif() | |||
| @@ -7,6 +7,7 @@ if (ENABLE_PYTHON) | |||
| python/bindings.cc | |||
| python/bindings/dataset/engine/cache/bindings.cc | |||
| python/bindings/dataset/core/bindings.cc | |||
| python/bindings/dataset/callback/bindings.cc | |||
| python/bindings/dataset/kernels/data/bindings.cc | |||
| python/bindings/dataset/kernels/bindings.cc | |||
| python/bindings/dataset/engine/datasetops/bindings.cc | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/stl_bind.h" | |||
| #include "minddata/dataset/api/python/pybind_register.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) { | |||
| (void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback") | |||
| .def(py::init<int32_t>()) | |||
| .def("set_begin", &PyDSCallback::setBegin) | |||
| .def("set_end", &PyDSCallback::setEnd) | |||
| .def("set_epoch_begin", &PyDSCallback::setEpochBegin) | |||
| .def("set_epoch_end", &PyDSCallback::setEpochEnd) | |||
| .def("set_step_begin", &PyDSCallback::setStepBegin) | |||
| .def("set_step_end", &PyDSCallback::setStepEnd); | |||
| })); | |||
| PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) { | |||
| (void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam") | |||
| .def(py::init<int64_t, int64_t, int64_t>()) | |||
| .def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_) | |||
| .def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_) | |||
| .def_readonly("cur_step_num", &CallbackParam::cur_step_num_); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -20,6 +20,7 @@ | |||
| #include <map> | |||
| #include "utils/ms_utils.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/cache/cache_client.h" | |||
| #include "minddata/dataset/engine/dataset_iterator.h" | |||
| @@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); | |||
| } else if (key == "cache") { | |||
| cache_client = value.cast<std::shared_ptr<CacheClient>>(); | |||
| } else if (key == "callbacks") { | |||
| std::vector<std::shared_ptr<DSCallback>> callbacks; | |||
| std::transform(value.begin(), value.end(), std::back_inserter(callbacks), | |||
| [](py::handle cb) { return cb.cast<std::shared_ptr<PyDSCallback>>(); }); | |||
| (void)map_builder.AddCallbacks(callbacks); | |||
| } else { | |||
| RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); | |||
| RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| if (ENABLE_PYTHON) | |||
| add_library(callback OBJECT | |||
| callback_manager.cc | |||
| py_ds_callback.cc | |||
| ) | |||
| else () | |||
| add_library(callback OBJECT | |||
| callback_manager.cc | |||
| ) | |||
| endif () | |||
| @@ -0,0 +1,160 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/callback/callback_manager.h" | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/engine/datasetops/dataset_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) { | |||
| callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end()); | |||
| } | |||
| Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) { | |||
| RETURN_UNEXPECTED_IF_NULL(op); | |||
| op_ = op; | |||
| // turn the flag on if callback is set | |||
| enabled_ = !callbacks_.empty(); | |||
| // error check for each of the callbacks | |||
| for (auto &cb : callbacks_) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::Begin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no epoch_begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0) | |||
| callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no step_begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::End(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no epoch_end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status CallbackManager::StepEnd(const CallbackParam &cb_param) { | |||
| RETURN_OK_IF_TRUE(!enabled_); | |||
| std::vector<size_t> callback_inds; | |||
| // go through all callback functions to see if each function is needed | |||
| for (size_t ind = 0; ind < callbacks_.size(); ind++) { | |||
| if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0) | |||
| callback_inds.push_back(ind); | |||
| } | |||
| // return Status::OK() if no step_end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // forward declare to avoid cyclic include of dataset_op.h | |||
| class DatasetOp; | |||
| /// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this. | |||
| class CallbackManager { | |||
| public: | |||
| /// CallbackManager default constructor. Init needs to be called before using the created instance. | |||
| CallbackManager() : enabled_(false) {} | |||
| /// \brief | |||
| /// \param [in] callbacks list of callbacks to perform | |||
| void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks); | |||
| /// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true | |||
| /// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads | |||
| /// \return Status | |||
| Status Init(std::shared_ptr<DatasetOp> op); | |||
| /// \brief callback function called at the start of the first row | |||
| /// \return Status | |||
| Status Begin(const CallbackParam &); | |||
| /// \brief callback function called at the start of each epoch | |||
| /// \return Status | |||
| Status EpochBegin(const CallbackParam &); | |||
| /// \brief callback function called at the start of each row | |||
| /// \return Status | |||
| Status StepBegin(const CallbackParam &); | |||
| /// \brief callback function called after the last row is processed | |||
| /// \return Status | |||
| Status End(const CallbackParam &); | |||
| /// \brief callback function called at the end of each epoch | |||
| /// \return Status | |||
| Status EpochEnd(const CallbackParam &); | |||
| /// \brief callback function called at the the end of each row | |||
| /// \return Status | |||
| Status StepEnd(const CallbackParam &); | |||
| private: | |||
| bool enabled_; // flag to enable callback, if false, all functions would return immediately | |||
| std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager | |||
| std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H | |||
| @@ -0,0 +1,40 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H | |||
| #include <nlohmann/json.hpp> | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function. | |||
| /// This is a prototype for now, more fields will be added | |||
| class CallbackParam { | |||
| public: | |||
| CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num) | |||
| : cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {} | |||
| // these are constant public fields for easy access and consistency with python cb_param | |||
| // the names and orders are consistent with batchInfo | |||
| const int64_t cur_epoch_num_; // current epoch | |||
| const int64_t cur_epoch_step_num_; // step number of the current epoch | |||
| const int64_t cur_step_num_; // step number since the first row | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/callback/callback_param.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DSCallback { | |||
| public: | |||
| /// \brief constructor of DSCallback, this is the base class for all front end specific callbacks | |||
| /// \param step_size number of steps to call DSNStepBegin() | |||
| explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {} | |||
| /// \brief actual callback function for begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSBegin(const CallbackParam &cb_param) = 0; | |||
| /// \brief actual callback function for epoch_begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0; | |||
| /// \brief actual callback function for step_begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0; | |||
| /// \brief actual callback function for end, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSEnd(const CallbackParam &cb_param) = 0; | |||
| /// \brief actual callback function epoch_end begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0; | |||
| /// \brief actual callback function for step_end, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0; | |||
| /// \brief predicate function, whether begin callback is needed | |||
| /// \return bool | |||
| virtual bool IsBeginNeeded() = 0; | |||
| /// \brief predicate function, whether epoch_begin callback is needed | |||
| /// \return bool | |||
| virtual bool IsEpochBeginNeeded() = 0; | |||
| /// \brief predicate function, whether step_begin callback is needed | |||
| /// \return bool | |||
| virtual bool IsNStepBeginNeeded() = 0; | |||
| /// \brief predicate function, whether end callback is needed | |||
| /// \return bool | |||
| virtual bool IsEndNeeded() = 0; | |||
| /// \brief predicate function, whether epoch_end callback is needed | |||
| /// \return bool | |||
| virtual bool IsEpochEndNeeded() = 0; | |||
| /// \brief predicate function, whether step_end callback is needed | |||
| /// \return bool | |||
| virtual bool IsNStepEndNeeded() = 0; | |||
| /// \brief getter | |||
| /// \return step_size | |||
| int32_t step_size() const { return step_size_; } | |||
| protected: | |||
| int32_t step_size_; // step begin/end will be called every step_size_ | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H | |||
| @@ -0,0 +1,86 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/callback/callback_manager.h" | |||
| #include "minddata/dataset/callback/py_ds_callback.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status PyDSCallback::DSBegin(const CallbackParam &cb_param) { | |||
| return PyDSCallback::ExecutePyfunc(begin_func_, cb_param); | |||
| } | |||
| Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) { | |||
| return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param); | |||
| } | |||
| Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) { | |||
| return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param); | |||
| } | |||
| Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); } | |||
| Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) { | |||
| return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param); | |||
| } | |||
| Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) { | |||
| return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param); | |||
| } | |||
| bool PyDSCallback::IsBeginNeeded() { return begin_needed_; } | |||
| bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; } | |||
| bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; } | |||
| bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; } | |||
| bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; } | |||
| bool PyDSCallback::IsEndNeeded() { return end_needed_; } | |||
| Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) { | |||
| { | |||
| // Acquire Python GIL | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); | |||
| } | |||
| f(cb_param); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| void PyDSCallback::setBegin(py::function f) { | |||
| begin_func_ = f; | |||
| begin_needed_ = true; | |||
| } | |||
| void PyDSCallback::setEnd(py::function f) { | |||
| end_func_ = f; | |||
| end_needed_ = true; | |||
| } | |||
| void PyDSCallback::setEpochBegin(py::function f) { | |||
| epoch_begin_func_ = f; | |||
| epoch_begin_needed_ = true; | |||
| } | |||
| void PyDSCallback::setEpochEnd(py::function f) { | |||
| epoch_end_func_ = f; | |||
| epoch_end_needed_ = true; | |||
| } | |||
| void PyDSCallback::setStepBegin(py::function f) { | |||
| step_begin_func_ = f; | |||
| step_begin_needed_ = true; | |||
| } | |||
| void PyDSCallback::setStepEnd(py::function f) { | |||
| step_end_func_ = f; | |||
| step_end_needed_ = true; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "pybind11/pybind11.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace py = pybind11; | |||
| class PyDSCallback : public DSCallback { | |||
| public: | |||
| /// \brief constructor for PyDSCallback. This callback is for python front end | |||
| explicit PyDSCallback(int32_t step_size = 1) | |||
| : DSCallback(step_size), | |||
| begin_needed_(false), | |||
| epoch_begin_needed_(false), | |||
| step_begin_needed_(false), | |||
| end_needed_(false), | |||
| epoch_end_needed_(false), | |||
| step_end_needed_(false) {} | |||
| void setBegin(py::function f); | |||
| void setEnd(py::function f); | |||
| void setEpochBegin(py::function f); | |||
| void setEpochEnd(py::function f); | |||
| void setStepBegin(py::function f); | |||
| void setStepEnd(py::function f); | |||
| /// \brief actual callback function for begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSBegin(const CallbackParam &cb_param) override; | |||
| /// \brief actual callback function for epoch_begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSEpochBegin(const CallbackParam &cb_param) override; | |||
| /// \brief actual callback function for step_begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSNStepBegin(const CallbackParam &cb_param) override; | |||
| /// \brief actual callback function for end, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSEnd(const CallbackParam &cb_param) override; | |||
| /// \brief actual callback function epoch_end begin, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSEpochEnd(const CallbackParam &cb_param) override; | |||
| /// \brief actual callback function for step_end, needs to be overridden in the derived class | |||
| /// \param cb_param, callback parameter passed in from DatasetOp when calling the callback | |||
| /// \return Status | |||
| Status DSNStepEnd(const CallbackParam &cb_param) override; | |||
| /// \brief predicate function, whether begin callback is needed | |||
| /// \return bool | |||
| bool IsBeginNeeded() override; | |||
| /// \brief predicate function, whether epoch_begin callback is needed | |||
| /// \return bool | |||
| bool IsEpochBeginNeeded() override; | |||
| /// \brief predicate function, whether step_begin callback is needed | |||
| /// \return bool | |||
| bool IsNStepBeginNeeded() override; | |||
| /// \brief predicate function, whether end callback is needed | |||
| /// \return bool | |||
| bool IsEndNeeded() override; | |||
| /// \brief predicate function, whether epoch_end callback is needed | |||
| /// \return bool | |||
| bool IsEpochEndNeeded() override; | |||
| /// \brief predicate function, whether step_end callback is needed | |||
| /// \return bool | |||
| bool IsNStepEndNeeded() override; | |||
| /// \brief helper function to acquire GIL then execute a pyfunc | |||
| /// \param f the python function | |||
| /// \param cb_param | |||
| /// \return Status | |||
| static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param); | |||
| private: | |||
| py::function begin_func_; | |||
| py::function epoch_begin_func_; | |||
| py::function step_begin_func_; | |||
| py::function end_func_; | |||
| py::function epoch_end_func_; | |||
| py::function step_end_func_; | |||
| bool begin_needed_; | |||
| bool epoch_begin_needed_; | |||
| bool step_begin_needed_; | |||
| bool end_needed_; | |||
| bool epoch_end_needed_; | |||
| bool step_end_needed_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H | |||
| @@ -21,6 +21,8 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "minddata/dataset/callback/callback_manager.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| @@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// \return boolean returns true if it's last iteration | |||
| bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; } | |||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | |||
| /// The expected behavior is this, when this function is invoked, this function will block until all the workers | |||
| /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. | |||
| /// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not | |||
| /// needed. Only parallelOp needs to override this function. | |||
| /// \return Status | |||
| virtual Status PauseFromMaster() { return Status::OK(); } | |||
| protected: | |||
| /// \brief Removes a parent operator from this operator | |||
| /// \notes External callers do not have access to this function | |||
| @@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| std::unique_ptr<DbConnector> out_connector_; // Output Connector | |||
| std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name | |||
| std::mutex column_name_map_mutex_; // For protecting shared access to the column map | |||
| CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp | |||
| private: | |||
| /// Sets the operator id. | |||
| @@ -15,25 +15,23 @@ | |||
| */ | |||
| #include <algorithm> | |||
| #include <cstring> | |||
| #include <iomanip> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/callback/callback_param.h" | |||
| #include "minddata/dataset/core/constants.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/data_buffer.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_op.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/engine/opt/pass.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "minddata/dataset/util/task_manager.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) { | |||
| RETURN_IF_NOT_OK(sanityCheck()); | |||
| *ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_), | |||
| std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_); | |||
| (*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job) | |||
| Status MapOp::operator()() { | |||
| // Create and register the local queues. | |||
| local_queues_.Init(num_workers_, oc_queue_size_); | |||
| // init callback | |||
| RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); | |||
| Status rc = local_queues_.Register(tree_->AllTasks()); | |||
| RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks())); | |||
| if (rc.IsError()) { | |||
| TaskManager::FindMe()->Post(); | |||
| return rc; | |||
| @@ -175,28 +177,51 @@ Status MapOp::operator()() { | |||
| // Synchronize with TaskManager | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(rc); | |||
| // num_buffers received, including eoe, num_epoch, num_step of current epoch | |||
| int64_t num_buf = 0, ep_step = 0, total_step = 0; | |||
| RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step))); | |||
| int64_t que_id = 0; | |||
| std::unique_ptr<DataBuffer> buff; | |||
| bool is_eof = false; | |||
| // Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues | |||
| // Stop when all worker threads are finished (received EOF) | |||
| while (!is_eof) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| is_eof = buff->eof(); | |||
| // Create an empty map worker job to be populated by a databuffer and map jobs | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(); | |||
| worker_job->databuffer = std::move(buff); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| while (!buff->eof()) { | |||
| if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) { | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| } | |||
| while (!buff->eoe()) { | |||
| ep_step++; | |||
| total_step++; | |||
| // Create an empty map worker job to be populated by a databuffer and map jobs | |||
| RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| // Populate map worker job for a worker to execute | |||
| RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job)); | |||
| // Populate map worker job for a worker to execute | |||
| RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job)); | |||
| // Push map worker job to the corresponding worker's queue | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| // Push map worker job to the corresponding worker's queue | |||
| RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job))); | |||
| que_id = (que_id + 1) % num_workers_; | |||
| } | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| } | |||
| // send the eoe buffer to worker | |||
| // reset epoch_step when a new epoch is about to start | |||
| if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| ep_step = 0; | |||
| } | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| UpdateRepeatAndEpochCounter(); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| } | |||
| // the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback | |||
| // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); | |||
| // handle eof logic | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| return Status::OK(); | |||
| } | |||
| @@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) { | |||
| // Fetch next data buffer and map job list | |||
| RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); | |||
| // Sanity check the databuffer. | |||
| // Special case: if there's more threads than buffers, some threads simply get the final control | |||
| // messages (eoe/eof), and so they will not perform the check. | |||
| if (!in_buffer->eoe() && !in_buffer->eof()) { | |||
| int32_t num_rows = in_buffer->NumRows(); | |||
| int32_t num_cols = in_buffer->NumCols(); | |||
| if (num_rows == 0 || num_cols == 0) { | |||
| RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); | |||
| } | |||
| } | |||
| // Now that init work is done, drop into the main fetching loop. | |||
| // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself | |||
| // rather than use the base-class defaults. | |||
| while (true) { | |||
| // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work | |||
| // with Performance Mode design. | |||
| if (in_buffer->eoe()) { | |||
| UpdateRepeatAndEpochCounter(); | |||
| // handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received | |||
| if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) { | |||
| // when worker receives the signal from master thread, it increments a atomic int | |||
| // the last guy who increments the counter, wakes up master thread | |||
| if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set(); | |||
| // this will block the worker until master thread gives it a new work | |||
| RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); | |||
| continue; | |||
| } else if (in_buffer->eoe()) { | |||
| // Calling base class EoeReceived to forward eoe buffer. | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| // Fetch next data buffer and map job list | |||
| @@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) { | |||
| break; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer."); | |||
| std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>()); | |||
| // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. | |||
| RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list)); | |||
| @@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl | |||
| std::vector<TensorRow> result_table; | |||
| // Executing the list of jobs | |||
| for (size_t i = 0; i < job_list.size(); i++) { | |||
| // Executre MapJob. | |||
| // Execute MapJob. | |||
| RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table)); | |||
| // Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list. | |||
| // Assign the processed data as an input for the next job processing, except for the last TensorOp in the list. | |||
| if (i + 1 < job_list.size()) { | |||
| job_input_table = std::move(result_table); | |||
| } | |||
| @@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) { | |||
| // Downcast shared pointer then call visitor | |||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | |||
| } | |||
| Status MapOp::PauseFromMaster() { | |||
| // reset num_paused workers to 0 | |||
| num_workers_paused_ = 0; | |||
| for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | |||
| // a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause. | |||
| RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add( | |||
| std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone)))); | |||
| } | |||
| // wait until all workers are done processing their work in local_queue_ | |||
| RETURN_IF_NOT_OK(master_pause_wp_.Wait()); | |||
| // clear the WaitPost for the next Wait() | |||
| master_pause_wp_.Clear(); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -16,15 +16,19 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_ | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_job.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/engine/datasetops/map_op/map_job.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| @@ -108,6 +112,13 @@ class MapOp : public ParallelOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &AddCallbacks(const std::vector<std::shared_ptr<DSCallback>> &callbacks) { | |||
| builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end()); | |||
| return *this; | |||
| } | |||
| // The builder "build" method creates the final object. | |||
| // @param ptr The shared_ptr to the new MapOp object | |||
| // @return Status | |||
| @@ -116,6 +127,7 @@ class MapOp : public ParallelOp { | |||
| private: | |||
| std::vector<std::string> build_in_col_names_; | |||
| std::vector<std::string> build_out_col_names_; | |||
| std::vector<std::shared_ptr<DSCallback>> builder_callbacks_; | |||
| std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_; | |||
| int32_t build_num_workers_; | |||
| int32_t build_op_connector_size_; | |||
| @@ -186,6 +198,7 @@ class MapOp : public ParallelOp { | |||
| // A unit of job for map worker thread. | |||
| // MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob. | |||
| struct MapWorkerJob { | |||
| explicit MapWorkerJob(std::unique_ptr<DataBuffer> db) : databuffer(std::move(db)) {} | |||
| std::vector<std::shared_ptr<MapJob>> jobs; | |||
| std::unique_ptr<DataBuffer> databuffer; | |||
| }; | |||
| @@ -215,6 +228,12 @@ class MapOp : public ParallelOp { | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices_; | |||
| // wait post used to perform the pausing logic in MapOp | |||
| WaitPost master_pause_wp_; | |||
| // count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| // logic of MapOp: getting the data from previous Op, validating user specified column names, | |||
| // applying a list of TensorOps to each of the data, process the results and then | |||
| @@ -247,6 +266,13 @@ class MapOp : public ParallelOp { | |||
| // Private function for initializing private variables such as in_columns_, out_columns_. | |||
| // @return - Status | |||
| Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map); | |||
| // This function should only be called from master thread. It intends to suspend the operation of all workers and | |||
| // have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost. | |||
| // Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker | |||
| // who does the increment wakes up the master. | |||
| // @return - Status | |||
| Status PauseFromMaster() override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,7 @@ class Semaphore { | |||
| /// \brief Decrement the internal counter. Will be blocked if the value is 0. | |||
| /// \return Error code. Can get interrupt. | |||
| Status P(); | |||
| /// \brief Increment the internal counter. Wakeup on of the watiers if any. | |||
| /// \brief Increment the internal counter. Wakeup on of the waiters if any. | |||
| void V(); | |||
| /// \brief Peek the internal value | |||
| /// \return The internal value | |||
| @@ -59,6 +59,13 @@ namespace dataset { | |||
| } \ | |||
| } while (false) | |||
| #define RETURN_OK_IF_TRUE(_condition) \ | |||
| do { \ | |||
| if (_condition) { \ | |||
| return Status::OK(); \ | |||
| } \ | |||
| } while (false) | |||
| enum class StatusCode : char { | |||
| kOK = 0, | |||
| kOutOfMemory = 1, | |||
| @@ -0,0 +1,18 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """init file for python callback""" | |||
| from .ds_callback import DSCallback, WaitedDSCallback | |||
| __all__ = ["DSCallback", "WaitedDSCallback"] | |||
| @@ -0,0 +1,232 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Python callback class | |||
| """ | |||
| import threading | |||
| from mindspore._c_dataengine import PyDSCallback | |||
| from mindspore.train.callback import Callback | |||
| from .validators import check_callback | |||
| class DSCallback: | |||
| """ | |||
| Abstract base class used to build a dataset callback class. | |||
| Args: | |||
| step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1). | |||
| Examples: | |||
| >>> class PrintInfo(DSCallback): | |||
| >>> def ds_epoch_end(self, ds_run_context): | |||
| >>> print(cb_params.cur_epoch_num) | |||
| >>> print(cb_params.cur_step_num) | |||
| >>> | |||
| >>> data = data.map(operations=op, callbacks=PrintInfo()) | |||
| """ | |||
| @check_callback | |||
| def __init__(self, step_size=1): | |||
| self.step_size = step_size | |||
| def ds_begin(self, ds_run_context): | |||
| """ | |||
| Called before the data pipeline is started. | |||
| Args: | |||
| ds_run_context (RunContext): Include some information of the pipeline. | |||
| """ | |||
| def ds_epoch_begin(self, ds_run_context): | |||
| """ | |||
| Called before a new epoch is started. | |||
| Args: | |||
| ds_run_context (RunContext): Include some information of the pipeline. | |||
| """ | |||
| def ds_epoch_end(self, ds_run_context): | |||
| """ | |||
| Called after an epoch is finished. | |||
| Args: | |||
| ds_run_context (RunContext): Include some information of the pipeline. | |||
| """ | |||
| def ds_step_begin(self, ds_run_context): | |||
| """ | |||
| Called before n steps are started. | |||
| Args: | |||
| ds_run_context (RunContext): Include some information of the pipeline. | |||
| """ | |||
| def ds_step_end(self, ds_run_context): | |||
| """ | |||
| Called after n steps are finished. | |||
| Args: | |||
| ds_run_context (RunContext): Include some information of the pipeline. | |||
| """ | |||
| def create_runtime_obj(self): | |||
| """ | |||
| Creates a runtime (C++) object from the callback methods defined by the user. | |||
| Returns: _c_dataengine.PyDSCallback | |||
| """ | |||
| c_cb = PyDSCallback(self.step_size) | |||
| at_least_one = False | |||
| if self.__class__.ds_begin != DSCallback.ds_begin: | |||
| c_cb.set_begin(self.ds_begin) | |||
| at_least_one = True | |||
| if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin: | |||
| c_cb.set_epoch_begin(self.ds_epoch_begin) | |||
| at_least_one = True | |||
| if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end: | |||
| c_cb.set_epoch_end(self.ds_epoch_end) | |||
| at_least_one = True | |||
| if self.__class__.ds_step_begin != DSCallback.ds_step_begin: | |||
| c_cb.set_step_begin(self.ds_step_begin) | |||
| at_least_one = True | |||
| if self.__class__.ds_step_end != DSCallback.ds_step_end: | |||
| c_cb.set_step_end(self.ds_step_end) | |||
| at_least_one = True | |||
| if not at_least_one: | |||
| raise AttributeError("Provided Callback class did not override any of the 6 callback methods.") | |||
| return c_cb | |||
| class WaitedDSCallback(Callback, DSCallback): | |||
| """ | |||
| Abstract base class used to build a dataset callback class that are synchronized with the training callback. | |||
| This class can be used to execute a user defined logic right after the previous step or epoch. | |||
| For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters. | |||
| Examples: | |||
| >>> my_cb = MyWaitedCallback(32) | |||
| >>> data = data.map(operations=AugOp(), callbacks=my_cb) | |||
| >>> data = data.batch(32) | |||
| >>> # define the model | |||
| >>> model.train(epochs, data, callbacks=[my_cb]) | |||
| Args: | |||
| step_size: the number of rows in each step. | |||
| Usually the step size will be equal to the batch size (Default=1) | |||
| """ | |||
| def __init__(self, step_size=1): | |||
| super().__init__() | |||
| self.step_size = step_size | |||
| self.step_event = threading.Event() | |||
| self.step_run_context = None | |||
| self.epoch_event = threading.Event() | |||
| self.epoch_run_context = None | |||
| def sync_epoch_begin(self, train_run_context, ds_run_context): | |||
| """ | |||
| Called before a new dataset epoch is started and after the previous training epoch is ended. | |||
| Args: | |||
| train_run_context: Include some information of the model with feedback from the previous epoch. | |||
| ds_run_context: Include some information of the dataset pipeline. | |||
| """ | |||
| def sync_step_begin(self, train_run_context, ds_run_context): | |||
| """ | |||
| Called before a new dataset step is started and after the previous training step is ended. | |||
| Args: | |||
| train_run_context: Include some information of the model with feedback from the previous step. | |||
| ds_run_context: Include some information of the dataset pipeline. | |||
| """ | |||
| def epoch_end(self, run_context): | |||
| """ | |||
| Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin. | |||
| Args: | |||
| run_context: Include some information of the model. | |||
| """ | |||
| self.epoch_run_context = run_context | |||
| self.epoch_event.set() | |||
| self.epoch_event.clear() | |||
| def ds_epoch_begin(self, ds_run_context): | |||
| """ | |||
| Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback. | |||
| Args: | |||
| ds_run_context: Include some information of the pipeline. | |||
| """ | |||
| if ds_run_context.cur_epoch_num > 1: | |||
| if self.epoch_run_context is None: | |||
| self.epoch_event.wait() | |||
| self.sync_epoch_begin(self.epoch_run_context, ds_run_context) | |||
| self.epoch_run_context = None | |||
| def step_end(self, run_context): | |||
| """ | |||
| Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin. | |||
| Args: | |||
| run_context: Include some information of the model. | |||
| """ | |||
| self.step_run_context = run_context | |||
| self.step_event.set() | |||
| self.step_event.clear() | |||
| def ds_step_begin(self, ds_run_context): | |||
| """ | |||
| Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback. | |||
| Args: | |||
| ds_run_context: Include some information of the pipeline. | |||
| """ | |||
| if ds_run_context.cur_step_num > self.step_size: | |||
| if self.step_run_context is None: | |||
| self.step_event.wait() | |||
| self.sync_step_begin(self.step_run_context, ds_run_context) | |||
| self.step_run_context = None | |||
| def create_runtime_obj(self): | |||
| """ | |||
| Creates a runtime (C++) object from the callback methods defined by the user. This method is internal. | |||
| Returns: _c_dataengine.PyDSCallback | |||
| """ | |||
| c_cb = PyDSCallback(self.step_size) | |||
| at_least_one = False | |||
| if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin: | |||
| c_cb.set_step_begin(self.ds_step_begin) | |||
| at_least_one = True | |||
| if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin: | |||
| c_cb.set_epoch_begin(self.ds_epoch_begin) | |||
| at_least_one = True | |||
| if not at_least_one: | |||
| raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") | |||
| return c_cb | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License foNtest_resr the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Built-in validators. | |||
| """ | |||
| from functools import wraps | |||
| from ..core.validator_helpers import parse_user_args, check_pos_int32 | |||
| def check_callback(method): | |||
| """check the input arguments of DSCallback.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [step_size], _ = parse_user_args(method, *args, **kwargs) | |||
| check_pos_int32(step_size, "step_size") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | |||
| check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ | |||
| check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\ | |||
| check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \ | |||
| check_paddeddataset | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE | |||
| @@ -395,7 +395,7 @@ class Dataset: | |||
| @check_map | |||
| def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None): | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): | |||
| """ | |||
| Apply each operation in operations to this dataset. | |||
| @@ -438,6 +438,8 @@ class Dataset: | |||
| option could be beneficial if the python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None). | |||
| Returns: | |||
| MapDataset, dataset after mapping operation. | |||
| @@ -552,7 +554,7 @@ class Dataset: | |||
| >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) | |||
| """ | |||
| return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, | |||
| python_multiprocessing, cache) | |||
| python_multiprocessing, cache, callbacks) | |||
| @check_filter | |||
| def filter(self, predicate, input_columns=None, num_parallel_workers=1): | |||
| @@ -1548,6 +1550,7 @@ class DatasetOp(Dataset): | |||
| return self.children[0].get_class_indexing() | |||
| raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self))) | |||
| class BucketBatchByLengthDataset(DatasetOp): | |||
| """ | |||
| The result of applying BucketBatchByLength operator to the input dataset. | |||
| @@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp): | |||
| option could be beneficial if the python operation is computational heavy (default=False). | |||
| cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used). | |||
| The cache feature is under development and is not recommended. | |||
| callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None) | |||
| Raises: | |||
| ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. | |||
| """ | |||
| def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None): | |||
| num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None): | |||
| super().__init__(num_parallel_workers) | |||
| self.children.append(input_dataset) | |||
| if input_columns is not None and not isinstance(input_columns, list): | |||
| @@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp): | |||
| self.python_multiprocessing = python_multiprocessing | |||
| self.process_pool = None | |||
| if callbacks is not None and not isinstance(callbacks, list): | |||
| callbacks = [callbacks] | |||
| self.callbacks = callbacks | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| args["input_columns"] = self.input_columns | |||
| @@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp): | |||
| args["output_columns"] = self.output_columns | |||
| args["columns_order"] = self.columns_order | |||
| args["cache"] = self.cache.cache_client if self.cache is not None else None | |||
| if self.callbacks is not None: | |||
| args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks] | |||
| return args | |||
| def get_dataset_size(self): | |||
| @@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp): | |||
| new_op.cache = copy.deepcopy(self.cache, memodict) | |||
| new_op.operations = self.operations | |||
| new_op.dataset_size = self.dataset_size | |||
| new_op.callbacks = self.callbacks | |||
| return new_op | |||
| # Iterator bootstrap will be called on iterator construction. | |||
| @@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp): | |||
| self._children_start_end_index_[index][0] = cumulative_samples_nums | |||
| self._children_start_end_index_[index][1] = tem_value % sampler.num_shards | |||
| tem_sampler = copy.deepcopy(sampler) | |||
| tem_sampler.set_offset(cumulative_samples_nums) | |||
| child.sampler = tem_sampler | |||
| @@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset): | |||
| def get_dataset_size(self): | |||
| if self.dataset_size is None: | |||
| self.dataset_size = math.ceil((self.stop - self.start)/self.step) | |||
| self.dataset_size = math.ceil((self.stop - self.start) / self.step) | |||
| return self.dataset_size | |||
| @@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset): | |||
| if not self.num_shards: | |||
| self.dataset_size = len(self.source) | |||
| else: | |||
| self.dataset_size = math.ceil(len(self.source)/self.num_shards) | |||
| self.dataset_size = math.ceil(len(self.source) / self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: | |||
| @@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset): | |||
| num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, | |||
| num_shards=num_shards, shard_id=shard_id) | |||
| class _PaddedDataset: | |||
| """ | |||
| Mainly for combining false samples provided by users into a dataset. | |||
| @@ -5435,6 +5447,7 @@ class _PaddedDataset: | |||
| Args: | |||
| padded_samples (list(dict)): the data provided by user to added to initial Dataset | |||
| """ | |||
| def __init__(self, padded_samples): | |||
| self.column_names = list(padded_samples[0].keys()) | |||
| self.padded_samples = padded_samples | |||
| @@ -5445,6 +5458,7 @@ class _PaddedDataset: | |||
| def __len__(self): | |||
| return len(self.padded_samples) | |||
| class PaddedDataset(GeneratorDataset): | |||
| """ | |||
| Create a dataset with fake data provided by user. Mainly used to add to the original data set | |||
| @@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset): | |||
| >>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}] | |||
| >>> ds1 = ds.PaddedDataset(data1) | |||
| """ | |||
| @check_paddeddataset | |||
| def __init__(self, padded_samples): | |||
| dataset = _PaddedDataset(padded_samples) | |||
| @@ -23,6 +23,7 @@ from functools import wraps | |||
| import numpy as np | |||
| from mindspore._c_expression import typing | |||
| from mindspore.dataset.callback import DSCallback | |||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | |||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | |||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | |||
| @@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis | |||
| from . import datasets | |||
| from . import samplers | |||
| from . import cache_client | |||
| from .. import callback | |||
| def check_imagefolderdatasetv2(method): | |||
| @@ -247,6 +249,7 @@ def check_celebadataset(method): | |||
| return new_method | |||
| def check_save(method): | |||
| """A wrapper that wrap a parameter checker to the save op.""" | |||
| @@ -257,7 +260,7 @@ def check_save(method): | |||
| nreq_param_int = ['num_files'] | |||
| nreq_param_str = ['file_name', 'file_type'] | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): | |||
| if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000): | |||
| raise ValueError("num_files should between {} and {}.".format(1, 1000)) | |||
| validate_dataset_param_value(nreq_param_str, param_dict, str) | |||
| if param_dict.get('file_type') != 'mindrecord': | |||
| @@ -265,6 +268,8 @@ def check_save(method): | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_minddataset(method): | |||
| """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" | |||
| @@ -362,6 +367,7 @@ def check_generatordataset(method): | |||
| return new_method | |||
| def check_random_dataset(method): | |||
| """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" | |||
| @@ -545,7 +551,8 @@ def check_map(method): | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \ | |||
| [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache, | |||
| callbacks], _ = \ | |||
| parse_user_args(method, *args, **kwargs) | |||
| nreq_param_columns = ['input_columns', 'output_columns'] | |||
| @@ -558,9 +565,17 @@ def check_map(method): | |||
| if cache is not None: | |||
| type_check(cache, (cache_client.DatasetCache,), "cache") | |||
| if callbacks is not None: | |||
| if isinstance(callbacks, (list, tuple)): | |||
| type_check_list(callbacks, (callback.DSCallback,), "callbacks") | |||
| else: | |||
| type_check(callbacks, (callback.DSCallback,), "callbacks") | |||
| for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): | |||
| if param is not None: | |||
| check_columns(param, param_name) | |||
| if callbacks is not None: | |||
| type_check(callbacks, (list, DSCallback), "callbacks") | |||
| return method(self, *args, **kwargs) | |||
| @@ -15,6 +15,7 @@ SET(DE_UT_SRCS | |||
| bounding_box_augment_op_test.cc | |||
| arena_test.cc | |||
| btree_test.cc | |||
| callback_test.cc | |||
| center_crop_op_test.cc | |||
| channel_swap_test.cc | |||
| circular_pool_test.cc | |||
| @@ -0,0 +1,301 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include <list> | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/callback/ds_callback.h" | |||
| #include "minddata/dataset/core/client.h" | |||
| #include "minddata/dataset/engine/datasetops/source/random_data_op.h" | |||
| #include "minddata/dataset/kernels/data/no_op.h" | |||
| #include "utils/log_adapter.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::LogStream; | |||
| using mindspore::MsLogLevel::INFO; | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace test { | |||
| std::shared_ptr<ExecutionTree> BuildTree(std::vector<std::shared_ptr<DatasetOp>> ops) { | |||
| std::shared_ptr<ExecutionTree> tree = std::make_shared<ExecutionTree>(); | |||
| Status rc; | |||
| for (int i = 0; i < ops.size(); i++) { | |||
| rc = tree->AssociateNode(ops[i]); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| if (i > 0) { | |||
| rc = ops[i]->AddChild(ops[i - 1]); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| if (i == ops.size() - 1) { | |||
| rc = tree->AssignRoot(ops[i]); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| } | |||
| return tree; | |||
| } | |||
| class TestCallback : public DSCallback { | |||
| public: | |||
| TestCallback(int32_t step_size) | |||
| : DSCallback(step_size), | |||
| begin_(true), | |||
| epoch_begin_(true), | |||
| step_begin_(true), | |||
| end_(true), | |||
| epoch_end_(true), | |||
| step_end_(true) { | |||
| all_names_.reserve(32); | |||
| all_step_nums_.reserve(32); | |||
| all_ep_nums_.reserve(32); | |||
| } | |||
| Status DSBegin(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("BGN"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| Status DSEpochBegin(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("EPBGN"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| Status DSNStepBegin(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("SPBGN"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| Status DSEnd(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("END"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| Status DSEpochEnd(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("EPEND"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| Status DSNStepEnd(const CallbackParam &cb_param) override { | |||
| all_names_.push_back("SPEND"); | |||
| all_step_nums_.push_back(cb_param.cur_step_num_); | |||
| all_ep_nums_.push_back(cb_param.cur_epoch_num_); | |||
| return Status::OK(); | |||
| } | |||
| bool IsBeginNeeded() override { return begin_; } | |||
| bool IsEpochBeginNeeded() override { return epoch_begin_; } | |||
| bool IsNStepBeginNeeded() override { return step_begin_; } | |||
| bool IsEndNeeded() override { return end_; } | |||
| bool IsEpochEndNeeded() override { return epoch_end_; } | |||
| bool IsNStepEndNeeded() override { return step_end_; } | |||
| std::vector<std::string> all_names(size_t len) { | |||
| return std::vector<std::string>(all_names_.begin(), all_names_.begin() + len); | |||
| } | |||
| std::vector<int64_t> all_step_nums(size_t len) { | |||
| return std::vector<int64_t>(all_step_nums_.begin(), all_step_nums_.begin() + len); | |||
| } | |||
| std::vector<int64_t> all_ep_nums(size_t len) { | |||
| return std::vector<int64_t>(all_ep_nums_.begin(), all_ep_nums_.begin() + len); | |||
| } | |||
| // flag for turning callback on and off | |||
| bool begin_, epoch_begin_, step_begin_, end_, epoch_end_, step_end_; | |||
| // name of the callback function in sequence, BGN, EPBGN, SPB, END, EPEND, SPEND | |||
| std::vector<std::string> all_names_; | |||
| std::vector<int64_t> all_step_nums_, all_ep_nums_; | |||
| }; | |||
| } // namespace test | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| class MindDataTestCallback : public UT::DatasetOpTesting { | |||
| public: | |||
| void SetUp() override { | |||
| DatasetOpTesting::SetUp(); | |||
| GlobalInit(); | |||
| } | |||
| }; | |||
| TEST_F(MindDataTestCallback, TestBasicCallback) { | |||
| // config callback | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; // don't do the end for now due to a timing issue | |||
| // config leaf_op, use random_data to avoid I/O | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | |||
| schema->AddColumn(col); | |||
| std::shared_ptr<RandomDataOp> leaf; | |||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(44).Build(&leaf); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config mapOp | |||
| std::shared_ptr<MapOp> map_op; | |||
| auto map_b = MapOp::Builder(); | |||
| rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config RepeatOp | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // start build then launch tree | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorMap tensor_map; | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (!tensor_map.empty()) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"}; | |||
| std::vector<int64_t> all_steps = {0, 0, 1, 1, 65, 65, 88}; | |||
| std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1}; | |||
| // doing resize to make sure no unexpected epoch_end or extra epoch_begin is called | |||
| size_t len = 7; | |||
| EXPECT_EQ(tst_cb->all_names(len), callback_names); | |||
| EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | |||
| EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | |||
| } | |||
| TEST_F(MindDataTestCallback, TestMutiEpochCallback) { | |||
| // config callback | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; // don't do the end for now due to a timing issue | |||
| // config leaf_op, use random_data to avoid I/O | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | |||
| schema->AddColumn(col); | |||
| std::shared_ptr<RandomDataOp> leaf; | |||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config mapOp | |||
| std::shared_ptr<MapOp> map_op; | |||
| auto map_b = MapOp::Builder(); | |||
| rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config RepeatOp | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // start build then launch tree | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorMap tensor_map; | |||
| size_t num_epochs = 2; | |||
| for (int ep_num = 0; ep_num < num_epochs; ++ep_num) { | |||
| di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| } | |||
| std::vector<std::string> callback_names = {"BGN", "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND", | |||
| "EPBGN", "SPBGN", "SPEND", "SPBGN", "SPEND", "EPEND"}; | |||
| std::vector<int64_t> all_steps = {0, 0, 1, 1, 5, 5, 8, 8, 9, 9, 13, 13, 16}; | |||
| std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2}; | |||
| size_t len = 13; | |||
| EXPECT_EQ(tst_cb->all_names(len), callback_names); | |||
| EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | |||
| EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | |||
| } | |||
| TEST_F(MindDataTestCallback, TestSelectedCallback) { | |||
| // config callback | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; | |||
| // turn off the epochs | |||
| tst_cb->epoch_begin_ = false; | |||
| tst_cb->epoch_end_ = false; | |||
| // config leaf_op, use random_data to avoid I/O | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | |||
| ColDescriptor col("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &shape); | |||
| schema->AddColumn(col); | |||
| std::shared_ptr<RandomDataOp> leaf; | |||
| rc = RandomDataOp::Builder().SetRowsPerBuffer(1).SetDataSchema(std::move(schema)).SetTotalRows(4).Build(&leaf); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config mapOp | |||
| std::shared_ptr<MapOp> map_op; | |||
| auto map_b = MapOp::Builder(); | |||
| rc = map_b.SetInColNames({"label"}).SetTensorFuncs({std::make_shared<NoOp>()}).AddCallbacks({cb1}).Build(&map_op); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // config RepeatOp | |||
| std::shared_ptr<RepeatOp> repeat_op; | |||
| rc = RepeatOp::Builder(2).Build(&repeat_op); | |||
| // start build then launch tree | |||
| std::shared_ptr<ExecutionTree> tree = test::BuildTree({leaf, map_op, repeat_op}); | |||
| rc = tree->Prepare(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| rc = tree->Launch(); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| // Start the loop of reading tensors from our pipeline | |||
| DatasetIterator di(tree); | |||
| TensorMap tensor_map; | |||
| size_t num_epochs = 2; | |||
| for (int ep_num = 0; ep_num < num_epochs; ++ep_num) { | |||
| di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tensor_map.size() != 0) { | |||
| rc = di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| } | |||
| } | |||
| std::vector<std::string> callback_names = {"BGN", "SPBGN", "SPEND", "SPBGN", "SPEND", | |||
| "SPBGN", "SPEND", "SPBGN", "SPEND"}; | |||
| std::vector<int64_t> all_steps = {0, 1, 1, 5, 5, 9, 9, 13, 13}; | |||
| std::vector<int64_t> all_epochs = {0, 1, 1, 1, 1, 2, 2, 2, 2}; | |||
| size_t len = 9; | |||
| EXPECT_EQ(tst_cb->all_names(len), callback_names); | |||
| EXPECT_EQ(tst_cb->all_ep_nums(len), all_epochs); | |||
| EXPECT_EQ(tst_cb->all_step_nums(len), all_steps); | |||
| } | |||
| @@ -0,0 +1,365 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| from builtins import range, super | |||
| import time | |||
| import pytest | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| from mindspore.dataset.callback import DSCallback, WaitedDSCallback | |||
| from mindspore.train import Model | |||
| from mindspore.train.callback import Callback | |||
| import mindspore.dataset as ds | |||
| import mindspore.nn as nn | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class MyDSCallback(DSCallback): | |||
| def __init__(self, step_size=1, events=None, cb_id=0): | |||
| super().__init__(step_size) | |||
| self.events = events | |||
| self.cb_id = cb_id | |||
| def append(self, event_name, ds_run_context): | |||
| event = [event_name, ds_run_context.cur_epoch_num, | |||
| ds_run_context.cur_step_num_in_epoch, ds_run_context.cur_step_num] | |||
| event = '_'.join([str(e) for e in event]) | |||
| index = -1 | |||
| for i, e in enumerate(self.events): | |||
| if e[0] == event: | |||
| index = i | |||
| break | |||
| if index != -1: | |||
| self.events[index][1].append(self.cb_id) | |||
| else: | |||
| self.events.append((event, [self.cb_id])) | |||
| def ds_begin(self, ds_run_context): | |||
| self.append("begin", ds_run_context) | |||
| def ds_end(self, ds_run_context): | |||
| self.append("end", ds_run_context) | |||
| def ds_epoch_begin(self, ds_run_context): | |||
| self.append("epoch_begin", ds_run_context) | |||
| def ds_epoch_end(self, ds_run_context): | |||
| self.append("epoch_end", ds_run_context) | |||
| def ds_step_begin(self, ds_run_context): | |||
| self.append("step_begin", ds_run_context) | |||
| def ds_step_end(self, ds_run_context): | |||
| self.append("step_end", ds_run_context) | |||
| def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): | |||
| events = [] | |||
| cb_id = list(range(map_num)) | |||
| def append(name, e, s): | |||
| event = [name, e + 1, s + 1, e * step_num * repeat + s + 1] | |||
| event = '_'.join([str(ev) for ev in event]) | |||
| events.append((event, cb_id)) | |||
| events.append(("begin_0_0_0", cb_id)) | |||
| for e in range(epoch_num): | |||
| append("epoch_begin", e, -1) | |||
| for s in range(step_num * repeat): | |||
| if s % step_size == 0: | |||
| append("step_begin", e, s) | |||
| append("step_end", e, s) | |||
| append("epoch_end", e, step_num * repeat - 1) | |||
| return events | |||
| def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): | |||
| events = [] | |||
| arr = list(range(1, steps + 1)) | |||
| data = ds.NumpySlicesDataset(arr, shuffle=False) | |||
| my_cb = MyDSCallback(step_size=step_size, events=events) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb) | |||
| if repeat != 1: | |||
| data = data.repeat(repeat) | |||
| itr = data.create_tuple_iterator(num_epochs=epochs) | |||
| for _ in range(epochs): | |||
| for _ in itr: | |||
| pass | |||
| expected_events = generate_expected(epochs, steps, step_size, 1, repeat) | |||
| assert expected_events == events | |||
| def build_test_case_2cbs(epochs, steps): | |||
| events1 = [] | |||
| events2 = [] | |||
| my_cb1 = MyDSCallback(events=events1) | |||
| my_cb2 = MyDSCallback(events=events2) | |||
| arr = list(range(1, steps + 1)) | |||
| data = ds.NumpySlicesDataset(arr, shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=[my_cb1, my_cb2]) | |||
| itr = data.create_tuple_iterator(num_epochs=epochs) | |||
| for _ in range(epochs): | |||
| for _ in itr: | |||
| pass | |||
| expected_events = generate_expected(epochs, steps) | |||
| assert expected_events == events1 | |||
| assert expected_events == events2 | |||
| def build_test_case_2maps(epochs, steps): | |||
| events = [] | |||
| my_cb1 = MyDSCallback(events=events, cb_id=0) | |||
| my_cb2 = MyDSCallback(events=events, cb_id=1) | |||
| arr = list(range(1, steps + 1)) | |||
| data = ds.NumpySlicesDataset(arr, shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb1) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb2) | |||
| itr = data.create_tuple_iterator(num_epochs=epochs) | |||
| for _ in range(epochs): | |||
| for _ in itr: | |||
| pass | |||
| expected_events = generate_expected(epochs, steps, map_num=2) | |||
| assert expected_events[1:] == events[1:] | |||
| for event in events: | |||
| assert len(event) == 2 | |||
| event, cb_ids = event | |||
| if event != "begin_0_0_0": | |||
| assert cb_ids[0] == 0 | |||
| assert cb_ids[1] == 1 | |||
| def test_callbacks_all_methods(): | |||
| logger.info("test_callbacks_all_methods") | |||
| build_test_case_1cb(1, 1) | |||
| build_test_case_1cb(1, 2) | |||
| build_test_case_1cb(1, 3) | |||
| build_test_case_1cb(1, 4) | |||
| build_test_case_1cb(2, 1) | |||
| build_test_case_1cb(2, 2) | |||
| build_test_case_1cb(2, 3) | |||
| build_test_case_1cb(2, 4) | |||
| build_test_case_1cb(3, 1) | |||
| build_test_case_1cb(3, 2) | |||
| build_test_case_1cb(3, 3) | |||
| build_test_case_1cb(3, 4) | |||
| def test_callbacks_var_step_size(): | |||
| logger.info("test_callbacks_var_step_size") | |||
| build_test_case_1cb(1, 2, 2) | |||
| build_test_case_1cb(1, 3, 2) | |||
| build_test_case_1cb(1, 4, 2) | |||
| build_test_case_1cb(2, 2, 2) | |||
| build_test_case_1cb(2, 3, 2) | |||
| build_test_case_1cb(2, 4, 2) | |||
| build_test_case_1cb(3, 2, 2) | |||
| build_test_case_1cb(3, 3, 2) | |||
| build_test_case_1cb(3, 4, 2) | |||
| def test_callbacks_all_2cbs(): | |||
| logger.info("test_callbacks_all_2cbs") | |||
| build_test_case_2cbs(4, 1) | |||
| build_test_case_2cbs(4, 2) | |||
| build_test_case_2cbs(4, 3) | |||
| build_test_case_2cbs(4, 4) | |||
| def test_callbacks_2maps(): | |||
| logger.info("test_callbacks_2maps") | |||
| build_test_case_2maps(5, 10) | |||
| build_test_case_2maps(6, 9) | |||
| class MyWaitedCallback(WaitedDSCallback): | |||
| def __init__(self, events, step_size=1): | |||
| super().__init__(step_size) | |||
| self.events = events | |||
| def sync_epoch_begin(self, train_run_context, ds_run_context): | |||
| event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" | |||
| self.events.append(event) | |||
| def sync_step_begin(self, train_run_context, ds_run_context): | |||
| event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" | |||
| self.events.append(event) | |||
| class MyMSCallback(Callback): | |||
| def __init__(self, events): | |||
| self.events = events | |||
| def epoch_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" | |||
| self.events.append(event) | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" | |||
| self.events.append(event) | |||
| class Net(nn.Cell): | |||
| def construct(self, x, y): | |||
| return x | |||
| def test_train_non_sink(): | |||
| logger.info("test_train_non_sink") | |||
| events = [] | |||
| my_cb1 = MyWaitedCallback(events, 1) | |||
| my_cb2 = MyMSCallback(events) | |||
| arr = [1, 2, 3, 4] | |||
| data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb1) | |||
| net = Net() | |||
| model = Model(net) | |||
| model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) | |||
| expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', | |||
| 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', | |||
| 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', | |||
| 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', | |||
| 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', | |||
| 'ms_step_end_2_8', 'ms_epoch_end_2_8'] | |||
| assert events == expected_synced_events | |||
| def test_train_batch_size2(): | |||
| logger.info("test_train_batch_size2") | |||
| events = [] | |||
| my_cb1 = MyWaitedCallback(events, 2) | |||
| my_cb2 = MyMSCallback(events) | |||
| arr = [1, 2, 3, 4] | |||
| data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb1) | |||
| data = data.batch(2) | |||
| net = Net() | |||
| model = Model(net) | |||
| model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) | |||
| expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_3', | |||
| 'ms_step_end_1_2', | |||
| 'ms_epoch_end_1_2', 'ds_epoch_begin_2_4', | |||
| 'ds_step_begin_2_5', 'ms_step_end_2_3', 'ds_step_begin_2_7', | |||
| 'ms_step_end_2_4', 'ms_epoch_end_2_4'] | |||
| assert events == expected_synced_events | |||
| def test_callbacks_validations(): | |||
| logger.info("test_callbacks_validations") | |||
| with pytest.raises(Exception) as err: | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| data.map(operations=(lambda x: x), callbacks=0) | |||
| assert "Argument callbacks with value 0 is not " in str(err.value) | |||
| with pytest.raises(Exception) as err: | |||
| my_cb1 = MyDSCallback() | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| data.map(operations=(lambda x: x), callbacks=[my_cb1, 0]) | |||
| assert "Argument callbacks[1] with value 0 is not " in str(err.value) | |||
| with pytest.raises(Exception) as err: | |||
| class BadCB(DSCallback): | |||
| pass | |||
| my_cb = BadCB() | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb) | |||
| for _ in data: | |||
| pass | |||
| assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value) | |||
| def test_callback_sink_simulation(): | |||
| logger.info("test_callback_sink_simulation") | |||
| events = [] | |||
| epochs = 2 | |||
| my_cb = MyWaitedCallback(events, 1) | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb) | |||
| data = data.to_device() | |||
| data.send(num_epochs=epochs) | |||
| for e in range(epochs): | |||
| for s in range(4): | |||
| time.sleep(0.5) | |||
| events.append(f"ms_step_end_{e + 1}_{e * 4 + s + 1}") | |||
| my_cb.step_end(run_context=0) | |||
| events.append(f"ms_epoch_end_{e + 1}_{(e + 1) * 4}") | |||
| my_cb.epoch_end(run_context=0) | |||
| expected_synced_events = ['ms_step_end_1_1', 'ds_step_begin_1_2', 'ms_step_end_1_2', 'ds_step_begin_1_3', | |||
| 'ms_step_end_1_3', 'ds_step_begin_1_4', 'ms_step_end_1_4', | |||
| 'ms_epoch_end_1_4', 'ds_epoch_begin_2_4', | |||
| 'ds_step_begin_2_5', 'ms_step_end_2_5', 'ds_step_begin_2_6', | |||
| 'ms_step_end_2_6', 'ds_step_begin_2_7', 'ms_step_end_2_7', 'ds_step_begin_2_8', | |||
| 'ms_step_end_2_8', 'ms_epoch_end_2_8'] | |||
| assert events == expected_synced_events | |||
| def test_callbacks_repeat(): | |||
| logger.info("test_callbacks_repeat") | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=3) | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) | |||
| build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) | |||
| if __name__ == '__main__': | |||
| test_callbacks_all_methods() | |||
| test_callbacks_all_2cbs() | |||
| test_callbacks_2maps() | |||
| test_callbacks_validations() | |||
| test_callbacks_var_step_size() | |||
| test_train_batch_size2() | |||
| test_callback_sink_simulation() | |||
| test_callbacks_repeat() | |||