Browse Source

!6464 [MS][LITE][Develop] GPU ops ArithmeticSelf , Slice , BN and concat add CI tests

Merge pull request !6464 from pengyongrong/op_format_toNC4HW4
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e29a7bb33e
10 changed files with 504 additions and 42 deletions
  1. +4
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc
  2. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h
  3. +6
    -0
      mindspore/lite/test/run_test.sh
  4. +7
    -7
      mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc
  5. +103
    -5
      mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc
  6. +132
    -6
      mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc
  7. +8
    -5
      mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc
  8. +116
    -5
      mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc
  9. +0
    -1
      mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
  10. +126
    -9
      mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc

+ 4
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/cast.cc View File

@@ -16,7 +16,7 @@
#include <cstring>
#include <algorithm>
#include <set>
#include<string>
#include <string>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/kernel/cast.h"
#include "src/runtime/kernel/opencl/utils.h"
@@ -49,14 +49,16 @@ int CastOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
return RET_OK;
}

void CastOpenCLKernel::GetKernelName(std::string *kernel_name, CastParameter *param) {
int CastOpenCLKernel::GetKernelName(std::string *kernel_name, CastParameter *param) {
if (param->src_type_ == kNumberTypeFloat32 && param->dst_type_ == kNumberTypeFloat16) {
kernel_name[0] += "_Fp32ToFp16";
} else if (param->src_type_ == kNumberTypeFloat16 && param->dst_type_ == kNumberTypeFloat32) {
kernel_name[0] += "_Fp16ToFp32";
} else {
MS_LOG(ERROR) << "unsupported convert format from : " << param->src_type_ << "to " << param->dst_type_;
return RET_ERROR;
}
return RET_OK;
}

int CastOpenCLKernel::Init() {


+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/cast.h View File

@@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_CAST_H_

#include <vector>
#include<string>
#include <string>
#include "ir/anf.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/fp32/cast.h"
@@ -39,7 +39,7 @@ class CastOpenCLKernel : public OpenCLKernel {

int Run() override;

void GetKernelName(std::string *kernel_name, CastParameter *param);
int GetKernelName(std::string *kernel_name, CastParameter *param);

int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;



+ 6
- 0
mindspore/lite/test/run_test.sh View File

@@ -31,3 +31,9 @@ cp -fr $TEST_DATA_DIR/testPK ./data

# for GPU OpenCL
./lite-test --gtest_filter="TestConvolutionOpenCL.simple_test*"

./lite-test --gtest_filter="TestArithmeticSelfOpenCLCI.ArithmeticSelfRound*"
./lite-test --gtest_filter="TestConcatOpenCLCI.ConcatFp32_2inputforCI*"
./lite-test --gtest_filter="TestSliceOpenCLfp32.Slicefp32CI*"
./lite-test --gtest_filter="TestBatchnormOpenCLCI.Batchnormfp32CI*"


+ 7
- 7
mindspore/lite/test/ut/src/runtime/kernel/opencl/activation_tests.cc View File

@@ -519,17 +519,16 @@ TEST_F(TestActivationOpenCL, LeakyReluFp_dim4) {
delete param;
delete input_tensor;
delete output_tensor;
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}

TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) {
std::string in_file = "/data/local/tmp/test_data/in_tanh.bin";
std::string out_file = "/data/local/tmp/test_data/out_tanh.bin";
std::string in_file = "/data/local/tmp/test_data/in_tanhfp16.bin";
std::string out_file = "/data/local/tmp/test_data/out_tanhfp16.bin";
MS_LOG(INFO) << "Tanh Begin test!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto data_type = kNumberTypeFloat32;
auto data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16);
bool enable_fp16 = ocl_runtime->GetFp16Enable();

@@ -561,7 +560,7 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) {
printf_tensor<float>("Tanh:FP32--input data--", inputs[0]);
}

auto *param = new (std::nothrow) ActivationParameter();
auto param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "New ActivationParameter fail.";
delete input_tensor;
@@ -628,10 +627,11 @@ TEST_F(TestActivationOpenCLTanh, TanhFp_dim4) {
printf_tensor<float>("Tanh:FP32--output data---", outputs[0]);
CompareRes<float>(output_tensor, out_file);
}
delete param;
lite::opencl::OpenCLRuntime::DeleteInstance();
input_tensor->SetData(nullptr);
delete input_tensor;
output_tensor->SetData(nullptr);
delete output_tensor;
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore

+ 103
- 5
mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_self_tests.cc View File

@@ -28,9 +28,14 @@ class TestArithmeticSelfOpenCLfp16 : public mindspore::CommonTest {
TestArithmeticSelfOpenCLfp16() {}
};

class TestArithmeticSelfOpenCLCI : public mindspore::CommonTest {
public:
TestArithmeticSelfOpenCLCI() {}
};

template <typename T>
void CompareOutputData1(T *input_data1, T *output_data, T *correct_data, int size, float err_bound) {
for (size_t i = 0; i < 100; i++) {
for (size_t i = 0; i < size; i++) {
T abs = fabs(output_data[i] - correct_data[i]);
ASSERT_LE(abs, err_bound);
}
@@ -53,7 +58,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) {

MS_LOG(INFO) << " init tensors ";

std::vector<int> shape = {1, 19, 19, 96};
std::vector<int> shape = {1, 2, 2, 144};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
@@ -66,7 +71,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) {
std::vector<lite::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << " initialize param ";
auto param = new (std::nothrow) ArithmeticSelfParameter();
auto param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
@@ -77,7 +82,7 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) {
}
return;
}
param->op_parameter_.type_ = schema::PrimitiveType_Neg;
param->op_parameter_.type_ = schema::PrimitiveType_Round;
auto *arithmeticself_kernel =
new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (arithmeticself_kernel == nullptr) {
@@ -120,13 +125,106 @@ TEST_F(TestArithmeticSelfOpenCLfp16, ArithmeticSelfOpenCLFp16) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}

TEST_F(TestArithmeticSelfOpenCLCI, ArithmeticSelfRound) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
float input_data1[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f,
0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f};
float correctOutput[] = {1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f,
1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f};

MS_LOG(INFO) << " init tensors ";
std::vector<int> shape = {1, 1, 4, 4};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
auto *input_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
if (input_tensor == nullptr || output_tensor == nullptr) {
MS_LOG(INFO) << " new input_tensor or output_tensor failed ";
return;
}
std::vector<lite::Tensor *> inputs{input_tensor};
std::vector<lite::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << " initialize param ";
auto param = reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
return;
}
param->op_parameter_.type_ = schema::PrimitiveType_Round;
auto *arithmeticself_kernel =
new (std::nothrow) kernel::ArithmeticSelfOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (arithmeticself_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::ArithmeticSelfOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
return;
}
arithmeticself_kernel->SetFormatType(schema::Format_NC4HW4);
arithmeticself_kernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{arithmeticself_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete arithmeticself_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
std::cout << sizeof(input_data1) / sizeof(float) << std::endl;
memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1));

