diff --git a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc index 45d65288a2..89e20b8d11 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc @@ -1067,6 +1067,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptrSetDeviceId(ToInt(value)); } else if (key == "send_epoch_end") { (void)builder->SetSendEpochEnd(ToBool(value)); + } else if (key == "total_batch") { + (void)builder->SetTotalBatch(ToInt(value)); } } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 69ed55309c..9b3305256f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -33,14 +33,15 @@ namespace mindspore { namespace dataset { DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - bool send_epoch_end) + bool send_epoch_end, int total_batch) : PipelineOp(1), channel_name_(channel_name), device_type_(device_type), device_id_(device_id), prefetch_size_(prefetch_size), send_epoch_end_(send_epoch_end), - stop_send_(false) { + stop_send_(false), + total_batch_(total_batch) { #ifdef ENABLE_TDTQUE ascend_keep_waiting_ = true; #endif @@ -60,7 +61,8 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size) : builder_prefetch_size_(prefetch_size), builder_device_id_(0), builder_device_type_(DeviceType::CPU), - builder_channel_name_("") {} + builder_channel_name_(""), + builder_total_batch_(0) {} Status DeviceQueueOp::EoeReceived(int32_t worker_id) { state_ = OpState::kDeOpIdle; @@ -102,11 +104,13 @@ Status DeviceQueueOp::operator()() { #ifdef ENABLE_TDTQUE Status DeviceQueueOp::SendDataToAscend() { MS_LOG(INFO) << "Device queue, sending data to Ascend."; - int64_t total_batch = 0; + int64_t send_batch = 0; double batch_start_time, end_time; int32_t batch_cost, tdt_cost; int32_t connector_size = 0; int32_t connector_capacity; + bool is_break_loop = false; + std::shared_ptr profiling_node; bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); if (isProfilingEnable) { @@ -119,8 +123,8 @@ Status DeviceQueueOp::SendDataToAscend() { std::unique_ptr current_buffer; RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - while (!current_buffer->eof()) { - while (!current_buffer->eoe()) { + while (!current_buffer->eof() && !is_break_loop) { + while (!current_buffer->eoe() && !is_break_loop) { RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); TensorRow currRow; for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) { @@ -142,17 +146,21 @@ Status DeviceQueueOp::SendDataToAscend() { if (isProfilingEnable) { end_time = ProfilingTime::GetCurMilliSecond(); // record push tdt time - profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost); + profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch + 1, tdt_cost); batch_cost = (int32_t)(end_time - batch_start_time); // record batch time - profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost); + profiling_node->Record(TIME, BATCH_TIME, send_batch + 1, batch_cost); // record pipeline time - profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost); + profiling_node->Record(TIME, PIPELINE_TIME, send_batch + 1, batch_cost - tdt_cost); batch_start_time = end_time; // record connector depth - profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); + profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch + 1, connector_size); + } + send_batch++; + if (total_batch_ > 0 && send_batch >= total_batch_) { + is_break_loop = true; + break; } - total_batch++; } if (isProfilingEnable) { connector_size = ChildOpConnectorSize(); @@ -184,7 +192,7 @@ Status DeviceQueueOp::SendDataToAscend() { } tree_->SetFinished(); - MS_LOG(INFO) << "Device queue total batch is " << total_batch; + MS_LOG(INFO) << "Device queue total batch is " << send_batch; return Status::OK(); } @@ -193,7 +201,7 @@ Status DeviceQueueOp::SendDataToAscend() { #ifdef ENABLE_GPUQUE Status DeviceQueueOp::SendDataToGPU() { MS_LOG(INFO) << "Device queue, sending data to GPU."; - int64_t total_batch = 0; + int64_t send_batch = 0; bool is_break_loop = false; bool is_open = false; uint32_t handle = INVALID_HANDLE; @@ -235,19 +243,23 @@ Status DeviceQueueOp::SendDataToGPU() { is_open = true; } RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle, isProfilingEnable, &push_cost)); - total_batch++; + send_batch++; if (isProfilingEnable) { end_time = ProfilingTime::GetCurMilliSecond(); // record push data time - profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch, push_cost); + profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost); batch_cost = (int32_t)(end_time - batch_start_time); // record batch time - profiling_node->Record(TIME, BATCH_TIME, total_batch, batch_cost); + profiling_node->Record(TIME, BATCH_TIME, send_batch, batch_cost); // record pipeline time - profiling_node->Record(TIME, PIPELINE_TIME, total_batch, batch_cost - push_cost); + profiling_node->Record(TIME, PIPELINE_TIME, send_batch, batch_cost - push_cost); batch_start_time = end_time; // record connector depth - profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch, connector_size); + profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch, connector_size); + } + if (total_batch_ > 0 && send_batch >= total_batch_) { + is_break_loop = true; + break; } } if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed()) { @@ -272,7 +284,7 @@ Status DeviceQueueOp::SendDataToGPU() { } tree_->SetFinished(); - MS_LOG(INFO) << "Device queue total batch is " << total_batch << "."; + MS_LOG(INFO) << "Device queue total batch is " << send_batch << "."; GpuBufferMgr::GetInstance().Close(handle); GpuBufferMgr::GetInstance().CloseConfirm(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 7dc999dfa5..385dfa47a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -86,13 +86,18 @@ class DeviceQueueOp : public PipelineOp { return *this; } + Builder &SetTotalBatch(int total_batch) { + builder_total_batch_ = total_batch; + return *this; + } + // Name: Build() // Description: The final step for building a DeviceQueueOp via the Builder is // to call this Build() method. It will instantiate the DeviceQueueOp // and return it to caller as a shared pointer. Status Build(std::shared_ptr *ptr) { *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, - builder_prefetch_size_, builder_send_epoch_end_); + builder_prefetch_size_, builder_send_epoch_end_, builder_total_batch_); return Status::OK(); } @@ -102,12 +107,13 @@ class DeviceQueueOp : public PipelineOp { DeviceType builder_device_type_; std::string builder_channel_name_; bool builder_send_epoch_end_; + int builder_total_batch_; }; // Name: constructor // Description DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - bool send_epoch_end); + bool send_epoch_end, int total_batch); // Name: destructor // Description @@ -183,6 +189,7 @@ class DeviceQueueOp : public PipelineOp { const int32_t prefetch_size_; const bool send_epoch_end_; bool stop_send_; + int total_batch_; #ifdef ENABLE_TDTQUE std::shared_ptr tdtInstancePtr; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index b2bae22f9c..8bbe8a2daa 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2623,6 +2623,8 @@ class TransferDataset(DatasetOp): args["device_type"] = self._device_type args["device_id"] = self._device_id args["send_epoch_end"] = self._send_epoch_end + if hasattr(self.children[0], "__total_batch__"): + args["total_batch"] = self.children[0].__total_batch__ return args def create_dict_iterator(self, num_epochs=-1, output_numpy=False): diff --git a/mindspore/train/model.py b/mindspore/train/model.py index bba8a26c4e..6537f0e9fe 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -403,6 +403,7 @@ class Model: epoch_num = epoch else: epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) + train_dataset.__total_batch__ = epoch * sink_size dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True,