Merge pull request !6221 from anzhengqi/stop-send-at-eostags/v1.0.0
| @@ -81,6 +81,7 @@ PYBIND_REGISTER( | |||||
| .def("GetNumClasses", &DEPipeline::GetNumClasses) | .def("GetNumClasses", &DEPipeline::GetNumClasses) | ||||
| .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | .def("GetRepeatCount", &DEPipeline::GetRepeatCount) | ||||
| .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) | .def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); }) | ||||
| .def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); }) | |||||
| .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | .def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) { | ||||
| THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | THROW_IF_ERROR(de.SaveDataset(file_names, file_type)); | ||||
| return true; | return true; | ||||
| @@ -291,6 +291,16 @@ Status DEPipeline::StopSend() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status DEPipeline::ContinueSend() { | |||||
| // tree_.root() must be DeviceQueueOp | |||||
| DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get()); | |||||
| if (op == nullptr) { | |||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp"); | |||||
| } | |||||
| op->ContinueSend(); | |||||
| return Status::OK(); | |||||
| } | |||||
| int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); } | ||||
| bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } | bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); } | ||||
| @@ -203,6 +203,9 @@ class DEPipeline { | |||||
| Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom); | ||||
| Status StopSend(); | Status StopSend(); | ||||
| Status ContinueSend(); | |||||
| Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | ||||
| std::shared_ptr<DatasetOp> *bottom); | std::shared_ptr<DatasetOp> *bottom); | ||||
| @@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() { | |||||
| } | } | ||||
| if (current_buffer->eoe() && send_epoch_end_) { | if (current_buffer->eoe() && send_epoch_end_) { | ||||
| TensorRow currRow; | TensorRow currRow; | ||||
| while (stop_send_) { | |||||
| MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal..."; | |||||
| std::this_thread::sleep_for(std::chrono::microseconds(100)); | |||||
| } | |||||
| auto status = | auto status = | ||||
| tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); | tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE); | ||||
| if (status == TdtStatus::FAILED) { | if (status == TdtStatus::FAILED) { | ||||
| @@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() { | |||||
| return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); | return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "an epoch has already sent, now stop send data."; | |||||
| stop_send_ = true; | |||||
| } | } | ||||
| if (isProfilingEnable) { | if (isProfilingEnable) { | ||||
| connector_size = ChildOpConnectorSize(); | connector_size = ChildOpConnectorSize(); | ||||
| @@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp { | |||||
| void StopSend() { stop_send_ = true; } | void StopSend() { stop_send_ = true; } | ||||
| void ContinueSend() { | |||||
| MS_LOG(INFO) << "continue send at the beginning of the epoch"; | |||||
| stop_send_ = false; | |||||
| } | |||||
| // Name: Print() | // Name: Print() | ||||
| // Description: A function that prints info about the node | // Description: A function that prints info about the node | ||||
| void Print(std::ostream &out, // In: The output stream to print to | void Print(std::ostream &out, // In: The output stream to print to | ||||
| @@ -2585,6 +2585,9 @@ class TransferDataset(DatasetOp): | |||||
| def stop_send(self): | def stop_send(self): | ||||
| self.iterator.depipeline.StopSend() | self.iterator.depipeline.StopSend() | ||||
| def continue_send(self): | |||||
| self.iterator.depipeline.ContinueSend() | |||||
| class RangeDataset(MappableDataset): | class RangeDataset(MappableDataset): | ||||
| """ | """ | ||||
| @@ -163,6 +163,10 @@ class DatasetHelper: | |||||
| """Free up resources about data sink.""" | """Free up resources about data sink.""" | ||||
| self.iter.stop_send() | self.iter.stop_send() | ||||
| def continue_send(self): | |||||
| """continue send data to device at the beginning of epoch.""" | |||||
| self.iter.continue_send() | |||||
| class _DatasetIter: | class _DatasetIter: | ||||
| """Base iter for dataset helper""" | """Base iter for dataset helper""" | ||||
| @@ -182,6 +186,7 @@ class _DatasetIter: | |||||
| _send_data_no_flag(dataset, epoch_num) | _send_data_no_flag(dataset, epoch_num) | ||||
| self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | self.stop_send = dataset.__TRANSFER_DATASET__.stop_send | ||||
| self.continue_send = dataset.__TRANSFER_DATASET__.continue_send | |||||
| self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) | ||||
| def __iter__(self): | def __iter__(self): | ||||
| @@ -442,6 +442,7 @@ class Model: | |||||
| cb_params.net_outputs = outputs | cb_params.net_outputs = outputs | ||||
| list_callback.step_end(run_context) | list_callback.step_end(run_context) | ||||
| dataset_helper.continue_send() | |||||
| list_callback.epoch_end(run_context) | list_callback.epoch_end(run_context) | ||||
| should_stop = should_stop or run_context.get_stop_requested() | should_stop = should_stop or run_context.get_stop_requested() | ||||
| if should_stop: | if should_stop: | ||||
| @@ -64,6 +64,9 @@ class MindData: | |||||
| def stop_send(self): | def stop_send(self): | ||||
| pass | pass | ||||
| def continue_send(self): | |||||
| pass | |||||
| def __len__(self): | def __len__(self): | ||||
| return self._size | return self._size | ||||
| @@ -98,6 +98,6 @@ def test_deeplabv3_1p(): | |||||
| print("expect loss: ", callback.loss) | print("expect loss: ", callback.loss) | ||||
| print("expect time: ", callback.time) | print("expect time: ", callback.time) | ||||
| expect_loss = 0.92 | expect_loss = 0.92 | ||||
| expect_time = 40 | |||||
| expect_time = 43 | |||||
| assert callback.loss.asnumpy() <= expect_loss | assert callback.loss.asnumpy() <= expect_loss | ||||
| assert callback.time <= expect_time | assert callback.time <= expect_time | ||||