Browse Source

!15440 fix opencl ut bugs

From: @yeyunpeng2020
Reviewed-by: @ddwsky,@HilbertDavid
Signed-off-by: @ddwsky
pull/15440/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
8b9f84447d
31 changed files with 358 additions and 286 deletions
  1. +1
    -0
      mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl
  2. +2
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc
  3. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc
  4. +3
    -3
      mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc
  5. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h
  6. +11
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc
  7. +1
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h
  8. +0
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc
  9. +0
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h
  10. +72
    -0
      mindspore/lite/test/run_ut_gpu.sh
  11. +1
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc
  12. +0
    -38
      mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
  13. +1
    -1
      mindspore/lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc
  14. +12
    -2
      mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc
  15. +5
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/common.h
  16. +1
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc
  17. +1
    -1
      mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc
  18. +1
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
  19. +4
    -4
      mindspore/lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc
  20. +0
    -19
      mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc
  21. +21
    -19
      mindspore/lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc
  22. +1
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc
  23. +3
    -2
      mindspore/lite/test/ut/src/runtime/kernel/opencl/shape_tests.cc
  24. +23
    -12
      mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc
  25. +1
    -1
      mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc
  26. +2
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/split_tests.cc
  27. +12
    -52
      mindspore/lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc
  28. +112
    -15
      mindspore/lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc
  29. +0
    -105
      mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc
  30. +15
    -5
      mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc
  31. +50
    -0
      mindspore/lite/test/ut_gpu.cfg

+ 1
- 0
mindspore/lite/src/runtime/kernel/opencl/cl/to_format.cl View File

@@ -79,4 +79,5 @@ IMG_to_BUF(float16, float32, half, float, read_imageh);
IMG_to_BUF(float16, float16, half, half, read_imageh);
IMG_to_BUF(int32, int32, int, int, read_imagei);
IMG_to_BUF(uint32, uint32, int, int, read_imagei);
IMG_to_BUF(int32, float32, int, float, read_imagei);
IMG_to_BUF(int8, int8, char, char, read_imagei);

+ 2
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/argminmax.cc View File

