Merge pull request !4776 from ZiruiWu/map_callback_follow_uptags/v0.7.0-beta
| @@ -45,6 +45,8 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) { | |||
| .def("get_op_connector_size", &ConfigManager::op_connector_size) | |||
| .def("get_seed", &ConfigManager::seed) | |||
| .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) | |||
| .def("get_callback_timeout", &ConfigManager::callback_timeout) | |||
| .def("set_callback_timeout", &ConfigManager::set_callback_timeout) | |||
| .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); | |||
| })); | |||
| @@ -50,7 +50,7 @@ Status CallbackManager::Begin(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -69,7 +69,7 @@ Status CallbackManager::EpochBegin(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no epoch_begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -89,7 +89,7 @@ Status CallbackManager::StepBegin(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no step_begin is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -108,7 +108,7 @@ Status CallbackManager::End(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -127,7 +127,7 @@ Status CallbackManager::EpochEnd(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no epoch_end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -147,7 +147,7 @@ Status CallbackManager::StepEnd(const CallbackParam &cb_param) { | |||
| // return Status::OK() if no step_end is needed | |||
| RETURN_OK_IF_TRUE(callback_inds.empty()); | |||
| RETURN_IF_NOT_OK(op_->PauseFromMaster()); | |||
| RETURN_IF_NOT_OK(op_->WaitForWorkers()); | |||
| // Now do the actual callback | |||
| for (size_t ind : callback_inds) { | |||
| @@ -32,7 +32,7 @@ class DatasetOp; | |||
| /// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this. | |||
| class CallbackManager { | |||
| public: | |||
| /// CallbackManager default constructor. Init needs to be called before using the created instance. | |||
| /// \brief CallbackManager default constructor. Init needs to be called before using the created instance. | |||
| CallbackManager() : enabled_(false) {} | |||
| /// \brief | |||
| @@ -88,5 +88,8 @@ uint32_t ConfigManager::seed() const { return seed_; } | |||
| void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } | |||
| void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } | |||
| void ConfigManager::set_callback_timeout(uint32_t timeout) { callback_timout_ = timeout; } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -116,9 +116,17 @@ class ConfigManager { | |||
| void set_monitor_sampling_interval(uint32_t interval); | |||
| // getter function | |||
| // @return The iterval of monitor sampling | |||
| // @return The interval of monitor sampling | |||
| int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } | |||
| // setter function | |||
| // @param timeout - The setting to apply to the config | |||
| void set_callback_timeout(uint32_t timeout); | |||
| // getter function | |||
| // @return The timeout DSWaitedCallback would wait for before raising an error | |||
| int32_t callback_timeout() const { return callback_timout_; } | |||
| private: | |||
| int32_t rows_per_buffer_{kCfgRowsPerBuffer}; | |||
| int32_t num_parallel_workers_{kCfgParallelWorkers}; | |||
| @@ -126,8 +134,9 @@ class ConfigManager { | |||
| int32_t op_connector_size_{kCfgOpConnectorSize}; | |||
| uint32_t seed_{kCfgDefaultSeed}; | |||
| uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; | |||
| uint32_t callback_timout_{kCfgCallbackTimeout}; | |||
| // Private helper function that taks a nlohmann json format and populates the settings | |||
| // Private helper function that takes a nlohmann json format and populates the settings | |||
| // @param j - The json nlohmann json info | |||
| Status FromJson(const nlohmann::json &j); | |||
| }; | |||
| @@ -68,6 +68,7 @@ constexpr uint32_t kCfgWorkerConnectorSize = 16; | |||
| constexpr uint32_t kCfgOpConnectorSize = 16; | |||
| constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; | |||
| constexpr uint32_t kCfgMonitorSamplingInterval = 10; | |||
| constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds | |||
| // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) | |||
| constexpr uint8_t kCVInvalidType = 255; | |||
| @@ -37,8 +37,10 @@ class DataBuffer { | |||
| // Buffer flags | |||
| enum BufferFlags : uint32_t { | |||
| kDeBFlagNone = 0, | |||
| kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg | |||
| kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg | |||
| kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg | |||
| kDeBFlagEOE = 1u << 1, // The buffer is an eoe end-of-epoch msg | |||
| kDeBFlagWait = 1u << 2, // The buffer is an control signal for workers to suspend operations | |||
| kDeBFlagQuit = 1u << 3 // The buffer is a control signal for workers to quit | |||
| }; | |||
| // Name: Constructor #1 | |||
| @@ -64,6 +66,10 @@ class DataBuffer { | |||
| bool eoe() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOE)); } | |||
| bool wait() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagWait)); } | |||
| bool quit() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagQuit)); } | |||
| // Simple getter funcs | |||
| int32_t id() const { return buffer_id_; } | |||
| @@ -363,10 +363,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> { | |||
| /// This function is only intended to be called by CallbackManager within the master thread of ParallelOp | |||
| /// The expected behavior is this, when this function is invoked, this function will block until all the workers | |||
| /// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master. | |||
| /// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not | |||
| /// needed. Only parallelOp needs to override this function. | |||
| /// They would automatically wait on the QueueList when they are done. | |||
| /// \return Status | |||
| virtual Status PauseFromMaster() { return Status::OK(); } | |||
| virtual Status WaitForWorkers() { return Status::OK(); } | |||
| protected: | |||
| /// \brief Removes a parent operator from this operator | |||
| @@ -166,7 +166,7 @@ Status MapOp::operator()() { | |||
| // init callback | |||
| RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this())); | |||
| Status rc = local_queues_.Register(tree_->AllTasks()); | |||
| RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| if (rc.IsError()) { | |||
| TaskManager::FindMe()->Post(); | |||
| return rc; | |||
| @@ -205,23 +205,29 @@ Status MapOp::operator()() { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| } | |||
| // send the eoe buffer to worker | |||
| // reset epoch_step when a new epoch is about to start | |||
| // check whether this is the end of a real epoch (not all eoe signals end of epoch) | |||
| if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) { | |||
| RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step))); | |||
| ep_step = 0; | |||
| } | |||
| // Propagate the eoe buffer to worker | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| UpdateRepeatAndEpochCounter(); | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); | |||
| } | |||
| // the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback | |||
| // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); | |||
| // handle eof logic | |||
| // End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1 | |||
| // RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step))); | |||
| // Handle eof logic, this code might never be reached if epoch_ctrl = -1. | |||
| std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff)); | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job))); | |||
| // Quit all workers, this code might never be reached if EpochCtrl is -1. | |||
| for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | |||
| auto quit = std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagQuit)); | |||
| RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(quit))); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| @@ -242,26 +248,27 @@ Status MapOp::WorkerEntry(int32_t worker_id) { | |||
| // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself | |||
| // rather than use the base-class defaults. | |||
| while (true) { | |||
| // handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received | |||
| if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) { | |||
| // when worker receives the signal from master thread, it increments a atomic int | |||
| // the last guy who increments the counter, wakes up master thread | |||
| if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set(); | |||
| // this will block the worker until master thread gives it a new work | |||
| RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); | |||
| continue; | |||
| } else if (in_buffer->eoe()) { | |||
| // Calling base class EoeReceived to forward eoe buffer. | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| // Fetch next data buffer and map job list | |||
| // Handle special logic where buffer carries a ctrl flag. | |||
| if (in_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { | |||
| if (in_buffer->wait()) { | |||
| // When worker receives the signal from master thread, it increments a atomic int | |||
| // The last guy who increments the counter, wakes up master thread | |||
| if (++num_workers_paused_ == num_workers_) { | |||
| wait_for_workers_post_.Set(); | |||
| } | |||
| // This will block the worker until master thread gives it a new work | |||
| } else if (in_buffer->eoe()) { | |||
| // Calling base class EoeReceived to forward eoe buffer. | |||
| RETURN_IF_NOT_OK(EoeReceived(worker_id)); | |||
| } else if (in_buffer->eof()) { | |||
| // Calling base class EofReceived to forward eof buffer. | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| } else if (in_buffer->quit()) { | |||
| break; | |||
| } | |||
| RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list)); | |||
| continue; | |||
| } else if (in_buffer->eof()) { | |||
| // Calling base class EofReceived to forward eof buffer. | |||
| RETURN_IF_NOT_OK(EofReceived(worker_id)); | |||
| break; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer."); | |||
| std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>()); | |||
| // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. | |||
| @@ -299,9 +306,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl | |||
| // Variable to keep the result after executing the job. | |||
| std::vector<TensorRow> result_table; | |||
| // Executing the list of jobs | |||
| // Executing the list of jobs. | |||
| for (size_t i = 0; i < job_list.size(); i++) { | |||
| // Execute MapJob. | |||
| // Execute MapWorkerJob. | |||
| RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table)); | |||
| // Assign the processed data as an input for the next job processing, except for the last TensorOp in the list. | |||
| if (i + 1 < job_list.size()) { | |||
| @@ -311,8 +318,7 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl | |||
| // Sanity check a row in result_table | |||
| if (!result_table.empty() && out_columns_.size() != result_table[0].size()) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, | |||
| "Result of a tensorOp doesn't match output column names"); | |||
| RETURN_STATUS_UNEXPECTED("Result of a tensorOp doesn't match output column names"); | |||
| } | |||
| // Merging the data processed by job (result_table) with the data that are not used. | |||
| @@ -386,7 +392,7 @@ Status MapOp::InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_ | |||
| // columns from child are correct | |||
| RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); | |||
| // initialize keep_input_columns, true means to keep the column. | |||
| // Initialize keep_input_columns, true means to keep the column. | |||
| keep_input_columns_.resize(col_name_id_map->size(), true); | |||
| for (const auto &col_name : in_columns_) { | |||
| int32_t missed = (*col_name_id_map)[col_name]; | |||
| @@ -449,18 +455,18 @@ Status MapOp::Accept(NodePass *p, bool *modified) { | |||
| return p->RunOnNode(shared_from_base<MapOp>(), modified); | |||
| } | |||
| Status MapOp::PauseFromMaster() { | |||
| Status MapOp::WaitForWorkers() { | |||
| // reset num_paused workers to 0 | |||
| num_workers_paused_ = 0; | |||
| for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) { | |||
| // a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause. | |||
| RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add( | |||
| std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone)))); | |||
| std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagWait)))); | |||
| } | |||
| // wait until all workers are done processing their work in local_queue_ | |||
| RETURN_IF_NOT_OK(master_pause_wp_.Wait()); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Wait()); | |||
| // clear the WaitPost for the next Wait() | |||
| master_pause_wp_.Clear(); | |||
| wait_for_workers_post_.Clear(); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -228,10 +228,10 @@ class MapOp : public ParallelOp { | |||
| // Indices of the columns to process. | |||
| std::vector<size_t> to_process_indices_; | |||
| // wait post used to perform the pausing logic in MapOp | |||
| WaitPost master_pause_wp_; | |||
| // Wait post used to perform the pausing logic in MapOp | |||
| WaitPost wait_for_workers_post_; | |||
| // count number of workers that have signaled master | |||
| // Count number of workers that have signaled master | |||
| std::atomic_int num_workers_paused_; | |||
| // Private function for worker/thread to loop continuously. It comprises the main | |||
| @@ -272,7 +272,7 @@ class MapOp : public ParallelOp { | |||
| // Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker | |||
| // who does the increment wakes up the master. | |||
| // @return - Status | |||
| Status PauseFromMaster() override; | |||
| Status WaitForWorkers() override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,7 @@ class Semaphore { | |||
| /// \brief Decrement the internal counter. Will be blocked if the value is 0. | |||
| /// \return Error code. Can get interrupt. | |||
| Status P(); | |||
| /// \brief Increment the internal counter. Wakeup on of the waiters if any. | |||
| /// \brief Increment the internal counter. Wake up on of the waiters if any. | |||
| void V(); | |||
| /// \brief Peek the internal value | |||
| /// \return The internal value | |||
| @@ -18,6 +18,7 @@ Python callback class | |||
| import threading | |||
| from mindspore._c_dataengine import PyDSCallback | |||
| from mindspore.train.callback import Callback | |||
| import mindspore.dataset as ds | |||
| from .validators import check_callback | |||
| @@ -170,7 +171,6 @@ class WaitedDSCallback(Callback, DSCallback): | |||
| """ | |||
| self.epoch_run_context = run_context | |||
| self.epoch_event.set() | |||
| self.epoch_event.clear() | |||
| def ds_epoch_begin(self, ds_run_context): | |||
| """ | |||
| @@ -180,10 +180,12 @@ class WaitedDSCallback(Callback, DSCallback): | |||
| ds_run_context: Include some information of the pipeline. | |||
| """ | |||
| if ds_run_context.cur_epoch_num > 1: | |||
| if self.epoch_run_context is None: | |||
| self.epoch_event.wait() | |||
| success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) | |||
| self.epoch_event.clear() | |||
| if not success: | |||
| raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)") | |||
| # by the time this thread wakes up, self.epoch_run_context is already available | |||
| self.sync_epoch_begin(self.epoch_run_context, ds_run_context) | |||
| self.epoch_run_context = None | |||
| def step_end(self, run_context): | |||
| """ | |||
| @@ -194,7 +196,6 @@ class WaitedDSCallback(Callback, DSCallback): | |||
| """ | |||
| self.step_run_context = run_context | |||
| self.step_event.set() | |||
| self.step_event.clear() | |||
| def ds_step_begin(self, ds_run_context): | |||
| """ | |||
| @@ -204,10 +205,12 @@ class WaitedDSCallback(Callback, DSCallback): | |||
| ds_run_context: Include some information of the pipeline. | |||
| """ | |||
| if ds_run_context.cur_step_num > self.step_size: | |||
| if self.step_run_context is None: | |||
| self.step_event.wait() | |||
| success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) | |||
| self.step_event.clear() | |||
| if not success: | |||
| raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)") | |||
| # by the time this thread wakes up, self.epoch_run_context is already available | |||
| self.sync_step_begin(self.step_run_context, ds_run_context) | |||
| self.step_run_context = None | |||
| def create_runtime_obj(self): | |||
| """ | |||
| @@ -157,6 +157,38 @@ def get_monitor_sampling_interval(): | |||
| return _config.get_monitor_sampling_interval() | |||
| def set_callback_timeout(timeout): | |||
| """ | |||
| Set the default timeout (in seconds) for DSWaitedCallback. | |||
| In case of a deadlock, the wait function will exit after the timeout period. | |||
| Args: | |||
| timeout (int): timeout(s) to be used to end teh wait in DSWaitedCallback in case of a deadlock. | |||
| Raises: | |||
| ValueError: If timeout is invalid (<= 0 or > MAX_INT_32). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> # sets the new timout value. | |||
| >>> ds.config.set_callback_timeout(100) | |||
| """ | |||
| if timeout <= 0 or timeout > INT32_MAX: | |||
| raise ValueError("timeout given is not within the required range.") | |||
| _config.set_callback_timeout(timeout) | |||
| def get_callback_timeout(): | |||
| """ | |||
| Get the default timeout for DSWaitedCallback. | |||
| In case of a deadlock, the wait function will exit after the timeout period. | |||
| Returns: | |||
| Int, the duration in seconds | |||
| """ | |||
| return _config.get_callback_timeout() | |||
| def __str__(): | |||
| """ | |||
| String representation of the configurations. | |||
| @@ -57,7 +57,7 @@ class TestCallback : public DSCallback { | |||
| begin_(true), | |||
| epoch_begin_(true), | |||
| step_begin_(true), | |||
| end_(true), | |||
| end_(false), | |||
| epoch_end_(true), | |||
| step_end_(true) { | |||
| all_names_.reserve(32); | |||
| @@ -145,7 +145,6 @@ TEST_F(MindDataTestCallback, TestBasicCallback) { | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(64); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; // don't do the end for now due to a timing issue | |||
| // config leaf_op, use random_data to avoid I/O | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | |||
| @@ -193,7 +192,6 @@ TEST_F(MindDataTestCallback, TestMutiEpochCallback) { | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; // don't do the end for now due to a timing issue | |||
| // config leaf_op, use random_data to avoid I/O | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape shape({}); // empty shape is a 1-value scalar Tensor | |||
| @@ -247,7 +245,6 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) { | |||
| Status rc; | |||
| std::shared_ptr<test::TestCallback> tst_cb = std::make_shared<test::TestCallback>(4); | |||
| std::shared_ptr<DSCallback> cb1 = tst_cb; | |||
| tst_cb->end_ = false; | |||
| // turn off the epochs | |||
| tst_cb->epoch_begin_ = false; | |||
| tst_cb->epoch_end_ = false; | |||
| @@ -29,7 +29,7 @@ import mindspore.nn as nn | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class MyDSCallback(DSCallback): | |||
| class BaseCallback(DSCallback): | |||
| def __init__(self, step_size=1, events=None, cb_id=0): | |||
| super().__init__(step_size) | |||
| self.events = events | |||
| @@ -49,25 +49,36 @@ class MyDSCallback(DSCallback): | |||
| else: | |||
| self.events.append((event, [self.cb_id])) | |||
| class Begin(BaseCallback): | |||
| def ds_begin(self, ds_run_context): | |||
| self.append("begin", ds_run_context) | |||
| def ds_end(self, ds_run_context): | |||
| self.append("end", ds_run_context) | |||
| class EpochBegin(BaseCallback): | |||
| def ds_epoch_begin(self, ds_run_context): | |||
| self.append("epoch_begin", ds_run_context) | |||
| class EpochEnd(BaseCallback): | |||
| def ds_epoch_end(self, ds_run_context): | |||
| self.append("epoch_end", ds_run_context) | |||
| class StepBegin(BaseCallback): | |||
| def ds_step_begin(self, ds_run_context): | |||
| self.append("step_begin", ds_run_context) | |||
| class StepEnd(BaseCallback): | |||
| def ds_step_end(self, ds_run_context): | |||
| self.append("step_end", ds_run_context) | |||
| class MyDSCallback(Begin, EpochBegin, EpochEnd, StepBegin, StepEnd): | |||
| pass | |||
| def generate_expected(epoch_num, step_num, step_size=1, map_num=1, repeat=1): | |||
| events = [] | |||
| cb_id = list(range(map_num)) | |||
| @@ -98,7 +109,12 @@ def build_test_case_1cb(epochs, steps, step_size=1, repeat=1): | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb) | |||
| if repeat != 1: | |||
| data = data.repeat(repeat) | |||
| if repeat % 2 == 0 and repeat != 2: | |||
| data = data.repeat(2) | |||
| data = data.map(operations=(lambda x: x)) | |||
| data = data.repeat(repeat // 2) | |||
| else: | |||
| data = data.repeat(repeat) | |||
| itr = data.create_tuple_iterator(num_epochs=epochs) | |||
| for _ in range(epochs): | |||
| for _ in itr: | |||
| @@ -201,11 +217,10 @@ def test_callbacks_all_2cbs(): | |||
| build_test_case_2cbs(4, 4) | |||
| def test_callbacks_2maps(): | |||
| def skip_test_callbacks_2maps(): | |||
| logger.info("test_callbacks_2maps") | |||
| # This test case is skipped because in rare cases (25 out 1000) it might fail | |||
| build_test_case_2maps(5, 10) | |||
| build_test_case_2maps(6, 9) | |||
| @@ -243,8 +258,8 @@ class Net(nn.Cell): | |||
| return x | |||
| def test_train_non_sink(): | |||
| logger.info("test_train_non_sink") | |||
| def test_callbacks_non_sink(): | |||
| logger.info("test_callbacks_non_sink") | |||
| events = [] | |||
| my_cb1 = MyWaitedCallback(events, 1) | |||
| @@ -267,8 +282,8 @@ def test_train_non_sink(): | |||
| assert events == expected_synced_events | |||
| def test_train_batch_size2(): | |||
| logger.info("test_train_batch_size2") | |||
| def test_callbacks_non_sink_batch_size2(): | |||
| logger.info("test_callbacks_non_sink_batch_size2") | |||
| events = [] | |||
| my_cb1 = MyWaitedCallback(events, 2) | |||
| @@ -291,6 +306,27 @@ def test_train_batch_size2(): | |||
| assert events == expected_synced_events | |||
| def test_callbacks_non_sink_mismatch_size(): | |||
| logger.info("test_callbacks_non_sink_mismatch_size") | |||
| default_timeout = ds.config.get_callback_timeout() | |||
| ds.config.set_callback_timeout(1) | |||
| events = [] | |||
| my_cb1 = MyWaitedCallback(events, 2) | |||
| my_cb2 = MyMSCallback(events) | |||
| arr = [1, 2, 3, 4] | |||
| data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_cb1) | |||
| data = data.batch(3) | |||
| net = Net() | |||
| model = Model(net) | |||
| with pytest.raises(Exception) as err: | |||
| model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) | |||
| assert "RuntimeError: ds_step_begin timed out after 1 second(s)" in str(err.value) | |||
| ds.config.set_callback_timeout(default_timeout) | |||
| def test_callbacks_validations(): | |||
| logger.info("test_callbacks_validations") | |||
| @@ -318,7 +354,7 @@ def test_callbacks_validations(): | |||
| assert "Provided Callback class did not override any of the 6 callback methods." in str(err.value) | |||
| def test_callback_sink_simulation(): | |||
| def test_callbacks_sink_simulation(): | |||
| logger.info("test_callback_sink_simulation") | |||
| events = [] | |||
| @@ -353,13 +389,72 @@ def test_callbacks_repeat(): | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=3) | |||
| build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=3) | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=2) | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=1, repeat=4) | |||
| build_test_case_1cb(epochs=2, steps=2, step_size=2, repeat=8) | |||
| build_test_case_1cb(epochs=3, steps=2, step_size=4, repeat=16) | |||
| def test_callbacks_exceptions(): | |||
| logger.info("test_callbacks_exceptions") | |||
| class BadCB(DSCallback): | |||
| def ds_begin(self, ds_run_context): | |||
| raise RuntimeError("Bad begin") | |||
| with pytest.raises(Exception) as err: | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| data = data.map(operations=(lambda x: x), callbacks=BadCB()) | |||
| for _ in data: | |||
| pass | |||
| assert "RuntimeError: Bad begin" in str(err.value) | |||
| def test_callbacks_one_cb(): | |||
| logger.info("test_callbacks_one_cb") | |||
| data = ds.NumpySlicesDataset([1, 2, 3, 4], shuffle=False) | |||
| events1 = [] | |||
| events2 = [] | |||
| events3 = [] | |||
| my_begin = Begin(events=events1, cb_id=1) | |||
| my_epoch_begin = EpochBegin(events=events2, cb_id=2) | |||
| my_epoch_end = EpochEnd(events=events3, cb_id=3) | |||
| my_step_begin = StepBegin(events=events3, cb_id=3) | |||
| my_step_end = StepEnd(events=events2, cb_id=2) | |||
| data = data.map(operations=(lambda x: x), callbacks=my_begin) | |||
| data = data.map(operations=(lambda x: x), callbacks=[my_epoch_begin, my_step_end]) | |||
| data = data.map(operations=(lambda x: x), callbacks=[my_epoch_end, my_step_begin]) | |||
| itr = data.create_tuple_iterator() | |||
| for _ in range(2): | |||
| for _ in itr: | |||
| pass | |||
| expected_events1 = [('begin_0_0_0', [1])] | |||
| expected_events2 = [('epoch_begin_1_0_0', [2]), ('step_end_1_1_1', [2]), ('step_end_1_2_2', [2]), | |||
| ('step_end_1_3_3', [2]), ('step_end_1_4_4', [2]), ('epoch_begin_2_0_4', [2]), | |||
| ('step_end_2_1_5', [2]), ('step_end_2_2_6', [2]), ('step_end_2_3_7', [2]), | |||
| ('step_end_2_4_8', [2])] | |||
| expected_events3 = [('step_begin_1_1_1', [3]), ('step_begin_1_2_2', [3]), ('step_begin_1_3_3', [3]), | |||
| ('step_begin_1_4_4', [3]), ('epoch_end_1_4_4', [3]), ('step_begin_2_1_5', [3]), | |||
| ('step_begin_2_2_6', [3]), ('step_begin_2_3_7', [3]), ('step_begin_2_4_8', [3]), | |||
| ('epoch_end_2_4_8', [3])] | |||
| assert events1 == expected_events1 | |||
| assert events2 == expected_events2 | |||
| assert events3 == expected_events3 | |||
| if __name__ == '__main__': | |||
| test_callbacks_all_methods() | |||
| skip_test_callbacks_2maps() | |||
| test_callbacks_all_2cbs() | |||
| test_callbacks_2maps() | |||
| test_callbacks_all_methods() | |||
| test_callbacks_exceptions() | |||
| test_callbacks_repeat() | |||
| test_callbacks_sink_simulation() | |||
| test_callbacks_validations() | |||
| test_callbacks_var_step_size() | |||
| test_train_batch_size2() | |||
| test_callback_sink_simulation() | |||
| test_callbacks_repeat() | |||
| test_callbacks_non_sink_batch_size2() | |||
| test_callbacks_non_sink() | |||
| test_callbacks_one_cb() | |||
| test_callbacks_non_sink_mismatch_size() | |||