Fix review 1 Address comments 2 Fix review 3tags/v1.1.0
| @@ -781,7 +781,6 @@ inline Status Tensor::CreateFromVector<std::string>(const std::vector<std::strin | |||||
| num_bytes -= str.length() + 1; | num_bytes -= str.length() + 1; | ||||
| } | } | ||||
| // store one more offset value so we can get the length of the last string | // store one more offset value so we can get the length of the last string | ||||
| // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] | |||||
| offset_arr[i] = offset; | offset_arr[i] = offset; | ||||
| (*out)->data_end_ = (*out)->data_ + offset_arr[i]; | (*out)->data_end_ = (*out)->data_ + offset_arr[i]; | ||||
| @@ -546,9 +546,9 @@ Status BatchOp::GetDatasetSize(int64_t *dataset_size) { | |||||
| RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows)); | ||||
| if (num_rows > 0 && start_batch_size_ > 0) { | if (num_rows > 0 && start_batch_size_ > 0) { | ||||
| if (drop_) { | if (drop_) { | ||||
| num_rows = floor(num_rows / (1.0 * start_batch_size_)); | |||||
| num_rows = static_cast<int64_t>(floor(num_rows / (1.0 * start_batch_size_))); | |||||
| } else { | } else { | ||||
| num_rows = ceil(num_rows / (1.0 * start_batch_size_)); | |||||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * start_batch_size_))); | |||||
| } | } | ||||
| } | } | ||||
| *dataset_size = num_rows; | *dataset_size = num_rows; | ||||
| @@ -52,16 +52,8 @@ Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| void ParallelOp::Print(std::ostream &out, bool show_all) const { | void ParallelOp::Print(std::ostream &out, bool show_all) const { | ||||
| // Summary 1-liner print | |||||
| if (!show_all) { | |||||
| // Call super class printer | |||||
| DatasetOp::Print(out, show_all); | |||||
| out << " [workers: " << num_workers_ << "]"; | |||||
| } else { | |||||
| // Detailed print | |||||
| DatasetOp::Print(out, show_all); | |||||
| out << "\nNum workers: " << num_workers_; | |||||
| } | |||||
| DatasetOp::Print(out, show_all); | |||||
| out << " [workers: " << num_workers_ << "]"; | |||||
| } | } | ||||
| // Override base class reset to provide reset actions specific to the ParallelOp class. | // Override base class reset to provide reset actions specific to the ParallelOp class. | ||||
| @@ -220,7 +220,7 @@ Status ShuffleOp::operator()() { | |||||
| } | } | ||||
| // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the | // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the | ||||
| // pipepline manually now that we are done draining the shuffle buffer | |||||
| // pipeline manually now that we are done draining the shuffle buffer | |||||
| MS_LOG(DEBUG) << "Shuffle operator sending EOE."; | MS_LOG(DEBUG) << "Shuffle operator sending EOE."; | ||||
| auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | ||||
| RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); | ||||
| @@ -283,7 +283,7 @@ Status ShuffleOp::InitShuffleBuffer() { | |||||
| shuffle_buffer_state_ = kShuffleStateDrain; | shuffle_buffer_state_ = kShuffleStateDrain; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; | |||||
| MS_LOG(DEBUG) << "Shuffle operator finished initializing the shuffle buffer."; | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -111,7 +111,9 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou | |||||
| *output = out; | *output = out; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } catch (const std::exception &e) { | } catch (const std::exception &e) { | ||||
| RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); | |||||
| std::string err_msg = "Unexpected error in OneHotOp: "; | |||||
| err_msg += e.what(); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | } | ||||
| } | } | ||||
| @@ -89,7 +89,8 @@ RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops | |||||
| : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) { | : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) { | ||||
| if (ops_.empty()) { | if (ops_.empty()) { | ||||
| MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty."; | MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty."; | ||||
| } else if (ops_.size() == 1) { | |||||
| } | |||||
| if (ops_.size() == 1) { | |||||
| MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; | MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time."; | ||||
| } | } | ||||
| is_deterministic_ = false; | is_deterministic_ = false; | ||||
| @@ -343,8 +343,8 @@ class Dataset: | |||||
| Add a blocking condition to the input Dataset. | Add a blocking condition to the input Dataset. | ||||
| Args: | Args: | ||||
| num_batch (int): the number of batches without blocking at the start of each epoch. | |||||
| condition_name (str): The condition name that is used to toggle sending next row. | condition_name (str): The condition name that is used to toggle sending next row. | ||||
| num_batch (int): the number of batches without blocking at the start of each epoch. | |||||
| callback (function): The callback funciton that will be invoked when sync_update is called. | callback (function): The callback funciton that will be invoked when sync_update is called. | ||||
| Raises: | Raises: | ||||
| @@ -1452,9 +1452,10 @@ class Dataset: | |||||
| num_batch (Union[int, None]): The number of batches (rows) that are released. | num_batch (Union[int, None]): The number of batches (rows) that are released. | ||||
| When num_batch is None, it will default to the number specified by the | When num_batch is None, it will default to the number specified by the | ||||
| sync_wait operator (default=None). | sync_wait operator (default=None). | ||||
| data (Union[dict, None]): The data passed to the callback (default=None). | |||||
| data (Any): The data passed to the callback, user defined (default=None). | |||||
| """ | """ | ||||
| if isinstance(num_batch, int) and num_batch <= 0: | |||||
| if (not isinstance(num_batch, int) and num_batch is not None) or \ | |||||
| (isinstance(num_batch, int) and num_batch <= 0): | |||||
| # throwing exception, disable all sync_wait in pipeline | # throwing exception, disable all sync_wait in pipeline | ||||
| self.disable_sync() | self.disable_sync() | ||||
| raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch)) | raise RuntimeError("Sync_update batch size can only be positive, got : {}.".format(num_batch)) | ||||