@@ -40,7 +40,8 @@ int ArgMinMaxOpenCLKernel::CheckSpecs() {
}
if ((in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) ||
(out_tensors_[0]->data_type() != kNumberTypeFloat32 && out_tensors_[0]->data_type() != kNumberTypeFloat16)) {
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type();
MS_LOG(ERROR) << "Unsupported input/output data type. input data type is " << in_tensors_[0]->data_type()
<< " output data type is " << out_tensors_[0]->data_type();
return RET_ERROR;
}
if (in_tensors_[0]->shape().size() > 4 && in_tensors_[0]->shape().size() == 0) {


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc View File

@@ -35,7 +35,7 @@ int BatchNormOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.at(0)->shape().size() == 4) {
if (in_tensors_.at(0)->shape().size() != 4) {
MS_LOG(ERROR) << "The dim of in_tensors->shape must be 4 but your dim is : " << in_tensors_.at(0)->shape().size();
return RET_ERROR;
}


+ 3
- 3
mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc View File

@@ -37,7 +37,7 @@ int FillOpenCLKernel::RunFill() {
auto param = reinterpret_cast<FillParameter *>(this->op_parameter_);
default_ = param->num_dims_;
ImageSize img_size;
cl_float4 fill_value = {};
cl_int4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
@@ -51,11 +51,11 @@ int FillOpenCLKernel::RunFill() {
int FillOpenCLKernel::RunShape() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto src_data = out_tensors_[0]->data_c();
cl_float4 fill_value = {default_, default_, default_, default_};
cl_int4 fill_value = {default_, default_, default_, default_};
auto tensor_shape = in_tensors_[0]->shape();
void *tensor_shape_data = tensor_shape.data();
for (int i = 0; i < tensor_shape.size(); ++i) {
fill_value.s[i] = reinterpret_cast<float *>(tensor_shape_data)[i];
fill_value.s[i] = reinterpret_cast<int *>(tensor_shape_data)[i];
}
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h View File

@@ -39,7 +39,7 @@ class FillOpenCLKernel : public OpenCLKernel {
private:
int RunFill();
int RunShape();
float default_{0.0f};
int default_{0};
};

} // namespace mindspore::kernel


+ 11
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/gather.cc View File

@@ -193,6 +193,17 @@ int GatherOpenCLKernel::InitWeights() {
return RET_OK;
}

int GatherOpenCLKernel::PreProcess() {
if (!op_parameter_->infer_flag_) {
auto indices_tensor = in_tensors_[1];
if (!indices_tensor->IsConst()) {
ocl_runtime_->SyncCommandQueue();
indices_tensor->MutableData();
}
}
return OpenCLKernel::PreProcess();
}

int GatherOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
if (intensor1_is_tensor) {


+ 1
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/gather.h View File

@@ -32,6 +32,7 @@ class GatherOpenCLKernel : public OpenCLKernel {
int Run() override;
int InitWeights() override;
int Prepare() override;
int PreProcess() override;

int CheckSpecs() override;
void SetConstArgs() override;


+ 0
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc View File

@@ -48,8 +48,6 @@ int StackOpenCLKernel::RunAxis0() {
return RET_OK;
}

int StackOpenCLKernel::ReSize() { return RET_OK; }

void StackGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 4, max_y = 8;


+ 0
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/stack.h View File

@@ -33,8 +33,6 @@ class StackOpenCLKernel : public OpenCLKernel {
void SetConstArgs() override;
void SetGlobalLocal() override;

int ReSize() override;

int Run() override;

private:


+ 72
- 0
mindspore/lite/test/run_ut_gpu.sh View File

@@ -0,0 +1,72 @@
#!/bin/bash

basepath=$(pwd)
echo ${basepath}

# Example:sh run_benchmark_nets.sh -r /home/temp_test -d "8KE5T19620002408"
while getopts "r:d:" opt; do
case ${opt} in
r)
release_path=${OPTARG}
echo "release_path is ${OPTARG}"
;;
d)
device_id=${OPTARG}
echo "device_id is ${OPTARG}"
;;
?)
echo "unknown para"
exit 1;;
esac
done

ut_test_path=${basepath}/ut_test
rm -rf ${ut_test_path}
mkdir -p ${ut_test_path}

run_ut_result_file=${basepath}/run_benchmark_result.txt
echo ' ' > ${run_ut_result_file}
run_gpu_ut_log_file=${basepath}/run_gpu_ut_log.txt
echo 'run gpu ut logs: ' > ${run_gpu_ut_log_file}

ut_gpu_config=${basepath}/ut_gpu.cfg

function Run_gpu_ut() {
cd ${release_path} || exit 1

cp -a ${release_path}/lite-test ${ut_test_path}/lite-test || exit 1
cp -r ${basepath}/ut/src/runtime/kernel/opencl/test_data ${ut_test_path} || exit 1

# adb push all needed files to the phone
adb -s ${device_id} push ${ut_test_path} /data/local/tmp/ > adb_push_log.txt

# run adb ,run session ,check the result:
echo 'rm -rf /data/local/tmp/ut_test' > adb_cmd.txt
echo 'cd /data/local/tmp/ut_test' > adb_cmd.txt
echo 'cp /data/local/tmp/libc++_shared.so ./' >> adb_cmd.txt
echo 'cp /data/local/tmp/libgtest.so ./' >> adb_cmd.txt
echo 'chmod 777 lite-test' >> adb_cmd.txt

adb -s ${device_id} shell < adb_cmd.txt

# Run npu converted models:
while read line; do
echo 'cd /data/local/tmp/ut_test' > adb_run_cmd.txt
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/ut_test;./lite-test --gtest_filter='${line} >> "${run_gpu_ut_log_file}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/ut_test;./lite-test --gtest_filter='${line} >> adb_run_cmd.txt
adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_gpu_ut_log_file}"
if [ $? = 0 ]; then
run_result='arm64_gpu_ut: '${line}' pass'; echo ${run_result} >> ${run_ut_result_file}
else
run_result='arm64_gpu_ut: '${line}' failed'; echo ${run_result} >> ${run_ut_result_file}; return 1
fi
done < ${ut_gpu_config}
}

Run_gpu_ut
Run_gpu_ut_status=$?

if [[ $Run_gpu_ut_status == 1 ]]; then
exit 1
fi
exit 0

+ 1
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/argminmax_tests.cc View File

@@ -31,6 +31,7 @@ OpParameter *CreateParameter(schema::PrimitiveType type, int axis, int topk, boo
param->axis_type_ = axis_type;
param->out_value_ = out_value;
param->keep_dims_ = keep_dims;
reinterpret_cast<OpParameter *>(param)->infer_flag_ = true;
return reinterpret_cast<OpParameter *>(param);
}
} // namespace


+ 0
- 38
mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc View File

@@ -157,44 +157,6 @@ TEST_F(TestOpenCL_Arithmetic, FloorMod) {
}
}

TEST_F(TestOpenCL_Arithmetic, FloorModFile) {
std::vector<int> input0_shape = {1, 3, 4, 5};
std::vector<int> input1_shape = {1, 3, 4, 5};
std::vector<int> output_shape = {1, 3, 4, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/FloodModfp32_input1.bin";
std::string input2Ppath = "./test_data/FloodModfp32_input2.bin";
std::string correctOutputPath = "./test_data/FloodModfp32_output.bin";
auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));

for (auto fp16_enable : {true}) {
auto *param = CreateParameter(schema::PrimitiveType_FloorMod, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-7);
}
}

TEST_F(TestOpenCL_Arithmetic, SquaredDifference) {
std::vector<int> input0_shape = {1, 512, 1, 5};
std::vector<int> input1_shape = {1, 1, 1, 5};
std::vector<int> output_shape = {1, 512, 1, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/SquaredDifferencefp32_input1.bin";
std::string input2Ppath = "./test_data/SquaredDifferencefp32_input2.bin";
std::string correctOutputPath = "./test_data/SquaredDifferencefp32_output.bin";
auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));

for (auto fp16_enable : {true}) {
auto *param = CreateParameter(schema::PrimitiveType_SquaredDifference, input0_shape, input1_shape);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-9);
}
}