std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(input_data1, output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete param;
delete sub_graph;
}
} // namespace mindspore

+ 132
- 6
mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc View File

@@ -31,6 +31,128 @@ class TestBatchnormOpenCLfp16 : public mindspore::CommonTest {
public:
TestBatchnormOpenCLfp16() {}
};
class TestBatchnormOpenCLCI : public mindspore::CommonTest {
public:
TestBatchnormOpenCLCI() {}
};

TEST_F(TestBatchnormOpenCLCI, Batchnormfp32CI) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();

MS_LOG(INFO) << " Read tensors from .bin ";
std::vector<int> input_shape = {1, 2, 2, 8};
std::vector<int> output_shape = {1, 2, 2, 8};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);

float input_data[] = {2.471454, -2.1379554, -0.0904604, 1.2928944, -0.19215967, -0.8677279, -0.12759617,
1.2242758, -0.06398406, -0.4041858, 0.20352598, -2.067808, 0.52113044, -1.567617,
0.28003863, 0.41367245, 0.77298605, 0.29908583, 1.4015813, 1.330567, 1.760135,
0.6320845, 0.6995399, -1.208123, -1.9738104, -1.3283046, 1.022744, 0.02741058,
0.84505165, -0.89434445, 1.983211, -0.5485428};
float correct_data[] = {0.7505676, 0.515882, 0.26147857, 1.6026789, 0.47575232, 0.50116986, 0.33589783,
1.4884706, 0.56019205, 0.7832671, 0.53893626, -0.5093127, 0.71395767, 0.18509413,
0.33990562, 0.891792, 0.6230367, 0.89172685, 1.6696336, 1.6263539, 1.1277269,
1.1784974, 0.34403008, -0.3019984, 0.4167911, 0.6407478, 1.3120956, 0.80740136,
0.8221321, 0.4891496, 0.3566509, 0.18351318};
float mean_data[] = {0.3016613, -0.89284, 0.63434774, 0.145766, 0.73353934, -0.6744012, 0.7087985, -0.02967937};
float var_data[] = {2.5604038, 0.84985304, 0.36261332, 1.9083935, 0.4920925, 0.6476224, 0.6269014, 0.8567283};
float scale_data[] = {0.1201471, 0.142174, 0.5683258, 0.86815494, 0.23426804, 0.3634345, 0.0077846, 0.6813278};
float offset_data[] = {0.58764684, 0.70790595, 0.945536, 0.8817803, 0.78489226, 0.5884778, 0.3441211, 0.5654443};

