Browse Source

!6221 stop send data to device after end of sequence

Merge pull request !6221 from anzhengqi/stop-send-at-eos
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f6ac30ef29
10 changed files with 38 additions and 1 deletions
  1. +1
    -0
      mindspore/ccsrc/minddata/dataset/api/python/bindings.cc
  2. +10
    -0
      mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc
  3. +3
    -0
      mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h
  4. +6
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
  5. +5
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h
  6. +3
    -0
      mindspore/dataset/engine/datasets.py
  7. +5
    -0
      mindspore/train/dataset_helper.py
  8. +1
    -0
      mindspore/train/model.py
  9. +3
    -0
      tests/dataset_mock.py
  10. +1
    -1
      tests/st/networks/models/deeplabv3/test_deeplabv3.py

+ 1
- 0
mindspore/ccsrc/minddata/dataset/api/python/bindings.cc View File

@@ -81,6 +81,7 @@ PYBIND_REGISTER(
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("ContinueSend", [](DEPipeline &de) { THROW_IF_ERROR(de.ContinueSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;


+ 10
- 0
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.cc View File

@@ -291,6 +291,16 @@ Status DEPipeline::StopSend() {
return Status::OK();
}

Status DEPipeline::ContinueSend() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
if (op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ContinueSend only supported by DeviceQueueOp");
}
op->ContinueSend();
return Status::OK();
}

int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }

bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }


+ 3
- 0
mindspore/ccsrc/minddata/dataset/api/python/de_pipeline.h View File

@@ -203,6 +203,9 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);

Status StopSend();

Status ContinueSend();

Status ParseBuildSentencePieceVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);



+ 6
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc View File

@@ -153,6 +153,10 @@ Status DeviceQueueOp::SendDataToAscend() {
}
if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow;
while (stop_send_) {
MS_LOG(DEBUG) << "stop_send flag is set, waiting for continue signal...";
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) {
@@ -163,6 +167,8 @@ Status DeviceQueueOp::SendDataToAscend() {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
MS_LOG(INFO) << "an epoch has already sent, now stop send data.";
stop_send_ = true;
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();


+ 5
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h View File

@@ -123,6 +123,11 @@ class DeviceQueueOp : public PipelineOp {

void StopSend() { stop_send_ = true; }

void ContinueSend() {
MS_LOG(INFO) << "continue send at the beginning of the epoch";
stop_send_ = false;
}

// Name: Print()
// Description: A function that prints info about the node
void Print(std::ostream &out, // In: The output stream to print to


+ 3
- 0
mindspore/dataset/engine/datasets.py View File

@@ -2585,6 +2585,9 @@ class TransferDataset(DatasetOp):
def stop_send(self):
self.iterator.depipeline.StopSend()

def continue_send(self):
self.iterator.depipeline.ContinueSend()


class RangeDataset(MappableDataset):
"""


+ 5
- 0
mindspore/train/dataset_helper.py View File

@@ -163,6 +163,10 @@ class DatasetHelper:
"""Free up resources about data sink."""
self.iter.stop_send()

def continue_send(self):
"""continue send data to device at the beginning of epoch."""
self.iter.continue_send()


class _DatasetIter:
"""Base iter for dataset helper"""
@@ -182,6 +186,7 @@ class _DatasetIter:
_send_data_no_flag(dataset, epoch_num)

self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)

def __iter__(self):


+ 1
- 0
mindspore/train/model.py View File

@@ -442,6 +442,7 @@ class Model:
cb_params.net_outputs = outputs
list_callback.step_end(run_context)

dataset_helper.continue_send()
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:


+ 3
- 0
tests/dataset_mock.py View File

@@ -64,6 +64,9 @@ class MindData:
def stop_send(self):
pass

def continue_send(self):
pass

def __len__(self):
return self._size



+ 1
- 1
tests/st/networks/models/deeplabv3/test_deeplabv3.py View File

@@ -98,6 +98,6 @@ def test_deeplabv3_1p():
print("expect loss: ", callback.loss)
print("expect time: ", callback.time)
expect_loss = 0.92
expect_time = 40
expect_time = 43
assert callback.loss.asnumpy() <= expect_loss
assert callback.time <= expect_time

Loading…
Cancel
Save