TEST_F(TestOpenCL_Arithmetic, ElementwiseDiv) {
std::vector<int> input0_shape = {1, 2, 2, 3};
std::vector<int> input1_shape = {1, 2, 2, 3};


+ 1
- 1
mindspore/lite/test/ut/src/runtime/kernel/opencl/batch_to_space_nd_tests.cc View File

@@ -25,7 +25,7 @@ namespace {
// PrimitiveType_BatchToSpaceND: src/ops/populate/batch_to_space_populate.cc
OpParameter *CreateParameter(int block_shape[], int crops[], const std::vector<int> &input_shape,
std::vector<int> *output_shape) {
auto *param = test::CreateParameter<BatchToSpaceParameter>(schema::PrimitiveType_BatchToSpaceND);
auto *param = test::CreateParameter<BatchToSpaceParameter>(schema::PrimitiveType_BatchToSpace);
memcpy(param->block_shape_, block_shape, sizeof(param->block_shape_));
memcpy(param->crops_, crops, sizeof(param->crops_));
*output_shape = {input_shape[0] / param->block_shape_[0] / param->block_shape_[1],


+ 12
- 2
mindspore/lite/test/ut/src/runtime/kernel/opencl/common.cc View File

@@ -38,7 +38,7 @@ void TestMain(const std::vector<ArgsTuple> &input_infos, const std::vector<ArgsT
TestMain(input_infos_new, output_info, op_parameter, fp16_enable, atol, rtol, print_data);
}

void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOut> &output_info,
void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOutWithDType> &output_info,
OpParameter *op_parameter, bool fp16_enable, float atol, float rtol, bool print_data) {
auto primitive_type = static_cast<schema::PrimitiveType>(op_parameter->type_);
#ifdef ENABLE_V0
@@ -71,7 +71,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec
}
for (auto outout_info : output_info) {
const std::vector<int> &output_shape = std::get<0>(outout_info);
out_tensors.emplace_back(std::make_shared<Tensor>(kNumberTypeFloat32, output_shape, Format_NHWC, VAR));
out_tensors.emplace_back(std::make_shared<Tensor>(std::get<2>(outout_info), output_shape, Format_NHWC, VAR));
}
// secondly, init weight Tensor's data
std::vector<Tensor *> kernel_inputs;
@@ -180,6 +180,16 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec
}
delete sub_graph;
}
void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOut> &output_info,
OpParameter *op_parameter, bool fp16_enable, float atol, float rtol, bool print_data) {
std::vector<ArgsTupleOutWithDType> output_info_new;
auto transform_fun = [](ArgsTupleOut in) -> ArgsTupleOutWithDType {
return ArgsTupleOutWithDType(std::get<0>(in), std::get<1>(in), kNumberTypeFloat32);
};
std::transform(output_info.begin(), output_info.end(), std::back_inserter(output_info_new), transform_fun);

TestMain(input_infos, output_info_new, op_parameter, fp16_enable, atol, rtol, print_data);
}

// single-output
void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std::vector<int>, float *> output_info,


+ 5
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/common.h View File

@@ -32,6 +32,7 @@
using Tensor = mindspore::lite::Tensor;
using ArgsTuple = std::tuple<std::vector<int>, void *, Tensor::Category>;
using ArgsTupleOut = std::tuple<std::vector<int>, void *>;
using ArgsTupleOutWithDType = std::tuple<std::vector<int>, void *, mindspore::TypeId>;
using ArgsTupleWithDtype = std::tuple<std::vector<int>, void *, Tensor::Category, mindspore::TypeId>;
constexpr Tensor::Category VAR = Tensor::VAR;
constexpr Tensor::Category CONST_TENSOR = Tensor::Category::CONST_TENSOR;
@@ -94,6 +95,10 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec
OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9,
bool print_output = false);

void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOutWithDType> &output_info,
OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9,
bool print_output = false);

void TestMain(const std::vector<ArgsTuple> &input_infos, const std::vector<ArgsTupleOut> &output_info,
OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9,
bool print_output = false);


+ 1
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_tests.cc View File

@@ -127,6 +127,7 @@ TEST_F(TestOpenCL_Conv2D, test3_batch2) {
TestMain_Conv2D(attr, input_data, weight_data, bias_data, output_data, ActType_No, true, 1e-6f);
}

