Browse Source

fix ut

tags/v1.2.0-rc1
Eric 4 years ago
parent
commit
4e5b174f8f
1 changed files with 33 additions and 6 deletions
  1. +33
    -6
      mindspore/ccsrc/minddata/dataset/api/datasets.cc

+ 33
- 6
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -991,16 +991,24 @@ MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
const std::shared_ptr<Sampler> &sampler, nlohmann::json *padded_sample, const std::shared_ptr<Sampler> &sampler, nlohmann::json *padded_sample,
int64_t num_padded) { int64_t num_padded) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};
auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj, auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
*padded_sample, num_padded);
sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file, MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
const std::vector<std::vector<char>> &columns_list, Sampler *sampler, const std::vector<std::vector<char>> &columns_list, Sampler *sampler,
nlohmann::json *padded_sample, int64_t num_padded) { nlohmann::json *padded_sample, int64_t num_padded) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};
auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj, auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
*padded_sample, num_padded);
sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file, MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
@@ -1008,8 +1016,13 @@ MindDataDataset::MindDataDataset(const std::vector<char> &dataset_file,
const std::reference_wrapper<Sampler> sampler, nlohmann::json *padded_sample, const std::reference_wrapper<Sampler> sampler, nlohmann::json *padded_sample,
int64_t num_padded) { int64_t num_padded) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};

auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj, auto ds = std::make_shared<MindDataNode>(CharToString(dataset_file), VectorCharToString(columns_list), sampler_obj,
*padded_sample, num_padded);
sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files, MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
@@ -1017,16 +1030,26 @@ MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_f
const std::shared_ptr<Sampler> &sampler, nlohmann::json *padded_sample, const std::shared_ptr<Sampler> &sampler, nlohmann::json *padded_sample,
int64_t num_padded) { int64_t num_padded) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};

auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list), auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
sampler_obj, *padded_sample, num_padded);
sampler_obj, sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files, MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
const std::vector<std::vector<char>> &columns_list, Sampler *sampler, const std::vector<std::vector<char>> &columns_list, Sampler *sampler,
nlohmann::json *padded_sample, int64_t num_padded) { nlohmann::json *padded_sample, int64_t num_padded) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr; auto sampler_obj = sampler ? sampler->Parse() : nullptr;
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};

auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list), auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
sampler_obj, *padded_sample, num_padded);
sampler_obj, sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files, MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_files,
@@ -1034,8 +1057,12 @@ MindDataDataset::MindDataDataset(const std::vector<std::vector<char>> &dataset_f
const std::reference_wrapper<Sampler> sampler, nlohmann::json *padded_sample, const std::reference_wrapper<Sampler> sampler, nlohmann::json *padded_sample,
int64_t num_padded) { int64_t num_padded) {
auto sampler_obj = sampler.get().Parse(); auto sampler_obj = sampler.get().Parse();
nlohmann::json sample = nullptr;
if (padded_sample) {
sample = *padded_sample;
};
auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list), auto ds = std::make_shared<MindDataNode>(VectorCharToString(dataset_files), VectorCharToString(columns_list),
sampler_obj, *padded_sample, num_padded);
sampler_obj, sample, num_padded);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds); ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
} }
#endif #endif


Loading…
Cancel
Save