MS_LOG(INFO) << " construct tensors ";
lite::Tensor *tensor_data = new (std::nothrow) lite::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
lite::Tensor *tensor_mean =
new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::Tensor *tensor_var =
new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::Tensor *tensor_scale =
new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::Tensor *tensor_offset =
new (std::nothrow) lite::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
if (tensor_data == nullptr || tensor_mean == nullptr || tensor_var == nullptr || tensor_scale == nullptr ||
tensor_offset == nullptr) {
MS_LOG(INFO) << " init tensor failed ";
return;
}
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << " init tensor failed ";
delete tensor_data;
delete tensor_mean;
delete tensor_var;
delete tensor_scale;
delete tensor_offset;
return;
}
std::vector<lite::Tensor *> inputs = {tensor_data, tensor_scale, tensor_offset, tensor_mean, tensor_var};
std::vector<lite::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new BatchNormParameter failed ";
for (auto tensor : outputs) {
delete tensor;
}
return;
}
param->epsilon_ = pow(10, -5);
auto *batchnorm_kernel =
new (std::nothrow) kernel::BatchNormOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (batchnorm_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::BatchNorm_kernel failed ";
for (auto tensor : outputs) {
delete tensor;
}
delete param;
return;
}
batchnorm_kernel->Init();

// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}

MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{batchnorm_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete batchnorm_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " init tensors ";
memcpy(inputs[0]->data_c(), input_data, sizeof(input_data));
memcpy(inputs[1]->data_c(), scale_data, sizeof(scale_data));
memcpy(inputs[2]->data_c(), offset_data, sizeof(offset_data));
memcpy(inputs[3]->data_c(), mean_data, sizeof(mean_data));
memcpy(inputs[4]->data_c(), var_data, sizeof(var_data));
std::cout << "==================output data================" << std::endl;
sub_graph->Run();

auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}

TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) {
MS_LOG(INFO) << "begin test";
@@ -42,7 +164,7 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) {
MS_LOG(INFO) << " Read tensors from .bin ";
std::vector<int> input_shape = {1, 256, 256, 48};
std::vector<int> output_shape = {1, 256, 256, 48};
auto data_type = kNumberTypeFloat32;
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);