// Check and optimize
TEST_F(TestOpenCL_Conv2D, test4) {
std::vector<std::tuple<std::string, std::string, std::vector<float>, std::vector<float>, std::vector<float>,
std::vector<float>, ActType>>


+ 1
- 1
mindspore/lite/test/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc View File

@@ -19,7 +19,7 @@
namespace mindspore::lite::opencl::test {

class TestOpenCL_Conv2dTranspose : public CommonTest {};
// Check and optimize
namespace {
// PrimitiveType_DeConv2D: src/ops/populate/deconv2d_populate.cc
OpParameter *CreateParameter(int n, int h, int w, int ci, int co, int kh, int kw, std::vector<int> pad, int oh, int ow,


+ 1
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc View File

@@ -21,6 +21,7 @@ namespace mindspore::lite::opencl::test {
class TestOpenCL_DepthwiseConv2d : public CommonTest {};
namespace {
// Check and optimize
// PrimitiveType_DepthwiseConv2D: src/ops/populate/depthwise_conv2d_populate.cc
OpParameter *CreateParameter(int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_u, int pad_d, int pad_l,
int pad_r, int dilation_h, int dilation_w, ActType act_type, int input_channel) {


+ 4
- 4
mindspore/lite/test/ut/src/runtime/kernel/opencl/layer_norm_tests.cc View File

@@ -41,10 +41,10 @@ TEST_F(TestOpenCL_LayerNorm, test1) {
std::vector<int> beta_shape = {1, 1, 1, 5};
std::vector<int> output_shape = {2, 3, 4, 5};
size_t input_size, gamma_size, beta_size, output_size;
std::string inputPpath = "./test_data/layernormfp32_input.bin";
std::string gammaPpath = "./test_data/gammafp32_input.bin";
std::string betaPpath = "./test_data/betafp32_input.bin";
std::string correctOutputPath = "./test_data/layernormfp32_output.bin";
std::string inputPpath = "./test_data/layer_norm/test1/layernormfp32_input.bin";
std::string gammaPpath = "./test_data/layer_norm/test1/gammafp32_input.bin";
std::string betaPpath = "./test_data/layer_norm/test1/betafp32_input.bin";
std::string correctOutputPath = "./test_data/layer_norm/test1/layernormfp32_output.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size));
auto gamma_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(gammaPpath.c_str(), &gamma_size));
auto beta_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(betaPpath.c_str(), &beta_size));


+ 0
- 19
mindspore/lite/test/ut/src/runtime/kernel/opencl/matmul_tests.cc View File

@@ -32,25 +32,6 @@ OpParameter *CreateParameter(bool a_transpose = false, bool b_transpose = true)
}
} // namespace

TEST_F(TestOpenCL_MatMul, 2Dfile) {
std::vector<int> input_shape = {64, 64};
std::vector<int> output_shape = {64, 64};
std::vector<int> weight_shape = {64, 64};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/matmulfp32_input1.bin";
std::string input2Ppath = "./test_data/matmulfp32_input2.bin";
std::string correctOutputPath = "./test_data/matmulfp32_output.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));

for (auto fp16_enable : {false}) {
auto *param = CreateParameter(false, false);
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-3 : 1e-3);
}
}

TEST_F(TestOpenCL_MatMul, 2D) {
int ci = 5;
int co = 3;


+ 21
- 19
mindspore/lite/test/ut/src/runtime/kernel/opencl/pad_tests.cc View File

@@ -22,18 +22,10 @@ class TestOpenCL_Pad : public CommonTest {};

namespace {
// PrimitiveType_Pad: src/ops/populate/pad_populate.cc
OpParameter *CreateParameter(const std::vector<int> &paddings, float constant_value) {
OpParameter *CreateParameter(float constant_value) {
auto *param = test::CreateParameter<PadParameter>(schema::PrimitiveType_PadFusion);
param->pad_mode_ = schema::PaddingMode_CONSTANT;
param->constant_value_ = constant_value;
param->padding_length = MAX_PAD_SIZE;
int size = paddings.size();
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
param->paddings_[i] = 0;
}
for (size_t i = 0; i < size; i++) {
param->paddings_[MAX_PAD_SIZE - size + i] = paddings[i];
}
return reinterpret_cast<OpParameter *>(param);
}
} // namespace
@@ -42,8 +34,10 @@ TEST_F(TestOpenCL_Pad, 1D) {
float input_data[] = {1, 1, 1, 1};
float output_data[] = {2, 2, 2, 1, 1, 1, 1, 2, 2};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({3, 2}, 2);
TestMain({{{4}, input_data, VAR}}, {{9}, output_data}, param, fp16_enable);
auto *param = CreateParameter(2);
int padding[] = {3, 2};
TestMain({{{4}, input_data, VAR, kNumberTypeFloat32}, {{1, 2}, padding, CONST_TENSOR, kNumberTypeInt32}},
{{9}, output_data}, param, fp16_enable);
}
}

@@ -52,8 +46,10 @@ TEST_F(TestOpenCL_Pad, 2D) {
float output_data[] = {10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 1, 1, 1, 1, 10, 10,
10, 2, 2, 2, 2, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({1, 1, 1, 2}, 10);
TestMain({{{2, 5}, input_data, VAR}}, {{4, 8}, output_data}, param, fp16_enable);
int padding[] = {1, 1, 1, 2};
auto *param = CreateParameter(10);
TestMain({{{2, 5}, input_data, VAR, kNumberTypeFloat32}, {{2, 2}, padding, CONST_TENSOR, kNumberTypeInt32}},
{{4, 8}, output_data}, param, fp16_enable);
}
}

@@ -73,8 +69,10 @@ TEST_F(TestOpenCL_Pad, 4D) {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({0, 0, 3, 3, 3, 3, 0, 0}, 0);
TestMain({{{1, 4, 4, 3}, input_data, VAR}}, {{1, 10, 10, 3}, output_data}, param, fp16_enable);
auto *param = CreateParameter(0);
int padding[] = {0, 0, 3, 3, 3, 3, 0, 0};
TestMain({{{1, 4, 4, 3}, input_data, VAR, kNumberTypeFloat32}, {{4, 2}, padding, CONST_TENSOR, kNumberTypeInt32}},
{{1, 10, 10, 3}, output_data}, param, fp16_enable);
}

float output_data1[] = {
@@ -89,8 +87,10 @@ TEST_F(TestOpenCL_Pad, 4D) {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({0, 0, 3, 3, 3, 3, 0, 0}, 1);
TestMain({{{1, 4, 4, 3}, input_data, VAR}}, {{1, 10, 10, 3}, output_data1}, param, fp16_enable);
auto *param = CreateParameter(1);
int padding[] = {0, 0, 3, 3, 3, 3, 0, 0};
TestMain({{{1, 4, 4, 3}, input_data, VAR, kNumberTypeFloat32}, {{4, 2}, padding, CONST_TENSOR, kNumberTypeInt32}},
{{1, 10, 10, 3}, output_data1}, param, fp16_enable);
}
}

@@ -224,8 +224,10 @@ TEST_F(TestOpenCL_Pad, test0) {
auto constant_value = std::get<6>(case_);
std::cout << name << std::endl;
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(paddings, constant_value);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
auto *param = CreateParameter(constant_value);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(paddings.size() / 2), 2}, paddings.data(), CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}
}


+ 1
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/prelu_tests.cc View File

@@ -21,6 +21,7 @@ namespace mindspore::lite::opencl::test {
class TestOpenCL_PRrelu : public CommonTest {};

namespace {
// Check and optimize
// PrimitiveType_PReLU: src/ops/populate/p_relu_populate.cc
OpParameter *CreateParameter() {
auto *param = test::CreateParameter<PReluParameter>(schema::PrimitiveType_PReLUFusion);


+ 3
- 2
mindspore/lite/test/ut/src/runtime/kernel/opencl/shape_tests.cc View File

@@ -30,10 +30,11 @@ TEST_F(TestOpenCL_Shape, test0) {
std::vector<int> input_shape = {2, 4};
std::vector<int> output_shape = {2};
float input_data[] = {-0.4045, -0.0924, -0.617, -0.10114, -0.9893, 0.3342, 2.445, -2.182};
float output_data[] = {2, 4};
int output_data[] = {2, 4};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter();
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}}, {{output_shape, output_data, kNumberTypeInt32}},
param, fp16_enable);
}
}



