From: @lixiachen Reviewed-by: @nsyca,@mikef Signed-off-by: @nsycatags/v1.1.0
| @@ -89,6 +89,14 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| } | } | ||||
| Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); } | Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); } | ||||
| void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } | ||||
| void CacheLookupOp::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: CacheLookupOp"; | |||||
| if (show_all) { | |||||
| // Call the super class for displaying any common detailed info | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | |||||
| } | |||||
| } | |||||
| Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) { | ||||
| std::vector<row_id_type> cache_miss; | std::vector<row_id_type> cache_miss; | ||||
| RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | RETURN_IF_NOT_OK(keys_miss_->Pop(0, &cache_miss)); | ||||
| @@ -99,6 +99,7 @@ class CacheLookupOp : public CacheBase, public SamplerRT { | |||||
| Status InitSampler() override; | Status InitSampler() override; | ||||
| Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override; | ||||
| void Print(std::ostream &out, bool show_all) const override; | void Print(std::ostream &out, bool show_all) const override; | ||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| bool AllowCacheMiss() override { return true; } | bool AllowCacheMiss() override { return true; } | ||||
| std::string Name() const override { return kCacheLookupOp; } | std::string Name() const override { return kCacheLookupOp; } | ||||
| @@ -252,7 +252,7 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const { | |||||
| << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | << "\nNumber repeats per epoch : " << op_num_repeats_per_epoch_; | ||||
| if (sampler_) { | if (sampler_) { | ||||
| out << "\nSampler:\n"; | out << "\nSampler:\n"; | ||||
| sampler_->Print(out, show_all); | |||||
| sampler_->SamplerPrint(out, show_all); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -170,10 +170,10 @@ int64_t DistributedSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| return std::ceil(num_samples * 1.0 / num_devices_); | return std::ceil(num_samples * 1.0 / num_devices_); | ||||
| } | } | ||||
| void DistributedSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void DistributedSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: DistributedSampler"; | out << "\nSampler: DistributedSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ | ||||
| << "\nshuffle: " << shuffle_; | << "\nshuffle: " << shuffle_; | ||||
| } | } | ||||
| @@ -70,7 +70,7 @@ class DistributedSamplerRT : public SamplerRT { | |||||
| /// \return int64_t Calculated number of samples | /// \return int64_t Calculated number of samples | ||||
| int64_t CalculateNumSamples(int64_t num_rows) override; | int64_t CalculateNumSamples(int64_t num_rows) override; | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| int64_t cnt_; // number of samples that have already been filled in to buffer | int64_t cnt_; // number of samples that have already been filled in to buffer | ||||
| @@ -116,11 +116,11 @@ Status PKSamplerRT::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PKSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void PKSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PKSampler"; | out << "\nSampler: PKSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -59,7 +59,7 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| // @param show_all - bool to show detailed vs summary | // @param show_all - bool to show detailed vs summary | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| @@ -106,11 +106,11 @@ Status PythonSamplerRT::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void PythonSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void PythonSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: PythonSampler"; | out << "\nSampler: PythonSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ class PythonSamplerRT : public SamplerRT { | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| // @param show_all - bool to show detailed vs summary | // @param show_all - bool to show detailed vs summary | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() | ||||
| @@ -115,11 +115,11 @@ Status RandomSamplerRT::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void RandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void RandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: RandomSampler"; | out << "\nSampler: RandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -50,7 +50,7 @@ class RandomSamplerRT : public SamplerRT { | |||||
| // @return - The error code return | // @return - The error code return | ||||
| Status ResetSampler() override; | Status ResetSampler() override; | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| uint32_t seed_; | uint32_t seed_; | ||||
| @@ -78,7 +78,7 @@ Status SamplerRT::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64 | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void SamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| // Sampler printing is usually only called in the show_all mode. | // Sampler printing is usually only called in the show_all mode. | ||||
| // Derived classes will display the name, then call back to this base | // Derived classes will display the name, then call back to this base | ||||
| // for common info. | // for common info. | ||||
| @@ -126,7 +126,7 @@ class SamplerRT { | |||||
| // A print method typically used for debugging | // A print method typically used for debugging | ||||
| // @param out - The output stream to write output to | // @param out - The output stream to write output to | ||||
| // @param show_all - A bool to control if you want to show all info or just a summary | // @param show_all - A bool to control if you want to show all info or just a summary | ||||
| virtual void Print(std::ostream &out, bool show_all) const; | |||||
| virtual void SamplerPrint(std::ostream &out, bool show_all) const; | |||||
| // << Stream output operator overload | // << Stream output operator overload | ||||
| // @notes This allows you to write the debug print info using stream operators | // @notes This allows you to write the debug print info using stream operators | ||||
| @@ -134,7 +134,7 @@ class SamplerRT { | |||||
| // @param sampler - reference to teh sampler to print | // @param sampler - reference to teh sampler to print | ||||
| // @return - the output stream must be returned | // @return - the output stream must be returned | ||||
| friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { | friend std::ostream &operator<<(std::ostream &out, const SamplerRT &sampler) { | ||||
| sampler.Print(out, false); | |||||
| sampler.SamplerPrint(out, false); | |||||
| return out; | return out; | ||||
| } | } | ||||
| @@ -97,11 +97,11 @@ Status SequentialSamplerRT::ResetSampler() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SequentialSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SequentialSampler"; | out << "\nSampler: SequentialSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info | // Then add our own info | ||||
| out << "\nStart index: " << start_index_; | out << "\nStart index: " << start_index_; | ||||
| } | } | ||||
| @@ -52,7 +52,7 @@ class SequentialSamplerRT : public SamplerRT { | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| // @param show_all - bool to show detailed vs summary | // @param show_all - bool to show detailed vs summary | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| int64_t current_id_; // The id sequencer. Each new id increments from this | int64_t current_id_; // The id sequencer. Each new id increments from this | ||||
| @@ -119,11 +119,11 @@ Status SubsetRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buf | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void SubsetRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void SubsetRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: SubsetRandomSampler"; | out << "\nSampler: SubsetRandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ class SubsetRandomSamplerRT : public SamplerRT { | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| // @param show_all - bool to show detailed vs summary | // @param show_all - bool to show detailed vs summary | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| // A list of indices (already randomized in constructor). | // A list of indices (already randomized in constructor). | ||||
| @@ -181,11 +181,11 @@ Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_b | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| void WeightedRandomSamplerRT::Print(std::ostream &out, bool show_all) const { | |||||
| void WeightedRandomSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const { | |||||
| out << "\nSampler: WeightedRandomSampler"; | out << "\nSampler: WeightedRandomSampler"; | ||||
| if (show_all) { | if (show_all) { | ||||
| // Call the super class for displaying any common detailed info | // Call the super class for displaying any common detailed info | ||||
| SamplerRT::Print(out, show_all); | |||||
| SamplerRT::SamplerPrint(out, show_all); | |||||
| // Then add our own info if any | // Then add our own info if any | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ class WeightedRandomSamplerRT : public SamplerRT { | |||||
| // Printer for debugging purposes. | // Printer for debugging purposes. | ||||
| // @param out - output stream to write to | // @param out - output stream to write to | ||||
| // @param show_all - bool to show detailed vs summary | // @param show_all - bool to show detailed vs summary | ||||
| void Print(std::ostream &out, bool show_all) const override; | |||||
| void SamplerPrint(std::ostream &out, bool show_all) const override; | |||||
| private: | private: | ||||
| // A list of weights for each sample. | // A list of weights for each sample. | ||||