|
|
|
@@ -30,8 +30,8 @@ |
|
|
|
#include "minddata/mindrecord/include/shard_writer.h" |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace mindspore::dataset { |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace dataset { |
|
|
|
// TreeConsumer |
|
|
|
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); } |
|
|
|
|
|
|
|
@@ -440,7 +440,9 @@ Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape & |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); } |
|
|
|
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) { |
|
|
|
tree_adapter_ = std::make_unique<TreeAdapter>(); |
|
|
|
} |
|
|
|
|
|
|
|
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { |
|
|
|
root_ = std::move(d); |
|
|
|
@@ -473,20 +475,14 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) { |
|
|
|
} |
|
|
|
|
|
|
|
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) { |
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); |
|
|
|
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); |
|
|
|
|
|
|
|
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types), |
|
|
|
[](const TensorPtr &t) { return t->type(); }); |
|
|
|
RETURN_IF_NOT_OK(GetFirstRowShapeAndType()); |
|
|
|
*types = first_row_type_; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) { |
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); |
|
|
|
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_)); |
|
|
|
|
|
|
|
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes), |
|
|
|
[](const TensorPtr &t) { return t->shape(); }); |
|
|
|
RETURN_IF_NOT_OK(GetFirstRowShapeAndType()); |
|
|
|
*shapes = first_row_shape_; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -555,6 +551,18 @@ Status TreeGetters::InternalInit() { |
|
|
|
if (!s.IsError()) init_flag_ = true; |
|
|
|
return s; |
|
|
|
} |
|
|
|
Status TreeGetters::GetFirstRowShapeAndType() { |
|
|
|
RETURN_OK_IF_TRUE(first_row_obtained_); |
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType))); |
|
|
|
TensorRow first_row; |
|
|
|
RETURN_IF_NOT_OK(GetRow(&first_row)); |
|
|
|
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_), |
|
|
|
[](const TensorPtr &t) { return t->type(); }); |
|
|
|
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_), |
|
|
|
[](const TensorPtr &t) { return t->shape(); }); |
|
|
|
first_row_obtained_ = true; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); } |
|
|
|
|
|
|
|
Status BuildVocabConsumer::Start() { |
|
|
|
@@ -565,4 +573,5 @@ Status BuildVocabConsumer::Start() { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE."); |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
} // namespace mindspore::dataset |
|
|
|
} // namespace dataset |
|
|
|
} // namespace mindspore |