+ 23
- 12
mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc View File

@@ -22,12 +22,10 @@ class TestOpenCL_Slice : public CommonTest {};

namespace {
// PrimitiveType_Slice: src/ops/populate/slice_populate.cc
OpParameter *CreateParameter(const std::vector<int> &begin, const std::vector<int> &size) {
OpParameter *CreateParameter(const std::vector<int> &axis) {
auto *param = test::CreateParameter<SliceParameter>(schema::PrimitiveType_SliceFusion);
param->param_length_ = begin.size();
for (int i = 0; i < begin.size(); ++i) {
param->begin_[i] = begin[i];
param->size_[i] = size[i];
for (int i = 0; i < axis.size(); ++i) {
param->axis_[i] = axis[i];
}
return reinterpret_cast<OpParameter *>(param);
}
@@ -42,10 +40,16 @@ TEST_F(TestOpenCL_Slice, 4D) {
float output_data[] = {-0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958, -1.2808367, 0.1470597,
0.03393711, -0.33282498, -1.0433807, 0.28965706, 0.5343769, 0.75480366, -1.9328151,
-0.48714373, -0.14000037, -0.080552, 0.95056856, -0.06886655, 0.5316237};
auto param = CreateParameter({0, 0, 0, 2}, {1, 2, 2, 5});
TestMain({{{1, 2, 2, 8}, input_data, VAR}}, {{1, 2, 2, 5}, output_data}, param, false);
auto param = CreateParameter({0, 1, 2, 3});
std::vector<int> begin = {0, 0, 0, 2};
std::vector<int> size = {1, 2, 2, 5};
TestMain({{{1, 2, 2, 8}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(size.size())}, size.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 2, 2, 5}, output_data}, param, false);
}

// Check and optimize(fp16)
TEST_F(TestOpenCL_Slice, test0) {
std::vector<std::tuple<std::string, std::vector<int>, std::vector<int>, std::vector<float>, std::vector<float>,
std::vector<int>, std::vector<int>>>
@@ -148,11 +152,18 @@ TEST_F(TestOpenCL_Slice, test0) {
auto &size = std::get<6>(case_);

std::cout << name << std::endl;
auto *param = CreateParameter(begin, size);
TestMain({{input_shape, input_data.data(), VAR}}, {output_shape, output_data.data()}, param, false);
param = CreateParameter(begin, size);
TestMain({{input_shape, input_data.data(), VAR}}, {output_shape, output_data.data()}, param, true);
std::vector<int> axis(input_shape.size());
for (int i = 0; i < input_shape.size(); ++i) {
axis[i] = i;
}
auto *param = CreateParameter(axis);
for (auto fp16_enable : {false}) {
TestMain({{input_shape, input_data.data(), VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(size.size())}, size.data(), CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data.data()}, param, fp16_enable);
}
}
} // namespace mindspore
}

} // namespace mindspore::lite::opencl::test

+ 1
- 1
mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc View File

