Browse Source

!8122 [MS][LITE][GPU]fix bug: arm32 GPU build failed

Merge pull request !8122 from chenzupeng/master-lite
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b74a1d0a60
6 changed files with 275 additions and 413 deletions
  1. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h
  3. +1
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc
  4. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h
  5. +139
    -194
      mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
  6. +132
    -214
      mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc

+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc View File

@@ -295,7 +295,7 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::Tensor
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel =
new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx);
new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!";
free(opParameter);


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h View File

@@ -26,7 +26,7 @@ namespace mindspore::kernel {
class ArithmeticOpenCLKernel : public OpenCLKernel {
public:
ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~ArithmeticOpenCLKernel() override = default;



+ 1
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc View File

@@ -251,8 +251,7 @@ kernel::LiteKernel *OpenCLScaleKernelCreator(const std::vector<lite::Tensor *> &
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel =
new (std::nothrow) ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx);
auto *kernel = new (std::nothrow) ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Create OpenCL Scale kernel failed!";
free(opParameter);


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h View File

@@ -26,7 +26,7 @@ namespace mindspore::kernel {
class ScaleOpenCLKernel : public OpenCLKernel {
public:
ScaleOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~ScaleOpenCLKernel() override;



+ 139
- 194
mindspore/lite/test/ut/src/runtime/kernel/opencl/arithmetic_tests.cc View File

@@ -13,233 +13,178 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"

namespace mindspore {
class TestArithmeticOpenCL : public mindspore::CommonTest {
public:
TestArithmeticOpenCL() {}
};

template <class T>
static void BoardcaseAdd(const T *a, const T b, T *c, const int size) {
for (int i = 0; i < size; i++) {
c[i] = a[i] + b;
}
}

template <class T>
static void ElementAdd(const T *a, const T *b, T *c, const int size) {
for (int i = 0; i < size; i++) {
c[i] = a[i] + b[i];
}
}

template <class T>
static bool DataCompare(const T *a, const T *b, const int size, const float accuracy = 1e-4) {
for (int i = 0; i < size; i++) {
auto diff = fabs(a[i] - b[i]);
if (diff > accuracy) {
MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i];
return false;
}
}
return true;
}

template <class T>
static void InitData(void *data, const int size) {
T *data_float = reinterpret_cast<T *>(data);
static unsigned int seed = 123;
for (int i = 0; i < size; i++) {
data_float[i] = static_cast<int>(rand_r(&seed)) % 100;
}
}

template <class T>
static void LogData(void *data, const int size, const std::string prefix) {
std::cout << prefix;
T *data_float = reinterpret_cast<T *>(data);
for (int i = 0; i < size; i++) {
std::cout << data_float[i] << ",";
}
std::cout << std::endl;
}

template <class T>
static void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) {
bool is_log_data = false;
void RunTestCaseArithmetic(void *input_data0, const std::vector<int> &input_shape, void *input_data1,
const std::vector<int> &weight_shape, void *output_data, const std::vector<int> &out_shape,
bool enable_fp16, int op_type, int act_type = schema::ActivationType_NO_ACTIVATION) {
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float);
ocl_runtime->SetFp16Enable(enable_fp16);
auto allocator = ocl_runtime->GetAllocator();

bool is_bias_add = shape_b.empty();
auto data_type = kNumberTypeFloat32;
if (sizeof(T) == 2) {
data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(true);
}

lite::Tensor *tensor_a = new (std::nothrow) lite::Tensor(data_type, shape_a, schema::Format_NHWC4);
lite::Tensor *tensor_b = new (std::nothrow) lite::Tensor(data_type, shape_b, schema::Format_NHWC4);
lite::Tensor *tensor_c = new (std::nothrow) lite::Tensor(data_type, shape_a, schema::Format_NHWC4);
if (tensor_a == nullptr || tensor_b == nullptr || tensor_c == nullptr) {
MS_LOG(ERROR) << "Create tensor failed!";
delete tensor_a;
delete tensor_b;
delete tensor_c;
auto param = static_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "param_ptr create error.";
return;
}

int64_t element_num = tensor_a->ElementsC4Num();
int64_t element_num_b = is_bias_add ? 1 : tensor_b->ElementsC4Num();

T *data_a = new (std::nothrow) T[element_num];
T *data_b = new (std::nothrow) T[element_num_b];
T *data_c_cpu = new (std::nothrow) T[element_num];
T *data_c_ocl = new (std::nothrow) T[element_num];
if (data_a == nullptr || data_b == nullptr || data_c_cpu == nullptr || data_c_ocl == nullptr) {
MS_LOG(ERROR) << "Create buffer failed!";
delete tensor_a;
delete tensor_b;
delete tensor_c;
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;
int input0_size = std::accumulate(input_shape.begin(), input_shape.end(), 1LL, std::multiplies<int>());
int input1_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1LL, std::multiplies<int>());
if (input0_size != input1_size) {
param->broadcasting_ = true;
}
param->op_parameter_.type_ = op_type;
param->activation_type_ = act_type;
auto tensor_x_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}

InitData<T>(data_a, element_num);
InitData<T>(data_b, element_num_b);
memset(data_c_ocl, 0, sizeof(T) * element_num);

if (is_bias_add) {
BoardcaseAdd(data_a, static_cast<T *>(data_b)[0], data_c_cpu, element_num);
} else {
ElementAdd(data_a, data_b, data_c_cpu, element_num);
}

std::vector<lite::Tensor *> inputs = {tensor_a};
if (!is_bias_add) {
inputs.push_back(tensor_b);
} else {
tensor_b->MallocData();
memcpy(tensor_b->data_c(), data_b, sizeof(T));
}
std::vector<lite::Tensor *> outputs = {tensor_c};

ArithmeticParameter *param = static_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
param->broadcasting_ = is_bias_add;
if (param == nullptr) {
MS_LOG(ERROR) << "Create parameter failed!";
delete tensor_a;
delete tensor_b;
delete tensor_c;
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;
auto tensor_w_ptr = std::make_unique<lite::Tensor>(
TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape, schema::Format_NHWC,
input1_size != 1 ? lite::Tensor::Category::CONST_TENSOR : lite::Tensor::Category::CONST_SCALAR);
auto tensor_w = tensor_w_ptr.get();
if (tensor_w == nullptr) {
MS_LOG(ERROR) << "tensor_w create error.";
return;
}
param->ndim_ = 4;
param->op_parameter_.type_ = PrimitiveType_Add;

std::vector<lite::Tensor *> arithmetic_inputs = {tensor_a, tensor_b};
lite::InnerContext ctx;
ASSERT_EQ(lite::RET_OK, ctx.Init());
auto *arith_kernel = new (std::nothrow)
kernel::ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(param), arithmetic_inputs, outputs, &ctx);
if (arith_kernel == nullptr) {
MS_LOG(ERROR) << "Create ArithmeticOpenCLKernel failed!";
delete tensor_a;
delete tensor_b;
delete tensor_c;
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;
free(param);
tensor_w->set_data(input_data1);
auto tensor_out_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
return;
}
arith_kernel->Init();

tensor_a->MallocData(allocator);
tensor_b->MallocData(allocator);
std::vector<kernel::LiteKernel *> kernels{arith_kernel};
auto *kernel = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (arith_kernel == nullptr) {
MS_LOG(ERROR) << "Create SubGraphOpenCLKernel failed!";
delete tensor_a;
delete tensor_b;
delete tensor_c;
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;
delete arith_kernel;
std::vector<lite::Tensor *> inputs{tensor_x, tensor_w};
std::vector<lite::Tensor *> outputs{tensor_out};
auto op_kernel_ptr =
std::make_unique<kernel::ArithmeticOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs);
auto op_kernel = op_kernel_ptr.release();
if (op_kernel == nullptr) {
MS_LOG(ERROR) << "op_kernel create error.";
return;
}
kernel->Init();
op_kernel->Init();
inputs[0]->MallocData(allocator);

memcpy(inputs[0]->data_c(), data_a, sizeof(T) * element_num);
if (!is_bias_add) {
memcpy(inputs[1]->data_c(), data_b, sizeof(T) * element_num_b);
}

kernel->Run();

memcpy(data_c_ocl, outputs[0]->data_c(), sizeof(T) * element_num);
std::vector<kernel::LiteKernel *> kernels{op_kernel};

if (is_log_data) {
LogData<T>(data_a, 10, "Data A : ");
LogData<T>(data_b, tensor_b->shape().empty() ? 1 : 10, "Data B : ");
LogData<T>(data_c_cpu, 10, "Expect compute : ");
LogData<T>(outputs[0]->data_c(), 10, "OpenCL compute : ");
std::vector<lite::Tensor *> inputs_g{tensor_x};
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
auto pGraph = pGraph_ptr.get();
if (pGraph == nullptr) {
MS_LOG(ERROR) << "pGraph create error.";
return;
}
pGraph->Init();
memcpy(inputs[0]->MutableData(), input_data0, tensor_x->ElementsNum() * dtype_size);
pGraph->Run();
if (enable_fp16) {
CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float16_t>(1e-3),
2e-2);
} else {
CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float>(1e-5));
}
bool cmp = DataCompare(data_c_cpu, data_c_ocl, element_num);
MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!");
EXPECT_EQ(true, cmp);

// free
delete[] data_a;
delete[] data_b;
delete[] data_c_cpu;
delete[] data_c_ocl;

delete kernel;
for (auto tensor : inputs) {
delete tensor;
for (auto t : inputs) {
t->set_data(nullptr);
}
for (auto tensor : outputs) {
delete tensor;
for (auto t : outputs) {
t->set_data(nullptr);
}
MS_LOG(INFO) << "TestArithmetic passed";
}

class TestArithmeticOpenCL : public mindspore::CommonTest {
public:
TestArithmeticOpenCL() {}
};
TEST_F(TestArithmeticOpenCL, ArithmeticElementwiseAddFp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> in_shape1 = {n, h, w, c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> output_data = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f};
RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape,
false, schema::PrimitiveType_Add);
}

TEST_F(TestArithmeticOpenCL, AddElementwiseFP32) {
const std::vector<int> &shape_a = {1, 1024, 1024, 4};
const std::vector<int> &shape_b = {1, 1024, 1024, 4};
TestCase<float>(shape_a, shape_b);
TEST_F(TestArithmeticOpenCL, ArithmeticScalarMulFp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> in_shape1 = {1};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> weight_data = {2.0f};
std::vector<float> output_data = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f};
RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape,
false, schema::PrimitiveType_Mul);
}

TEST_F(TestArithmeticOpenCL, AddBroadcastFP32) {
const std::vector<int> &shape_a = {1, 128, 128, 4};
const std::vector<int> &shape_b = {};
TestCase<float>(shape_a, shape_b);
TEST_F(TestArithmeticOpenCL, ArithmeticBroadcastSubReLU6Fp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> in_shape1 = {c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> weight_data = {1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {0.0f, 0.0f, 0.0f, 3.0f, 3.0f, 3.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f};
RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape,
false, schema::PrimitiveType_Sub, schema::ActivationType_RELU6);
}

TEST_F(TestArithmeticOpenCL, AddElementwiseFP16) {
const std::vector<int> &shape_a = {1, 1024, 1024, 4};
const std::vector<int> &shape_b = {1, 1024, 1024, 4};
TestCase<float16_t>(shape_a, shape_b);
TEST_F(TestArithmeticOpenCL, ArithmeticBroadcastSub2Fp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, c};
std::vector<int> in_shape1 = {n, h, w, c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f};
std::vector<float> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> output_data = {0.0f, 0.0f, 0.0f, -3.0f, -3.0f, -3.0f, -6.0f, -6.0f, -6.0f, -9.0f, -9.0f, -9.0f};
RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape,
false, schema::PrimitiveType_Sub);
}

TEST_F(TestArithmeticOpenCL, AddBroadcastFP16) {
const std::vector<int> &shape_a = {1, 128, 128, 4};
const std::vector<int> &shape_b = {};
TestCase<float16_t>(shape_a, shape_b);
TEST_F(TestArithmeticOpenCL, ArithmeticElementwiseDivFp16) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> in_shape1 = {n, h, w, c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float16_t> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float16_t> weight_data = {1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f};
std::vector<float16_t> output_data = {1.0f, 2.0f, 3.0f, 2.0f, 2.5, 3.0f, 7.0f, 8.0f, 9.0f, 5.0f, 5.5, 6.0f};
RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape,
true, schema::PrimitiveType_Div);
}
} // namespace mindspore

+ 132
- 214
mindspore/lite/test/ut/src/runtime/kernel/opencl/scale_tests.cc View File

@@ -13,253 +13,171 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h"
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"

namespace mindspore {
class TestScaleOpenCL : public mindspore::CommonTest {
public:
TestScaleOpenCL() {}
};

template <class T>
static void BoardcaseScale(const T *in, const T scale, const T offset, T *out, const int size) {
for (int i = 0; i < size; i++) {
out[i] = in[i] * scale + offset;
}
}

template <class T>
static void Scale(const T *in, const T *scale, T *offset, T *out, const int size) {
for (int i = 0; i < size; i++) {
out[i] = in[i] * scale[i] + offset[i];
}
}

template <class T>
static bool DataCompare(const T *a, const T *b, const int size, const T accuracy = 1e-4) {
for (int i = 0; i < size; i++) {
auto diff = fabs(a[i] - b[i]);
if (diff > accuracy) {
MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i];
return false;
}
}
return true;
}

template <class T>
static void InitData(void *data, const int size) {
T *data_float = reinterpret_cast<T *>(data);
static unsigned int seed = 123;
for (int i = 0; i < size; i++) {
data_float[i] = static_cast<int>(rand_r(&seed)) % 100;
}
}

template <class T>
static void LogData(void *data, const int size, const std::string prefix) {
std::cout << prefix;
T *data_float = reinterpret_cast<T *>(data);
for (int i = 0; i < size; i++) {
std::cout << data_float[i] << ",";
}
std::cout << std::endl;
}

template <class T>
static void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) {
bool is_log_data = false;
void RunTestCaseScale(void *input_data0, const std::vector<int> &input_shape, void *scale_data, void *offset_data,
const std::vector<int> &weight_shape, void *output_data, const std::vector<int> &out_shape,
bool enable_fp16, int axis, int act_type = schema::ActivationType_NO_ACTIVATION) {
auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance();
ocl_runtime->Init();
size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float);
ocl_runtime->SetFp16Enable(enable_fp16);
auto allocator = ocl_runtime->GetAllocator();

bool is_broadcast = shape_b.empty();
auto format = schema::Format_NHWC4;

auto data_type = kNumberTypeFloat32;
if (sizeof(T) == 2) {
data_type = kNumberTypeFloat16;
ocl_runtime->SetFp16Enable(true);
}
lite::Tensor *tensor_in = new (std::nothrow) lite::Tensor(data_type, shape_a, format);
lite::Tensor *tensor_scale = new (std::nothrow) lite::Tensor(data_type, shape_b, format);
lite::Tensor *tensor_offset = new (std::nothrow) lite::Tensor(data_type, shape_b, format);
lite::Tensor *tensor_out = new (std::nothrow) lite::Tensor(data_type, shape_a, format);
if (tensor_in == nullptr || tensor_scale == nullptr || tensor_offset == nullptr) {
MS_LOG(ERROR) << "Create tensor failed!";
delete tensor_in;
delete tensor_scale;
delete tensor_offset;
delete tensor_out;
auto param = static_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "param_ptr create error.";
return;
}

int64_t element_num = tensor_in->ElementsC4Num();
int64_t element_num_b = is_broadcast ? 1 : tensor_scale->ElementsC4Num();

T *data_in = new (std::nothrow) T[element_num];
T *data_scale = new (std::nothrow) T[element_num_b];
T *data_offset = new (std::nothrow) T[element_num_b];
T *data_out_cpu = new (std::nothrow) T[element_num];
T *data_out_ocl = new (std::nothrow) T[element_num];
if (data_in == nullptr || data_scale == nullptr || data_out_cpu == nullptr || data_out_ocl == nullptr) {
MS_LOG(ERROR) << "Create buffer failed!";
delete tensor_in;
delete tensor_scale;
delete tensor_offset;
delete tensor_out;
delete[] data_in;
delete[] data_scale;
delete[] data_offset;
delete[] data_out_cpu;
delete[] data_out_ocl;
param->axis_ = axis;
param->activation_type_ = act_type;
auto tensor_x_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape);
auto tensor_x = tensor_x_ptr.get();
if (tensor_x == nullptr) {
MS_LOG(ERROR) << "tensor_x create error.";
return;
}

InitData<T>(data_in, element_num);
InitData<T>(data_scale, element_num_b);
InitData<T>(data_offset, element_num_b);
memset(data_out_ocl, 0, sizeof(T) * element_num);

if (is_broadcast) {
BoardcaseScale(data_in, static_cast<T *>(data_scale)[0], static_cast<T *>(data_offset)[0], data_out_cpu,
element_num);
} else {
Scale(data_in, data_scale, data_offset, data_out_cpu, element_num);
}

std::vector<lite::Tensor *> inputs = {tensor_in};
if (!is_broadcast) {
inputs.push_back(tensor_scale);
inputs.push_back(tensor_offset);
} else {
tensor_scale->MallocData();
tensor_offset->MallocData();
memcpy(tensor_scale->data_c(), data_scale, sizeof(T));
memcpy(tensor_offset->data_c(), data_offset, sizeof(T));
}
std::vector<lite::Tensor *> outputs = {tensor_out};

ScaleParameter *param = static_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "Create parameter failed!";
delete tensor_in;
delete tensor_scale;
delete tensor_offset;
delete tensor_out;
delete[] data_in;
delete[] data_scale;
delete[] data_offset;
delete[] data_out_cpu;
delete[] data_out_ocl;
auto tensor_scale_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape,
schema::Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
auto tensor_scale = tensor_scale_ptr.get();
if (tensor_scale == nullptr) {
MS_LOG(ERROR) << "tensor_scale create error.";
return;
}
param->axis_ = 0;
param->op_parameter_.type_ = schema::PrimitiveType_Scale;

std::vector<lite::Tensor *> scale_inputs = {tensor_in, tensor_scale, tensor_offset};
lite::InnerContext ctx;
ASSERT_EQ(lite::RET_OK, ctx.Init());
auto *scale_kernel =
new (std::nothrow) kernel::ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(param), scale_inputs, outputs, &ctx);
if (scale_kernel == nullptr) {
MS_LOG(ERROR) << "Create ScaleOpenCLKernel failed!";
delete tensor_in;
delete tensor_scale;
delete tensor_offset;
delete tensor_out;
delete[] data_in;
delete[] data_scale;
delete[] data_offset;
delete[] data_out_cpu;
delete[] data_out_ocl;
free(param);
tensor_scale->set_data(scale_data);
auto tensor_offset_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape,
schema::Format_NHWC, lite::Tensor::Category::CONST_TENSOR);
auto tensor_offset = tensor_offset_ptr.get();
if (tensor_offset == nullptr) {
MS_LOG(ERROR) << "tensor_offset create error.";
return;
}
scale_kernel->Init();