// get the input from .bin
@@ -90,7 +212,7 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) {
std::vector<lite::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << " initialize tensors ";
auto param = new (std::nothrow) BatchNormParameter();
auto param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new BatchNormParameter failed ";
for (auto tensor : outputs) {
@@ -140,15 +262,18 @@ TEST_F(TestBatchnormOpenCLfp16, Batchnormfp16input_dim4) {

auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.01);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete param;
delete sub_graph;
}

TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@@ -206,7 +331,7 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) {
std::vector<lite::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << " initialize tensors ";
auto param = new (std::nothrow) BatchNormParameter();
auto param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new BatchNormParameter failed ";
for (auto tensor : outputs) {
@@ -256,14 +381,15 @@ TEST_F(TestBatchnormOpenCLfp32, Batchnormfp32input_dim4) {

auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete param;
delete batchnorm_kernel;
delete sub_graph;
}
} // namespace mindspore

+ 8
- 5
mindspore/lite/test/ut/src/runtime/kernel/opencl/cast_tests.cc View File

@@ -48,7 +48,7 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) {
std::string correctOutputPath = "./test_data/out_castfp16.bin";

MS_LOG(INFO) << " initialize param ";
auto param = new (std::nothrow) CastParameter();
auto param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new CastParameter failed ";
return;
@@ -113,14 +113,16 @@ TEST_F(TestCastSelfOpenCL, Castfp32tofp16) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}

TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
@@ -135,7 +137,7 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
std::string correctOutputPath = "./test_data/out_castfp32.bin";

MS_LOG(INFO) << " initialize param ";
auto param = new (std::nothrow) CastParameter();
auto param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new CastParameter failed ";
return;
@@ -199,14 +201,15 @@ TEST_F(TestCastSelfOpenCL, Castfp16tofp32) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore

+ 116
- 5
mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc View File

@@ -32,6 +32,11 @@ class TestConcatOpenCLfp16 : public mindspore::CommonTest {
TestConcatOpenCLfp16() {}
};

class TestConcatOpenCLCI : public mindspore::CommonTest {
public:
TestConcatOpenCLCI() {}
};

template <typename T>
void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) {
for (size_t i = 0; i < size; i++) {
@@ -39,7 +44,109 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou
ASSERT_LE(abs, err_bound);
}
}
TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {

TEST_F(TestConcatOpenCLCI, ConcatFp32_2inputforCI) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();

MS_LOG(INFO) << " init tensors ";
constexpr int INPUT_NUM = 2;
std::array<std::vector<int>, INPUT_NUM> input_shapes = {std::vector<int>{1, 1, 1, 8}, std::vector<int>{1, 1, 1, 8}};
std::vector<int> output_shape = {2, 1, 1, 8};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
float input_data1[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f};
float input_data2[] = {0.5f, 0.6f, 0.74f, 0.23f, 0.46f, 0.69f, 0.13f, 0.47f};
float correctOutput[] = {0.75f, 0.06f, 0.74f, 0.30f, 0.9f, 0.59f, 0.03f, 0.37f,
0.5f, 0.6f, 0.74f, 0.23f, 0.46f, 0.69f, 0.13f, 0.47f};
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << " new output_tensor failed ";
return;
}
std::vector<lite::Tensor *> inputs;
std::vector<lite::Tensor *> outputs{output_tensor};
for (auto &shape : input_shapes) {
auto input_temp = new (std::nothrow) lite::Tensor(data_type, shape, schema::Format_NHWC, tensor_type);
inputs.push_back(input_temp);
if (input_temp == nullptr) {
MS_LOG(INFO) << " new input_tensor failed ";
return;
}
}

MS_LOG(INFO) << " initialize tensors ";
auto param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
return;
}
param->axis_ = 0;
auto *concat_kernel =
new (std::nothrow) kernel::ConcatOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (concat_kernel == nullptr) {
MS_LOG(INFO) << " new kernel::ConcatOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
return;
}
concat_kernel->SetFormatType(schema::Format_NC4HW4);
concat_kernel->Init();
// to do allocate memory for inputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}

MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{concat_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete concat_kernel;
return;
}
sub_graph->Init();
MS_LOG(INFO) << " initialize input data ";
memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1));
memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2));

std::cout << "==================output data================" << std::endl;
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}

TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis1) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->SetFp16Enable(true);
@@ -89,7 +196,7 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
MS_LOG(INFO) << " input_shapes size =: " << input_shapes.size();

MS_LOG(INFO) << " initialize tensors ";
auto param = new (std::nothrow) ConcatParameter();
auto param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
@@ -157,13 +264,15 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_2input_dim4_axis3) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete param;
delete sub_graph;
}

@@ -212,7 +321,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
MS_LOG(INFO) << " input_shapes size=: " << input_shapes.size();

MS_LOG(INFO) << " initialize tensors ";
auto param = new (std::nothrow) ConcatParameter();
auto param = reinterpret_cast<ConcatParameter *>(malloc(sizeof(ConcatParameter)));
if (param == nullptr) {
MS_LOG(INFO) << " new ConcatParameter failed ";
for (auto tensor : inputs) {
@@ -276,13 +385,15 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_2input_dim4_axis3) {
sub_graph->Run();
auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.00001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete param;
delete sub_graph;
}
} // namespace mindspore

+ 0
- 1
mindspore/lite/test/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc View File

@@ -451,7 +451,6 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp16) {
float16_t gnd_data[] = {3.3848767, 1.4446403, 1.8428744, 1.3194335, 2.5873442, 2.1384869, 2.04022, 1.1872686,
2.2294958, 1.6570128, 2.465089, 1.4294086, 2.7941442, 1.7871612, 2.188921, 1.0601988};
lite::opencl::OpenCLRuntime::GetInstance()->SetFp16Enable(true);
DepthWiseTestMain<float16_t, float16_t>(conv_param, input_data, weight_data, gnd_data, schema::Format_NHWC4,
kNumberTypeFloat16, true, 1e-2);
}


+ 126
- 9
mindspore/lite/test/ut/src/runtime/kernel/opencl/slice_tests.cc View File

@@ -40,6 +40,117 @@ void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bou
}
}

TEST_F(TestSliceOpenCLfp32, Slicefp32CI) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();

MS_LOG(INFO) << " Read tensors from .bin ";
std::vector<int> input_shape = {1, 2, 2, 8};
std::vector<int> output_shape = {1, 2, 2, 5};
std::vector<int> begin = {0, 0, 0, 2};
std::vector<int> size = {1, 2, 2, 5};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);

float input_data[] = {-0.45816937, 0.92391545, -0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958,
0.09470133, 0.19801073, 0.04927751, -1.2808367, 0.1470597, 0.03393711, -0.33282498,
-1.0433807, -1.3678077, -0.6423931, 0.5584889, 0.28965706, 0.5343769, 0.75480366,
-1.9328151, -0.48714373, 1.711132, -1.8871949, -0.2987629, -0.14000037, -0.080552,
0.95056856, -0.06886655, 0.5316237, 0.05787678};
float correct_data[] = {-0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958, -1.2808367, 0.1470597,
0.03393711, -0.33282498, -1.0433807, 0.28965706, 0.5343769, 0.75480366, -1.9328151,
-0.48714373, -0.14000037, -0.080552, 0.95056856, -0.06886655, 0.5316237};
MS_LOG(INFO) << " construct tensors ";
lite::Tensor *tensor_data = new (std::nothrow) lite::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
if (tensor_data == nullptr) {
MS_LOG(INFO) << " init tensor failed ";
return;
}
auto *output_tensor = new (std::nothrow) lite::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type);
if (output_tensor == nullptr) {
delete tensor_data;
MS_LOG(INFO) << " init tensor failed ";
return;
}
std::vector<lite::Tensor *> inputs = {tensor_data};
std::vector<lite::Tensor *> outputs = {output_tensor};