@@ -19,7 +19,7 @@
namespace mindspore::lite::opencl::test {

class TestOpenCL_SparseToDense : public CommonTest {};
// Check and optimize
namespace {
// PrimitiveType_SparseToDense: src/ops/populate/sparse_to_dense_populate.cc
OpParameter *CreateParameter() {


+ 2
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/split_tests.cc View File

@@ -26,6 +26,7 @@ OpParameter *CreateParameter(int split_dim_, int num_split_, std::vector<int> sp
auto *param = test::CreateParameter<SplitParameter>(schema::PrimitiveType_Split);
param->split_dim_ = split_dim_;
param->num_split_ = num_split_;
param->split_count_ = num_split_;
param->split_sizes_ = reinterpret_cast<int *>(malloc(param->num_split_ * sizeof(int)));
for (int i = 0; i < param->num_split_; ++i) {
param->split_sizes_[i] = split_sizes_[i];
@@ -34,6 +35,7 @@ OpParameter *CreateParameter(int split_dim_, int num_split_, std::vector<int> sp
}
} // namespace

// Check and optimize(No data file)
TEST_F(TestOpenCL_Split, input2_axis3) {
std::vector<int> input_shape = {2, 2, 2, 12};
std::vector<int> output_shape1 = {2, 2, 2, 6};


+ 12
- 52
mindspore/lite/test/ut/src/runtime/kernel/opencl/stack_tests.cc View File

@@ -72,9 +72,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis1) {
std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 5}, {3, 4, 5}};
std::vector<int> output_shape = {3, 2, 4, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stackfp32_output.bin";
std::string input1Ppath = "./test_data/stack/input2_ndim3_axis1/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stack/input2_ndim3_axis1/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis1/stackfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
@@ -91,9 +91,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis2) {
std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 5}, {3, 4, 5}};
std::vector<int> output_shape = {3, 4, 2, 5};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stackfp32_output.bin";
std::string input1Ppath = "./test_data/stack/input2_ndim3_axis2/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stack/input2_ndim3_axis2/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis2/stackfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
@@ -110,9 +110,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim2_axis2) {
std::vector<int> input_shapes[INPUT_NUM] = {{1, 96}, {1, 96}};
std::vector<int> output_shape = {1, 96, 2};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stackfp32_output.bin";
std::string input1Ppath = "./test_data/stack/input2_ndim2_axis2/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stack/input2_ndim2_axis2/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stack/input2_ndim2_axis2/stackfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
@@ -129,9 +129,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis3) {
std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 6}, {3, 4, 6}};
std::vector<int> output_shape = {3, 4, 6, 2};
size_t input1_size, input2_size, output_size;
std::string input1Ppath = "./test_data/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stackfp32_output.bin";
std::string input1Ppath = "./test_data/stack/input2_ndim3_axis3/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stack/input2_ndim3_axis3/stackfp32_input2.bin";
std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis3/stackfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
@@ -142,44 +142,4 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis3) {
}
}

TEST_F(TestOpenCL_Stack, input6_ndim3_axis0) {
constexpr int INPUT_NUM = 8;
int axis = 0;
std::vector<int> input_shapes[INPUT_NUM] = {{1, 17, 18}, {1, 17, 18}, {1, 17, 18}, {1, 17, 18},
{1, 17, 18}, {1, 17, 18}, {1, 17, 18}, {1, 17, 18}};
std::vector<int> output_shape = {8, 1, 17, 18};
size_t input1_size, input2_size, input3_size, input4_size, input5_size, input6_size, input7_size, input8_size,
output_size;
std::string input1Ppath = "./test_data/stackfp32_input1.bin";
std::string input2Ppath = "./test_data/stackfp32_input2.bin";
std::string input3Ppath = "./test_data/stackfp32_input3.bin";
std::string input4Ppath = "./test_data/stackfp32_input4.bin";
std::string input5Ppath = "./test_data/stackfp32_input5.bin";
std::string input6Ppath = "./test_data/stackfp32_input6.bin";
std::string input7Ppath = "./test_data/stackfp32_input7.bin";
std::string input8Ppath = "./test_data/stackfp32_input8.bin";
std::string correctOutputPath = "./test_data/stackfp32_output.bin";
auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size));
auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size));
auto input_data3 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size));
auto input_data4 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size));
auto input_data5 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input5Ppath.c_str(), &input5_size));
auto input_data6 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input6Ppath.c_str(), &input6_size));
auto input_data7 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input7Ppath.c_str(), &input7_size));
auto input_data8 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input8Ppath.c_str(), &input8_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
for (auto fp16_enable : {true}) {
auto *param = CreateParameter(axis);
TestMain({{input_shapes[0], input_data1, VAR},
{input_shapes[1], input_data2, VAR},
{input_shapes[2], input_data3, VAR},
{input_shapes[3], input_data4, VAR},
{input_shapes[4], input_data5, VAR},
{input_shapes[5], input_data6, VAR},
{input_shapes[6], input_data7, VAR},
{input_shapes[7], input_data8, VAR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
}
}

} // namespace mindspore::lite::opencl::test

+ 112
- 15
mindspore/lite/test/ut/src/runtime/kernel/opencl/strided_slice_tests.cc View File

@@ -41,7 +41,14 @@ TEST_F(TestOpenCL_StridedSlice, 1D) {
float output_data[] = {3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({3}, {36}, {3});
TestMain({{{36}, input_data, VAR}}, {{11}, output_data}, param, fp16_enable);
std::vector<int> begin = {3};
std::vector<int> end = {36};
std::vector<int> stride = {3};
TestMain({{{36}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{11}, output_data}, param, fp16_enable);
}
}

@@ -50,8 +57,15 @@ TEST_F(TestOpenCL_StridedSlice, 2D) {
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
float output_data[] = {11, 14};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {1, 2};
std::vector<int> end = {3, 8};
std::vector<int> stride = {2, 3};
auto *param = CreateParameter({1, 2}, {3, 8}, {2, 3});
TestMain({{{4, 9}, input_data, VAR}}, {{1, 2}, output_data}, param, fp16_enable);
TestMain({{{4, 9}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 2}, output_data}, param, fp16_enable);
}
}

@@ -61,7 +75,14 @@ TEST_F(TestOpenCL_StridedSlice, 3D) {
float output_data[] = {11, 14};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({0, 1, 2}, {1, 3, 8}, {1, 2, 3});
TestMain({{{1, 4, 9}, input_data, VAR}}, {{1, 1, 2}, output_data}, param, fp16_enable);
std::vector<int> begin = {0, 1, 2};
std::vector<int> end = {1, 3, 8};
std::vector<int> stride = {1, 2, 3};
TestMain({{{1, 4, 9}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 1, 2}, output_data}, param, fp16_enable);
}
}