tensor_in->MallocData(allocator);
tensor_scale->MallocData(allocator);
tensor_offset->MallocData(allocator);
std::vector<kernel::LiteKernel *> kernels{scale_kernel};
auto *kernel = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (scale_kernel == nullptr) {
MS_LOG(ERROR) << "Create SubGraphOpenCLKernel failed!";
delete tensor_in;
delete tensor_scale;
delete tensor_offset;
delete tensor_out;
delete[] data_in;
delete[] data_scale;
delete[] data_offset;
delete[] data_out_cpu;
delete[] data_out_ocl;
delete scale_kernel;
tensor_offset->set_data(offset_data);
auto tensor_out_ptr =
std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape);
auto tensor_out = tensor_out_ptr.get();
if (tensor_out == nullptr) {
MS_LOG(ERROR) << "tensor_out create error.";
return;
}
kernel->Init();

memcpy(inputs[0]->data_c(), data_in, sizeof(T) * element_num);
if (!is_broadcast) {
memcpy(inputs[1]->data_c(), data_scale, sizeof(T) * element_num_b);
memcpy(inputs[2]->data_c(), data_offset, sizeof(T) * element_num_b);
std::vector<lite::Tensor *> inputs{tensor_x, tensor_scale, tensor_offset};
std::vector<lite::Tensor *> outputs{tensor_out};
auto op_kernel_ptr =
std::make_unique<kernel::ScaleOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs);
auto op_kernel = op_kernel_ptr.release();
if (op_kernel == nullptr) {
MS_LOG(ERROR) << "op_kernel create error.";
return;
}
op_kernel->Init();
inputs[0]->MallocData(allocator);