MS_LOG(INFO) << "setting SliceParameter ";
auto param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter)));
if (param == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
MS_LOG(INFO) << "new SliceParameter failed ";
return;
}
for (int i = 0; i < input_shape.size(); i++) {
param->begin_[i] = begin[i];
param->size_[i] = size[i];
}

auto *slice_kernel =
new (std::nothrow) kernel::SliceOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (slice_kernel == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
MS_LOG(INFO) << "new kernel::slice_kernel failed ";
return;
}
slice_kernel->Init();

// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}

MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{slice_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
for (auto tensor : inputs) {
delete tensor;
}
for (auto tensor : outputs) {
delete tensor;
}
delete param;
delete slice_kernel;
MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed ";
return;
}
sub_graph->Init();

MS_LOG(INFO) << " init tensors ";
memcpy(inputs[0]->data_c(), input_data, sizeof(input_data));

std::cout << "==================output data================" << std::endl;
sub_graph->Run();

auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}

TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@@ -49,7 +160,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) {
MS_LOG(INFO) << " Read tensors from .bin ";
std::vector<int> input_shape = {1, 19, 19, 96};
std::vector<int> output_shape = {1, 10, 10, 13};
std::vector<int> begin = {0, 2, 3, 3};
std::vector<int> begin = {0, 2, 3, 4};
std::vector<int> size = {1, 10, 10, 13};
auto data_type = kNumberTypeFloat32;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);
@@ -76,7 +187,7 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) {
std::vector<lite::Tensor *> outputs = {output_tensor};

MS_LOG(INFO) << "setting SliceParameter ";
auto param = new (std::nothrow) SliceParameter();
auto param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter)));
if (param == nullptr) {
for (auto tensor : inputs) {
delete tensor;
@@ -137,14 +248,18 @@ TEST_F(TestSliceOpenCLfp32, Slicefp32input_dim4) {

auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;
}

TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
@@ -153,17 +268,17 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) {
auto allocator = ocl_runtime->GetAllocator();

MS_LOG(INFO) << " Read tensors from .bin ";
std::vector<int> input_shape = {1, 256, 256, 48};
std::vector<int> output_shape = {1, 255, 255, 15};
std::vector<int> input_shape = {1, 25, 25, 48};
std::vector<int> output_shape = {1, 24, 24, 15};
std::vector<int> begin = {0, 1, 1, 7};
std::vector<int> size = {1, 255, 255, 15};
std::vector<int> size = {1, 24, 24, 15};
auto data_type = kNumberTypeFloat16;
auto tensor_type = lite::TensorCategory(schema::NodeType_ValueNode);

// get the input from .bin
size_t input_size, output_size;
std::string input_path = "./test_data/in_data.bin";
std::string output_path = "./test_data/out_data.bin";
std::string input_path = "./test_data/in_slicefp16.bin";
std::string output_path = "./test_data/out_slicefp16.bin";
auto input_data = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
auto correct_data = reinterpret_cast<float16_t *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));

@@ -183,7 +298,7 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) {
std::vector<lite::Tensor *> outputs = {output_tensor};

MS_LOG(INFO) << " setting SliceParameter ";
auto param = new (std::nothrow) SliceParameter();
auto param = reinterpret_cast<SliceParameter *>(malloc(sizeof(SliceParameter)));
if (param == nullptr) {
for (auto tensor : inputs) {
delete tensor;
@@ -241,13 +356,15 @@ TEST_F(TestSliceOpenCLfp16, Slicefp16input_dim4) {

std::cout << "==================output data================" << std::endl;
sub_graph->Run();

auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->data_c());
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
for (auto tensor : inputs) {
tensor->SetData(nullptr);
delete tensor;
}
for (auto tensor : outputs) {
tensor->SetData(nullptr);
delete tensor;
}
delete sub_graph;


Loading…
Cancel
Save