@@ -72,37 +93,79 @@ TEST_F(TestOpenCL_StridedSlice, 4D) {
float output_data0[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {0, 0, 0, 0};
std::vector<int> end = {2, 2, 3, 3};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({0, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 3, 3}, output_data0}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{2, 2, 3, 3}, output_data0}, param, fp16_enable);
}

for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {0, 0, 0, 0};
std::vector<int> end = {2, 2, 3, 3};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({0, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 3, 3}, output_data0}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{2, 2, 3, 3}, output_data0}, param, fp16_enable);
}

float output_data1[] = {18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {1, 0, 0, 0};
std::vector<int> end = {2, 2, 3, 3};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({1, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 2, 3, 3}, output_data1}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 2, 3, 3}, output_data1}, param, fp16_enable);
}

float output_data2[] = {27, 28, 29, 30, 31, 32, 33, 34, 35};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {1, 1, 0, 0};
std::vector<int> end = {2, 2, 3, 3};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({1, 1, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 3, 3}, output_data2}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 1, 3, 3}, output_data2}, param, fp16_enable);
}

float output_data3[] = {33, 34, 35};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {1, 1, 2, 0};
std::vector<int> end = {2, 2, 3, 3};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({1, 1, 2, 0}, {2, 2, 3, 3}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 1, 3}, output_data3}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 1, 1, 3}, output_data3}, param, fp16_enable);
}

float output_data4[] = {34};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {1, 1, 2, 1};
std::vector<int> end = {2, 2, 3, 2};
std::vector<int> stride = {1, 1, 1, 1};
auto *param = CreateParameter({1, 1, 2, 1}, {2, 2, 3, 2}, {1, 1, 1, 1});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 1, 1}, output_data4}, param, fp16_enable);
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 1, 1, 1}, output_data4}, param, fp16_enable);
}
}

@@ -111,8 +174,15 @@ TEST_F(TestOpenCL_StridedSlice, 4D_stride2) {
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
float output_data[] = {13, 14, 31, 32};
for (auto fp16_enable : {false, true}) {
std::vector<int> begin = {0, 1, 1, 1};
std::vector<int> end = {1, 4, 3, 3};
std::vector<int> stride = {2, 2, 2, 1};
auto *param = CreateParameter({0, 1, 1, 1}, {1, 4, 3, 3}, {2, 2, 2, 1});
TestMain({{{1, 4, 3, 3}, input_data, VAR}}, {{1, 2, 1, 2}, output_data}, param, fp16_enable);
TestMain({{{1, 4, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 2, 1, 2}, output_data}, param, fp16_enable);
}
}

@@ -122,19 +192,35 @@ TEST_F(TestOpenCL_StridedSlice, 4D_to_3D) {
float output_data[] = {18, 20, 21, 23, 27, 29, 30, 32};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({1, 0, 0, 0}, {2, 2, 2, 3}, {1, 1, 1, 2});
TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 2}, output_data}, param, fp16_enable);
std::vector<int> begin = {1, 0, 0, 0};
std::vector<int> end = {2, 2, 2, 3};
std::vector<int> stride = {1, 1, 1, 2};
TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{2, 2, 2}, output_data}, param, fp16_enable);
}
}

// Check and optimize
TEST_F(TestOpenCL_StridedSlice, In1D_OutOfRangeBeginNegativeStride) {
float input_data[] = {1, 2, 3, 4};
float output_data[] = {4, 3, 2};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({5}, {0}, {-1});
TestMain({{{4}, input_data, VAR}}, {{3}, output_data}, param, fp16_enable);
std::vector<int> begin = {5};
std::vector<int> end = {0};
std::vector<int> stride = {-1};
TestMain({{{4}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{3}, output_data}, param, fp16_enable);
}
}

// Check and optimize
TEST_F(TestOpenCL_StridedSlice, test0) {
std::vector<float> values(32768);
for (int i = 0; i < values.size(); ++i) {
@@ -320,7 +406,12 @@ TEST_F(TestOpenCL_StridedSlice, test0) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(begin, end, stride);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
param->infer_flag_ = true;
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}
}
@@ -332,8 +423,14 @@ TEST_F(TestOpenCL_StridedSlice, test1) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter({0, 1, 0, 1}, {1, 3, 2, 4}, {1, 1, 2, 2});
TestMain({{{1, 3, 2, 4}, input_data, VAR}}, {{1, 2, 1, 2}, output_data}, param, fp16_enable,
fp16_enable ? 1e-2 : 1e-9);
std::vector<int> begin = {0, 1, 0, 1};
std::vector<int> end = {1, 3, 2, 4};
std::vector<int> stride = {1, 1, 2, 2};
TestMain({{{1, 3, 2, 4}, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32},
{{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}},
{{1, 2, 1, 2}, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-9);
}
}



+ 0
- 105
mindspore/lite/test/ut/src/runtime/kernel/opencl/to_format_tests.cc View File

@@ -1,105 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h"

