|
|
|
@@ -49,14 +49,16 @@ Status GeneratorOp::Builder::Build(std::shared_ptr<GeneratorOp> *ptr) { |
|
|
|
|
|
|
|
GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::string> column_names, |
|
|
|
std::vector<DataType> column_types, int32_t prefetch_size, int32_t buffer_size, |
|
|
|
int32_t connector_size) |
|
|
|
int32_t connector_size, int64_t pre_counter_size) |
|
|
|
: PipelineOp(connector_size), |
|
|
|
generator_function_(generator_function), |
|
|
|
column_names_(column_names), |
|
|
|
column_types_(column_types), |
|
|
|
prefetch_size_(prefetch_size), |
|
|
|
buffer_size_(buffer_size), |
|
|
|
buffer_id_(0) {} |
|
|
|
pre_counter_size_(pre_counter_size), |
|
|
|
buffer_id_(0), |
|
|
|
generator_counter_(0) {} |
|
|
|
|
|
|
|
GeneratorOp::~GeneratorOp() { this->Dealloc(); } |
|
|
|
|
|
|
|
@@ -146,6 +148,7 @@ Status GeneratorOp::FillBuffer(TensorQTable *tt) { |
|
|
|
TensorRow row; |
|
|
|
RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); |
|
|
|
tt->push_back(std::move(row)); |
|
|
|
generator_counter_++; |
|
|
|
} |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
@@ -209,6 +212,13 @@ Status GeneratorOp::operator()() { |
|
|
|
if (!eoe) { |
|
|
|
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); |
|
|
|
} |
|
|
|
if (pre_counter_size_ != -1 && pre_counter_size_ != generator_counter_) { |
|
|
|
std::stringstream ss; |
|
|
|
ss << "The actual amount of data read from generator " << generator_counter_ |
|
|
|
<< " is different from generator.len " << pre_counter_size_ |
|
|
|
<< ", you should adjust generator.len to make them match."; |
|
|
|
return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, ss.str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (fetched_table->size() > 0) { |
|
|
|
@@ -254,6 +264,7 @@ Status GeneratorOp::Reset() { |
|
|
|
// Wake up master thread |
|
|
|
wp_.Set(); |
|
|
|
} |
|
|
|
generator_counter_ = 0; |
|
|
|
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); |
|
|
|
} |
|
|
|
|
|
|
|
|