kernel->Run();

memcpy(data_out_ocl, outputs[0]->data_c(), sizeof(T) * element_num);
std::vector<kernel::LiteKernel *> kernels{op_kernel};

if (is_log_data) {
LogData<T>(data_in, 10, "Data input : ");
LogData<T>(data_scale, tensor_scale->shape().empty() ? 1 : 10, "Data scale : ");
LogData<T>(data_offset, tensor_offset->shape().empty() ? 1 : 10, "Data offset : ");
LogData<T>(data_out_cpu, 10, "Expect compute : ");
LogData<T>(outputs[0]->data_c(), 10, "OpenCL compute : ");
std::vector<lite::Tensor *> inputs_g{tensor_x};
auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels);
auto pGraph = pGraph_ptr.get();
if (pGraph == nullptr) {
MS_LOG(ERROR) << "pGraph create error.";
return;
}
pGraph->Init();
memcpy(inputs[0]->MutableData(), input_data0, tensor_x->ElementsNum() * dtype_size);
pGraph->Run();
if (enable_fp16) {
CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float16_t>(1e-3),
2e-2);
} else {
CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float>(1e-5));
}
bool cmp = DataCompare(data_out_cpu, data_out_ocl, element_num);
MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!");
EXPECT_EQ(true, cmp);