namespace mindspore::lite::opencl::test {
class TestToFormatOpenCL : public CommonTest {
public:
TestToFormatOpenCL() {}
};

TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) {
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
int h = 64;
int w = 1;
int c = 7360;
size_t input_size;
std::string input_path = "./test_data/transpose/transpose_fp32_input.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
if (input_data == nullptr) {
MS_LOG(ERROR) << "input_data load error.";
return;
}
std::vector<int> input_shape = {1, h, w, c};
auto tensor_x_ptr = std::make_unique<lite::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}
std::vector<int> out_shape = {1, c, h, w};
auto tensor_out_ptr = std::make_unique<lite::Tensor>(TypeId(kNumberTypeFloat32), out_shape);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
return;
}
std::vector<lite::Tensor *> inputs{tensor_x};
std::vector<lite::Tensor *> outputs{tensor_out};
auto arith_kernel_ptr = std::make_unique<kernel::ToFormatOpenCLKernel>(nullptr, inputs, outputs, nullptr);
auto arith_kernel = arith_kernel_ptr.get();
if (arith_kernel == nullptr) {
MS_LOG(ERROR) << "arith_kernel create error.";
return;
}
arith_kernel->Init();

inputs[0]->MallocData(allocator);

std::vector<kernel::LiteKernel *> kernels{arith_kernel};
auto pGraph_ptr = std::make_unique<kernel::OpenCLSubGraph>(inputs, outputs, kernels, kernels, kernels);
auto pGraph = pGraph_ptr.get();
if (pGraph == nullptr) {
MS_LOG(ERROR) << "pGraph create error.";
return;
}
pGraph->Init();
memcpy(inputs[0]->data_c(), input_data, input_size);
pGraph->Run();

size_t output_size;
std::string output_path = "./test_data/transpose/transpose_fp32_output.bin";
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
if (correct_data == nullptr) {
MS_LOG(ERROR) << "correct_data create error.";
return;
}
printf("==================output data=================\n");
float *output_data = reinterpret_cast<float *>(tensor_out->data_c());
std::cout << std::endl;
int size_n = h * w * c;
size_n = size_n > 100 ? 100 : size_n;
for (int i = 0; i < size_n; i++) {
std::cout << output_data[i] << " ";
if ((i + 1) % c == 0) {
std::cout << std::endl;
}
}
std::cout << std::endl;

// compare
ASSERT_EQ(0, CompareOutputData(output_data, correct_data, h * w * c, 0.00001));
MS_LOG(INFO) << "Test TransposeFp32 passed";
}
} // namespace mindspore::lite::opencl::test

+ 15
- 5
mindspore/lite/test/ut/src/runtime/kernel/opencl/transpose_tests.cc View File

@@ -46,7 +46,9 @@ TEST_F(TestOpenCL_Transpose, NHWC2NCHW) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(perm);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}

@@ -62,7 +64,9 @@ TEST_F(TestOpenCL_Transpose, NCHW2NHWC) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(perm);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}

@@ -78,7 +82,9 @@ TEST_F(TestOpenCL_Transpose, NHWC2NWHC) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(perm);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}

@@ -94,7 +100,9 @@ TEST_F(TestOpenCL_Transpose, NWC2CWN) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(perm);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}

@@ -112,7 +120,9 @@ TEST_F(TestOpenCL_Transpose, NWC2WNC) {

for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(perm);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32},
{{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}},
{output_shape, output_data}, param, fp16_enable);
}
}
} // namespace mindspore::lite::opencl::test

+ 50
- 0
mindspore/lite/test/ut_gpu.cfg View File

@@ -0,0 +1,50 @@
TestOpenCL_Transpose.*
TestOpenCL_StridedSlice.1D
TestOpenCL_StridedSlice.2D
TestOpenCL_StridedSlice.3D
TestOpenCL_StridedSlice.4D
TestOpenCL_StridedSlice.4D_stride2
TestOpenCL_StridedSlice.4D_to_3D
TestOpenCL_StridedSlice.test1
TestOpenCL_Stack.*
TestOpenCL_Split.input3_axis0
TestOpenCL_DepthToSpace.*
TestOpenCL_SpaceToDepth.*
TestOpenCL_SpaceToBatch.*
TestOpenCL_SoftMax.*
TestOpenCL_Slice.4D
TestOpenCL_Shape.*
TestOpenCL_Scale.*
TestOpenCL_Resize.*
TestOpenCL_Reshape.*
TestOpenCL_Reduce.*
TestOpenCL_Pooling.*
TestOpenCL_Pad.1D
TestOpenCL_Pad.2D
TestOpenCL_Pad.3D
TestOpenCL_Pad.4D
TestOpenCL_OneHot.*
TestOpenCL_MatMul.*
TestOpenCL_LayerNorm.*
TestOpenCL_Gather.*
TestOpenCL_FullConnection.*
TestOpenCL_Conv2D.test0
TestOpenCL_Conv2D.test0_no_bias
TestOpenCL_Conv2D.test1
TestOpenCL_Conv2D.test2
TestOpenCL_Conv2D.test3
TestOpenCL_Conv2D.test3_batch2
TestOpenCL_Concat.*
TestOpenCL_BatchNorm.*
TestOpenCL_BatchToSpaceND.*
TestOpenCL_Arithmetic.ElementwiseAdd
TestOpenCL_Arithmetic.ScalarMul
TestOpenCL_Arithmetic.BroadcastSubReLU6
TestOpenCL_Arithmetic.BroadcastSub2
TestOpenCL_Arithmetic.BroadcastSub3
TestOpenCL_Arithmetic.BroadcastFloorMod
TestOpenCL_Arithmetic.FloorMod
TestOpenCL_Arithmetic.ElementwiseDiv
TestOpenCL_ArithmeticSelf.*
TestOpenCL_ArgMinMax.*
TestOpenCL_Activation.*

Loading…
Cancel
Save