From 4499c74309fc748eb61d5a9230c14ebde7eeffa6 Mon Sep 17 00:00:00 2001 From: Pengyongrong Date: Sun, 8 Nov 2020 18:26:38 -0800 Subject: [PATCH] add new ops named sparse_to_dense,fill,shape for GPU --- .../kernel/opencl/cl/sparse_to_dense.cl | 131 +++++ .../src/runtime/kernel/opencl/kernel/fill.cc | 115 +++++ .../src/runtime/kernel/opencl/kernel/fill.h | 49 ++ .../kernel/opencl/kernel/sparse_to_dense.cc | 203 ++++++++ .../kernel/opencl/kernel/sparse_to_dense.h | 58 +++ .../lite/src/runtime/opencl/opencl_wrapper.cc | 10 + .../lite/src/runtime/opencl/opencl_wrapper.h | 3 + .../src/runtime/kernel/opencl/fill_tests.cc | 145 ++++++ .../kernel/opencl/sparse_to_dense_tests.cc | 459 ++++++++++++++++++ 9 files changed, 1173 insertions(+) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl new file mode 100644 index 0000000000..c46029561e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl @@ -0,0 +1,131 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +#define C4NUM 4 +__kernel void SparseToDenseScalarDim0(__read_only image2d_t input, __write_only image2d_t output, float weight, + int2 input_shape, float default_value) { + FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, 0)); + FLT4 result = {default_value, default_value, default_value, default_value}; + int integer = index_input.x / C4NUM; + int decimal = (int)(index_input.x) % C4NUM; + if (decimal == 0) { + result.x = weight; + } else if (decimal == 1) { + result.y = weight; + } else if (decimal == 2) { + result.z = weight; + } else { + result.w = weight; + } + WRITE_IMAGE(output, (int2)(0, integer), result); + return; +} + +__kernel void SparseToDenseScalarDim1(__read_only image2d_t input, __write_only image2d_t output, float weight, + int2 input_shape, float default_value) { + for (int i = 0; i < input_shape.x; ++i) { + FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i)); + int Y = result.x; + result.x = weight; + WRITE_IMAGE(output, (int2)(0, Y), result); + } +} + +__kernel void SparseToDenseVectorDim1(__read_only image2d_t input, __write_only image2d_t output, + __global float *weight, int2 input_shape, float default_value) { + int index_weight = 0; + for (int i = 0; i < input_shape.x; ++i) { + FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i)); + int Y = result.x; + result.x = weight[index_weight++]; + WRITE_IMAGE(output, (int2)(0, Y), result); + } +} + +__kernel void SparseToDenseScalarDim2Shape2(__read_only image2d_t input, __write_only image2d_t output, float weight, + int2 input_shape, float default_value) { + FLT temp[8] = {default_value, default_value, default_value, default_value, + default_value, default_value, default_value, default_value}; + FLT result_temp[8] = {default_value, default_value, default_value, default_value, + default_value, default_value, default_value, default_value}; + int index = 0; // 0~4 + int X = 0; + FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0)); + int Y = (int)index_begin.x; // N + temp[index] = index_begin.y; // c/4 + for (int i = 1; i < input_shape.x && index < C4NUM; ++i) { + FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i)); + if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) { + index++; + if (index < C4NUM) { + temp[index] = index_input.y; + } + } else { + for (int j = 0; j <= index && index < C4NUM; ++j) { + int decimal = (int)temp[j] % C4NUM; + result_temp[decimal] = weight; + X = ((int)temp[0]) / C4NUM; + } + FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; + WRITE_IMAGE(output, (int2)(X, Y), result); + index = 0; + Y = (int)index_input.x; + temp[0] = index_input.y; + temp[1] = temp[2] = temp[3] = default_value; + result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value; + } + } + + // judge the last element for input + X = ((int)temp[0]) / C4NUM; + for (int i = 0; i <= index && index < C4NUM; ++i) { + int decimal = (int)temp[i] % C4NUM; + result_temp[decimal] = weight; + } + FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; + WRITE_IMAGE(output, (int2)(X, Y), result); +} + +__kernel void SparseToDenseVectorDim2Shape2(__read_only image2d_t input, __write_only image2d_t output, + __global float *weight, int2 input_shape, float default_value) { + FLT temp[8] = {default_value, default_value, default_value, default_value, + default_value, default_value, default_value, default_value}; + FLT result_temp[8] = {default_value, default_value, default_value, default_value, + default_value, default_value, default_value, default_value}; + int index = 0; // 0~4 + int weight_index = 0; + int X = 0; + FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0)); + int Y = (int)index_begin.x; // N + temp[index] = index_begin.y; // c/4 + for (int i = 1; i < input_shape.x && index < C4NUM; ++i) { + FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i)); + if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) { + index++; + if (index < C4NUM) { + temp[index] = index_input.y; + } + } else { + for (int j = 0; j <= index && index < C4NUM; ++j) { + int decimal = (int)temp[j] % C4NUM; + result_temp[decimal] = weight[weight_index++]; + X = ((int)temp[0]) / C4NUM; + } + FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; + WRITE_IMAGE(output, (int2)(X, Y), result); + index = 0; + Y = (int)index_input.x; + temp[0] = index_input.y; + temp[1] = temp[2] = temp[3] = default_value; + result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value; + } + } + + // judge the last element for input + X = ((int)temp[0]) / C4NUM; + for (int i = 0; i <= index && index < C4NUM; ++i) { + int decimal = (int)temp[i] % C4NUM; + result_temp[decimal] = weight[weight_index++]; + } + FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; + WRITE_IMAGE(output, (int2)(X, Y), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc new file mode 100644 index 0000000000..4125007a5d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/fill.h" +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Fill; +using mindspore::schema::PrimitiveType_Shape; + +namespace mindspore::kernel { + +int FillOpenCLKernel::RunFill() { + auto allocator_ = ocl_runtime_->GetAllocator(); + auto param = reinterpret_cast(this->op_parameter_); + default_ = param->num_dims_; + std::vector img_size; + cl_float4 fill_value = {}; + fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_; + auto src_data = out_tensors_[0]->data_c(); + allocator_->GetImageSize(src_data, &img_size); + auto src_origin = cl::array{0, 0, 0}; + auto region = cl::array{img_size[0], img_size[1], 1}; + cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); + return RET_OK; +} + +int FillOpenCLKernel::RunShape() { + auto allocator_ = ocl_runtime_->GetAllocator(); + auto src_data = out_tensors_[0]->data_c(); + cl_float4 fill_value = {default_, default_, default_, default_}; + for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) { + fill_value.s[0] = in_tensors_[0]->shape()[i]; + size_t index = static_cast(i); + auto src_origin = cl::array{0, index, 0}; + auto region = cl::array{1, 1, 1}; + cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); + } + return RET_OK; +} + +int FillOpenCLKernel::Init() { + auto param = this->op_parameter_; + + if (out_tensors_[0]->shape().size() > 4) { + MS_LOG(ERROR) << " only support dim <= 4"; + return RET_ERROR; + } + if (in_tensors_[0]->shape().size() > 1 && param->type_ == PrimitiveType_Fill) { + MS_LOG(ERROR) << " fill only support dim = 1"; + return RET_ERROR; + } + return RET_OK; +} + +int FillOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running! "; + auto param = this->op_parameter_; + if (param->type_ == PrimitiveType_Fill) { + RunFill(); + } else { + RunShape(); + } + + return RET_OK; +} + +kernel::LiteKernel *FillOpenCLKernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) FillOpenCLKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << " new FillOpenCLKernel failed "; + free(opParameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << " Init kernel failed, name: fill "; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Fill, FillOpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Shape, FillOpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Fill, FillOpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Shape, FillOpenCLKernelCreator); + +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h new file mode 100644 index 0000000000..5d6da3120d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ + +#include +#include "mindspore/lite/nnacl/fp32/fill.h" +#include "mindspore/lite/nnacl/shape.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + +namespace mindspore::kernel { + +class FillOpenCLKernel : public OpenCLKernel { + public: + FillOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~FillOpenCLKernel() override = default; + + int Init() override; + + int Run() override; + + private: + int RunFill(); + int RunShape(); + cl::Kernel kernel_; + + private: + float default_{0.0f}; +}; + +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc new file mode 100644 index 0000000000..a1ed209cfc --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc @@ -0,0 +1,203 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/opencl/kernel/sparse_to_dense.h" +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/kernel/opencl/utils.h" +#include "src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_SparseToDense; + +namespace mindspore::kernel { + +int SparseToDenseOpenCLKernel::InitOutputToDefault() { + auto allocator_ = ocl_runtime_->GetAllocator(); + std::vector img_size; + cl_float4 fill_value = {}; + fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_; + auto src_data = out_tensors_[0]->data_c(); + allocator_->GetImageSize(src_data, &img_size); + auto src_origin = cl::array{0, 0, 0}; + auto region = cl::array{img_size[0], img_size[1], 1}; + cl::Image2D *out_image = reinterpret_cast(allocator_->GetImage(src_data)); + ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); + return RET_OK; +} + +int SparseToDenseOpenCLKernel::InitWeights() { + auto allocator = ocl_runtime_->GetAllocator(); + auto weight_tensor = in_tensors_[2]; + size_t size = 1; + for (int i = 0; i < weight_tensor->shape().size(); ++i) { + size *= weight_tensor->shape()[i]; + } + if (weight_scalar_) { + if (weight_tensor->data_type() == kNumberTypeFloat16) { + weight_scalar_ = static_cast(*reinterpret_cast(weight_tensor->data_c())); + } else { + weight_scalar_ = *reinterpret_cast(weight_tensor->data_c()); + } + } else { + auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); + size_t weight_size = UP_ROUND(size, C4NUM) * sizeof_FLT; + weight_vector_ = allocator->Malloc(weight_size); + allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true); + memset(weight_vector_, 0x00, weight_size); + if (weight_tensor->data_type() == kNumberTypeFloat16) { + if (enable_fp16_) { + memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT); + } else { + auto weight_fp32 = reinterpret_cast(weight_vector_); + auto origin_bias_fp16 = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < size; ++i) { + weight_fp32[i] = static_cast(origin_bias_fp16[i]); + } + } + } else { + if (enable_fp16_) { + auto weight_fp16 = reinterpret_cast(weight_vector_); + auto origin_bias_fp32 = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < size; ++i) { + weight_fp16[i] = static_cast(origin_bias_fp32[i]); + } + } else { + memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT); + } + } + allocator->UnmapBuffer(weight_vector_); + } + return RET_OK; +} + +int SparseToDenseOpenCLKernel::Init() { + if (out_tensors_[0]->shape().size() > 2 || in_tensors_.size() < 3) { + MS_LOG(ERROR) << " only support dim <= 2 and in_tensors_.size >= 3"; + return RET_ERROR; + } + if ((in_tensors_[0]->shape()[1] > 3) && (input_dim_ == 2)) { + MS_LOG(ERROR) << "in_tensors_indices shape[1] must be 1 2 or 3 && input_dim_=2 ,but your shapes is: " + << in_tensors_[0]->shape()[1] << "your input_dim_ is: " << input_dim_; + return ERROR; + } + input_dim_ = in_tensors_[0]->shape().size(); + weight_scalar_ = in_tensors_[2]->IsScalar(); + std::string kernel_name = "SparseToDense" + std::string(weight_scalar_ ? "ScalarDim" : "VectorDim") + + std::to_string(in_tensors_[0]->shape()[1] == 1 ? 1 : input_dim_); + if (input_dim_ == 2 && in_tensors_[0]->shape()[1] != 1) { + kernel_name += "Shape" + std::to_string(in_tensors_[0]->shape()[1]); + } + + std::set build_options; + std::string source = sparse_to_dense_source; + std::string program_name = "SparseToDense"; + ocl_runtime_->LoadSource(program_name, source); + ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); + + if (in_tensors_.size() > 3) { + auto input_tensor3 = in_tensors_[3]; + if (input_tensor3->data_type() == kNumberTypeFloat16) { + default_ = static_cast(*reinterpret_cast(input_tensor3->data_c())); + } else { + default_ = *reinterpret_cast(input_tensor3->data_c()); + } + } + + InitWeights(); + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; +} + +int SparseToDenseOpenCLKernel::InferShapeTo4D() { + if (in_tensors_[0]->shape().size() <= 4) { + if (in_tensors_[0]->shape().size() == 1) { + N_ = in_tensors_[0]->shape()[0]; + } else if (in_tensors_[0]->shape().size() == 2) { + N_ = in_tensors_[0]->shape()[0]; + C_ = in_tensors_[0]->shape()[1]; + } else if (in_tensors_[0]->shape().size() == 3) { + N_ = in_tensors_[0]->shape()[0]; + W_ = in_tensors_[0]->shape()[1]; + C_ = in_tensors_[0]->shape()[2]; + } else { + N_ = in_tensors_[0]->shape()[0]; + H_ = in_tensors_[0]->shape()[1]; + W_ = in_tensors_[0]->shape()[2]; + C_ = in_tensors_[0]->shape()[3]; + } + } else { + MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size(); + return RET_ERROR; + } + return RET_OK; +} + +int SparseToDenseOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running! "; + InferShapeTo4D(); + cl_int2 input_shape = {static_cast(N_ * H_), static_cast(W_ * UP_DIV(C_, C4NUM))}; + InitOutputToDefault(); + std::vector local = {1, 1}; + std::vector global = {1, 1}; + int arg_cn = 0; + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); + if (weight_scalar_) { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_scalar_); + } else { + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_); + } + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape); + ocl_runtime_->SetKernelArg(kernel_, arg_cn++, default_); + ocl_runtime_->RunKernel(kernel_, global, local, nullptr); + return RET_OK; +} + +kernel::LiteKernel *SparseToDenseOpenCLKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::InnerContext *ctx, + const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + if (inputs.empty()) { + MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); + free(opParameter); + return nullptr; + } + auto *kernel = new (std::nothrow) SparseToDenseOpenCLKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << " new HswishOpenCLKernel failed "; + free(opParameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << " Init kernel failed, name: hswish "; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator); +REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator); +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h new file mode 100644 index 0000000000..542bc417bf --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_ + +#include +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" + +namespace mindspore::kernel { + +class SparseToDenseOpenCLKernel : public OpenCLKernel { + public: + SparseToDenseOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~SparseToDenseOpenCLKernel() override = default; + + int Init() override; + int Run() override; + int InitWeights() override; + + private: + int InferShapeTo4D(); + int InitOutputToDefault(); + + private: + cl::Kernel kernel_; + // bool IndicesIsScalar{false}; + bool enable_fp16_{false}; + float default_{0.0f}; + float weight_scalar_{0.f}; + void *weight_vector_{nullptr}; + int input_dim_{1}; + std::vector output_shape_; + + size_t N_{1}; + size_t H_{1}; + size_t W_{1}; + size_t C_{1}; +}; +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc index 7d3594c01c..7cbac3205b 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.cc @@ -142,6 +142,7 @@ bool LoadLibraryFromPath(const std::string &library_path, void *handle) { LOAD_OPENCL_FUNCTION_PTR(clRetainDevice); LOAD_OPENCL_FUNCTION_PTR(clReleaseDevice); LOAD_OPENCL_FUNCTION_PTR(clCreateImage); + LOAD_OPENCL_FUNCTION_PTR(clEnqueueFillImage); #endif #if CL_HPP_TARGET_OPENCL_VERSION >= 200 LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueueWithProperties); @@ -228,6 +229,7 @@ CL_DEFINE_FUNC_PTR(clEnqueueCopyImageToBuffer); CL_DEFINE_FUNC_PTR(clRetainDevice); CL_DEFINE_FUNC_PTR(clReleaseDevice); CL_DEFINE_FUNC_PTR(clCreateImage); +CL_DEFINE_FUNC_PTR(clEnqueueFillImage); #endif #if CL_HPP_TARGET_OPENCL_VERSION >= 200 CL_DEFINE_FUNC_PTR(clGetKernelSubGroupInfoKHR); @@ -666,6 +668,14 @@ cl_mem clCreateImage(cl_context context, cl_mem_flags flags, const cl_image_form return func(context, flags, image_format, image_desc, host_ptr, errcode_ret); } +cl_int clEnqueueFillImage(cl_command_queue command_queue, cl_mem image, const void *fill_color, const size_t *origin, + const size_t *region, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, + cl_event *event) { + auto func = mindspore::lite::opencl::clEnqueueFillImage; + MS_ASSERT(func != nullptr); + return func(command_queue, image, fill_color, origin, region, num_events_in_wait_list, event_wait_list, event); +} + #endif #if CL_HPP_TARGET_OPENCL_VERSION >= 200 diff --git a/mindspore/lite/src/runtime/opencl/opencl_wrapper.h b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h index 76b7b11586..2f9a3cb5a7 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_wrapper.h +++ b/mindspore/lite/src/runtime/opencl/opencl_wrapper.h @@ -127,6 +127,8 @@ using clRetainDeviceFunc = cl_int (*)(cl_device_id); using clReleaseDeviceFunc = cl_int (*)(cl_device_id); using clCreateImageFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, cl_int *); +using clEnqueueFillImageFunc = cl_int (*)(cl_command_queue, cl_mem, const void *, const size_t *, const size_t *, + cl_uint, const cl_event *, cl_event *); #endif #if CL_HPP_TARGET_OPENCL_VERSION >= 200 using clCreateProgramWithILFunc = cl_program (*)(cl_context, const void *, size_t, cl_int *); @@ -199,6 +201,7 @@ CL_DECLARE_FUNC_PTR(clEnqueueCopyImageToBuffer); CL_DECLARE_FUNC_PTR(clRetainDevice); CL_DECLARE_FUNC_PTR(clReleaseDevice); CL_DECLARE_FUNC_PTR(clCreateImage); +CL_DECLARE_FUNC_PTR(clEnqueueFillImage); #endif #if CL_HPP_TARGET_OPENCL_VERSION >= 200 CL_DECLARE_FUNC_PTR(clGetKernelSubGroupInfoKHR); diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc new file mode 100644 index 0000000000..c93603e90f --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/fill_tests.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h" +using mindspore::lite::Tensor; +using mindspore::schema::PrimitiveType_Fill; +using mindspore::schema::PrimitiveType_Shape; +using mindspore::schema::Format::Format_NHWC; +namespace mindspore { +class TestFillOpenCLCI : public mindspore::CommonTest { + public: + TestFillOpenCLCI() {} +}; + +TEST_F(TestFillOpenCLCI, Fp32testfill) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {2}; + float input_data1[] = {3, 3}; + float correctOutput[] = {9, 9, 9, 9, 9, 9, 9, 9, 9}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {3, 3}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + std::vector inputs{&in_tensor1}; + std::vector outputs{&output_tensor}; + + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(FillParameter))); + param->num_dims_ = 9; + param->op_parameter_.type_ = PrimitiveType_Fill; + if (param == nullptr) { + MS_LOG(INFO) << " new FillParameter failed "; + return; + } + + auto *fill_kernel = + new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (fill_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; + delete param; + return; + } + fill_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{fill_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete fill_kernel; + return; + } + // to allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestFillOpenCLCI, Fp32testshape) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {2, 4}; + float input_data1[] = {-0.4045, -0.0924, -0.617, -0.10114, -0.9893, 0.3342, 2.445, -2.182}; + float correctOutput[] = {2, 4}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {2}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + std::vector inputs{&in_tensor1}; + std::vector outputs{&output_tensor}; + + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(ShapeParameter))); + param->op_parameter_.type_ = PrimitiveType_Shape; + if (param == nullptr) { + MS_LOG(INFO) << " new FillParameter failed "; + return; + } + + auto *fill_kernel = + new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (fill_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; + delete param; + return; + } + fill_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{fill_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete fill_kernel; + return; + } + // to allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc new file mode 100644 index 0000000000..87c9d2d88a --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/sparse_to_dense_tests.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "src/common/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h" +using mindspore::lite::Tensor; +using mindspore::schema::Format::Format_NHWC; +namespace mindspore { +class TestSparseToDenseOpenCLCI : public mindspore::CommonTest { + public: + TestSparseToDenseOpenCLCI() {} +}; + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Scalar) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {6, 2}; + std::vector input_shape2 = {2}; + std::vector input_shape3 = {1}; + std::vector input_shape4 = {1}; + float input_data1[] = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; + float input_data2[] = {6, 10}; + float input_data3[] = {6.0}; + float input_data4[] = {0.0}; + float correctOutput[] = {6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {6, 10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Vector) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {6, 2}; + std::vector input_shape2 = {2}; + std::vector input_shape3 = {6}; + std::vector input_shape4 = {1}; + float input_data1[] = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; + float input_data2[] = {6, 10}; + float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + float input_data4[] = {0.0}; + float correctOutput[] = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {6, 10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Shape1Vector) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {6, 1}; + std::vector input_shape2 = {1}; + std::vector input_shape3 = {6}; + std::vector input_shape4 = {1}; + float input_data1[] = {0, 2, 3, 6, 7, 9}; + float input_data2[] = {10}; + float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + float input_data4[] = {0.0}; + float correctOutput[] = {1, 0, 2, 3, 0, 0, 4, 5, 0, 6}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Shape1Scalar) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {7, 1}; // shape[1] = 1 + std::vector input_shape2 = {1}; + std::vector input_shape3 = {1}; + std::vector input_shape4 = {1}; + float input_data1[] = {0, 1, 2, 3, 4, 5, 9}; + float input_data2[] = {10}; + float input_data3[] = {6.0}; + float input_data4[] = {0.0}; + float correctOutput[] = {6, 6, 6, 6, 6, 6, 0, 0, 0, 6}; + auto data_type = kNumberTypeFloat32; + std::vector output_shape = {10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim1Scalar) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {6}; + std::vector input_shape2 = {1}; + std::vector input_shape3 = {1}; + std::vector input_shape4 = {1}; + float input_data1[] = {1, 3, 4, 5, 6, 7}; + float input_data2[] = {10}; + float input_data3[] = {1.0}; + float input_data4[] = {2.0}; + float correctOutput[] = {2, 1, 2, 1, 1, 1, 1, 1, 2, 2}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::Tensor::CONST_TENSOR; + std::vector output_shape = {10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, tensor_type); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, tensor_type); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, tensor_type); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, tensor_type); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim1Vector) { + MS_LOG(INFO) << " begin test "; + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); + runtime->Init(); + auto allocator = runtime->GetAllocator(); + MS_LOG(INFO) << " init tensors "; + std::vector input_shape1 = {6}; + std::vector input_shape2 = {1}; + std::vector input_shape3 = {6}; + std::vector input_shape4 = {1}; + float input_data1[] = {1, 3, 4, 5, 6, 7}; + float input_data2[] = {10}; + float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + float input_data4[] = {2.0}; + float correctOutput[] = {2, 1, 2, 2, 3, 4, 5, 6, 2, 2}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = lite::Tensor::CONST_TENSOR; + std::vector output_shape = {10}; + auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, tensor_type); + auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, tensor_type); + auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, tensor_type); + auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, tensor_type); + auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, tensor_type); + // allocate memory for weights + in_tensor2.MallocData(); + in_tensor3.MallocData(); + in_tensor4.MallocData(); + std::vector inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; + std::vector outputs{&output_tensor}; + // initialize weights + memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); + memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); + memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); + MS_LOG(INFO) << " initialize tensors "; + auto param = reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (param == nullptr) { + MS_LOG(INFO) << " new ActivationParameter failed "; + return; + } + + auto *sparse_to_dense_kernel = + new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (sparse_to_dense_kernel == nullptr) { + MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; + delete param; + return; + } + sparse_to_dense_kernel->Init(); + MS_LOG(INFO) << " initialize sub_graph "; + std::vector kernels{sparse_to_dense_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; + delete param; + delete sparse_to_dense_kernel; + return; + } + // to do allocate memory for inputs + in_tensor1.MallocData(allocator); + sub_graph->Init(); + MS_LOG(INFO) << " initialize input data "; + memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor.data_c()); + CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); + delete sub_graph; +} + +} // namespace mindspore