Browse Source

move common helper function

r1.7
zetongzhao 4 years ago
parent
commit
ab5f79f522
4 changed files with 61 additions and 66 deletions
  1. +5
    -59
      tests/ut/cpp/dataset/c_api_dataset_ops_test.cc
  2. +39
    -0
      tests/ut/cpp/dataset/common/common.cc
  3. +17
    -0
      tests/ut/cpp/dataset/common/common.h
  4. +0
    -7
      tests/ut/cpp/dataset/skip_pushdown_optimization_pass_test.cc

+ 5
- 59
tests/ut/cpp/dataset/c_api_dataset_ops_test.cc View File

@@ -16,7 +16,6 @@
#include "common/common.h"
#include "include/api/types.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/kernels/ir/data/transforms_ir.h"
@@ -142,68 +141,15 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};

TensorRow VecToRow(const MSTensorVec &v) {
TensorRow row;
for (const mindspore::MSTensor &t : v) {
std::shared_ptr<Tensor> rt;
(void)Tensor::CreateFromMemory(TensorShape(t.Shape()), MSTypeToDEType(static_cast<mindspore::TypeId>(t.DataType())),
(const uchar *)(t.Data().get()), t.DataSize(), &rt);
row.emplace_back(rt);
}
return row;
}
MSTensorVec RowToVec(const TensorRow &v) {
MSTensorVec rv; // std::make_shared<DETensor>(de_tensor)
std::transform(v.begin(), v.end(), std::back_inserter(rv), [](std::shared_ptr<Tensor> t) -> mindspore::MSTensor {
return mindspore::MSTensor(std::make_shared<DETensor>(t));
});
return rv;
}

MSTensorVec BucketBatchTestFunction(MSTensorVec input) {
mindspore::dataset::TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({1}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_INT32), &out);
(void)out->SetItemAt({0}, 2);
output.push_back(out);
return RowToVec(output);
}

MSTensorVec Predicate1(MSTensorVec in) {
// Return true if input is equal to 3
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(0)->GetItemAt(&input_value, {0});
bool result = (input_value == 3);

// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
(void)Tensor::CreateEmpty(
TensorShape({1}), DataType(DataType::Type::DE_INT32),
&out);
constexpr int value = 2;
(void)out->SetItemAt({0}, value);
output.push_back(out);

return RowToVec(output);
}

MSTensorVec Predicate2(MSTensorVec in) {
// Return true if label is more than 1
// The index of label in input is 1
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(1)->GetItemAt(&input_value, {0});
bool result = (input_value > 1);

// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(mindspore::dataset::TensorShape({}),
mindspore::dataset::DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);

return RowToVec(output);
}



+ 39
- 0
tests/ut/cpp/dataset/common/common.cc View File

@@ -150,3 +150,42 @@ std::shared_ptr<mindspore::dataset::ExecutionTree> DatasetOpTesting::Build(
#endif
#endif
} // namespace UT

namespace mindspore {
namespace dataset {
MSTensorVec Predicate1(MSTensorVec in) {
// Return true if input is equal to 3
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(0)->GetItemAt(&input_value, {0});
bool result = (input_value == 3);

// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(TensorShape({}), DataType(DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);

return RowToVec(output);
}

MSTensorVec Predicate2(MSTensorVec in) {
// Return true if label is more than 1
// The index of label in input is 1
uint64_t input_value;
TensorRow input = VecToRow(in);
(void)input.at(1)->GetItemAt(&input_value, {0});
bool result = (input_value > 1);

// Convert from boolean to TensorRow
TensorRow output;
std::shared_ptr<Tensor> out;
(void)Tensor::CreateEmpty(TensorShape({}), DataType(mindspore::dataset::DataType::Type::DE_BOOL), &out);
(void)out->SetItemAt({}, result);
output.push_back(out);

return RowToVec(output);
}
} // namespace dataset
} // namespace mindspore

+ 17
- 0
tests/ut/cpp/dataset/common/common.h View File

@@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"

using mindspore::Status;
using mindspore::StatusCode;
@@ -118,4 +119,20 @@ class DatasetOpTesting : public Common {
void SetUp() override;
};
} // namespace UT

namespace mindspore {
namespace dataset {
// defined in datasets.cc code, and function prototypes added here for UT purposes
// convert MSTensorVec to DE TensorRow, return empty if fails
TensorRow VecToRow(const MSTensorVec &v);

// defined in datasets.cc code, and function prototypes added here for UT purposes
// convert DE TensorRow to MSTensorVec, won't fail
MSTensorVec RowToVec(const TensorRow &v);

MSTensorVec Predicate1(MSTensorVec in);

MSTensorVec Predicate2(MSTensorVec in);
} // namespace dataset
} // namespace mindspore
#endif // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_

+ 0
- 7
tests/ut/cpp/dataset/skip_pushdown_optimization_pass_test.cc View File

@@ -18,7 +18,6 @@
#include <string>

#include "common/common.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include "minddata/dataset/engine/opt/pre/skip_pushdown_pass.h"
#include "minddata/dataset/include/dataset/samplers.h"
#include "minddata/dataset/include/dataset/vision.h"
@@ -107,12 +106,6 @@ class MindDataSkipPushdownTestOptimizationPass : public UT::DatasetOpTesting {
}
};

TensorRow VecToRow(const MSTensorVec &v);

MSTensorVec RowToVec(const TensorRow &v);

MSTensorVec Predicate1(MSTensorVec in);

/// Feature: MindData Skip Pushdown Optimization Pass Test
/// Description: Test MindData Skip Pushdown Optimization Pass with Sampler in MappableSourceNode
/// Expectation: Skip node is pushed down and removed after optimization pass


Loading…
Cancel
Save