// free
delete[] data_in;
delete[] data_scale;
delete[] data_offset;
delete[] data_out_cpu;
delete[] data_out_ocl;

delete kernel;
for (auto tensor : inputs) {
delete tensor;
for (auto t : inputs) {
t->set_data(nullptr);
}
for (auto tensor : outputs) {
delete tensor;
for (auto t : outputs) {
t->set_data(nullptr);
}
MS_LOG(INFO) << "TestScale passed";
}

class TestScaleOpenCL : public mindspore::CommonTest {
public:
TestScaleOpenCL() {}
};

TEST_F(TestScaleOpenCL, ElementFP32) {
const std::vector<int> &shape_a = {1, 1024, 1024, 4};
const std::vector<int> &shape_b = {1, 1024, 1024, 4};
TestCase<float>(shape_a, shape_b);
TEST_F(TestScaleOpenCL, ScaleAxis3Fp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> weight_shape = {c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> scale_data = {1.0f, 2.0f, 3.0f};
std::vector<float> offset_data = {1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {2.0f, 6.0f, 12.0f, 5.0f, 12.0f, 21.0f, 8.0f, 18.0f, 30.0f, 11.0f, 24.0f, 39.0f};
RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape,
output_data.data(), out_shape, false, 3);
}

TEST_F(TestScaleOpenCL, BroadcastFP32) {
const std::vector<int> &shape_a = {1, 128, 128, 4};
const std::vector<int> &shape_b = {};
TestCase<float>(shape_a, shape_b);
TEST_F(TestScaleOpenCL, ScaleAxis1Fp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> weight_shape = {h};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> scale_data = {1.0f, 2.0f};
std::vector<float> offset_data = {1.0f, 2.0f};
std::vector<float> output_data = {2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f};
RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape,
output_data.data(), out_shape, false, 1);
}

TEST_F(TestScaleOpenCL, ElementFP16) {
const std::vector<int> &shape_a = {1, 1024, 1024, 4};
const std::vector<int> &shape_b = {1, 1024, 1024, 4};
TestCase<float16_t>(shape_a, shape_b);
TEST_F(TestScaleOpenCL, ScaleAxis3ReLU6Fp32) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> weight_shape = {c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float> scale_data = {1.0f, 2.0f, -1.0f};
std::vector<float> offset_data = {1.0f, 2.0f, 3.0f};
std::vector<float> output_data = {2.0f, 6.0f, 0.0f, 5.0f, 6.0f, 0.0f, 6.0f, 6.0f, 0.0f, 6.0f, 6.0f, 0.0f};
RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape,
output_data.data(), out_shape, false, 3, schema::ActivationType_RELU6);
}

TEST_F(TestScaleOpenCL, BroadcastFP16) {
const std::vector<int> &shape_a = {1, 128, 128, 4};
const std::vector<int> &shape_b = {};
TestCase<float16_t>(shape_a, shape_b);
TEST_F(TestScaleOpenCL, ScaleAxis3Fp16) {
int n = 1;
int h = 2;
int w = 2;
int c = 3;
std::vector<int> in_shape0 = {n, h, w, c};
std::vector<int> weight_shape = {c};
std::vector<int> out_shape = {n, h, w, c};
std::vector<float16_t> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
std::vector<float16_t> scale_data = {1.0f, 2.0f, 3.0f};
std::vector<float16_t> offset_data = {1.0f, 2.0f, 3.0f};
std::vector<float16_t> output_data = {2.0f, 6.0f, 12.0f, 5.0f, 12.0f, 21.0f, 8.0f, 18.0f, 30.0f, 11.0f, 24.0f, 39.0f};
RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape,
output_data.data(), out_shape, true, 3);
}
} // namespace mindspore

Loading…
Cancel
Save