|
|
|
@@ -20,6 +20,7 @@ |
|
|
|
#include "minddata/dataset/engine/datasetops/cache_op.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/cache_merge_op.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/device_queue_op.h" |
|
|
|
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
@@ -258,6 +259,13 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status RepeatPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) { |
|
|
|
// Set total repeats and total epochs for the DeviceQueueOp |
|
|
|
node->set_total_repeats(num_epochs_); |
|
|
|
node->set_num_repeats_per_epoch(1); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
// Adds an operator to the eoe operator stack save area |
|
|
|
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { |
|
|
|
op_stack *current_stack = eoe_op_stacks_.top().get(); |
|
|
|
|