From: @pengyongrong Reviewed-by: @ddwsky,@zhanghaibo5 Signed-off-by: @ddwskytags/v1.1.0
| @@ -0,0 +1,131 @@ | |||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| #define C4NUM 4 | |||
| __kernel void SparseToDenseScalarDim0(__read_only image2d_t input, __write_only image2d_t output, float weight, | |||
| int2 input_shape, float default_value) { | |||
| FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, 0)); | |||
| FLT4 result = {default_value, default_value, default_value, default_value}; | |||
| int integer = index_input.x / C4NUM; | |||
| int decimal = (int)(index_input.x) % C4NUM; | |||
| if (decimal == 0) { | |||
| result.x = weight; | |||
| } else if (decimal == 1) { | |||
| result.y = weight; | |||
| } else if (decimal == 2) { | |||
| result.z = weight; | |||
| } else { | |||
| result.w = weight; | |||
| } | |||
| WRITE_IMAGE(output, (int2)(0, integer), result); | |||
| return; | |||
| } | |||
| __kernel void SparseToDenseScalarDim1(__read_only image2d_t input, __write_only image2d_t output, float weight, | |||
| int2 input_shape, float default_value) { | |||
| for (int i = 0; i < input_shape.x; ++i) { | |||
| FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i)); | |||
| int Y = result.x; | |||
| result.x = weight; | |||
| WRITE_IMAGE(output, (int2)(0, Y), result); | |||
| } | |||
| } | |||
| __kernel void SparseToDenseVectorDim1(__read_only image2d_t input, __write_only image2d_t output, | |||
| __global float *weight, int2 input_shape, float default_value) { | |||
| int index_weight = 0; | |||
| for (int i = 0; i < input_shape.x; ++i) { | |||
| FLT4 result = READ_IMAGE(input, smp_zero, (int2)(0, i)); | |||
| int Y = result.x; | |||
| result.x = weight[index_weight++]; | |||
| WRITE_IMAGE(output, (int2)(0, Y), result); | |||
| } | |||
| } | |||
| __kernel void SparseToDenseScalarDim2Shape2(__read_only image2d_t input, __write_only image2d_t output, float weight, | |||
| int2 input_shape, float default_value) { | |||
| FLT temp[8] = {default_value, default_value, default_value, default_value, | |||
| default_value, default_value, default_value, default_value}; | |||
| FLT result_temp[8] = {default_value, default_value, default_value, default_value, | |||
| default_value, default_value, default_value, default_value}; | |||
| int index = 0; // 0~4 | |||
| int X = 0; | |||
| FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0)); | |||
| int Y = (int)index_begin.x; // N | |||
| temp[index] = index_begin.y; // c/4 | |||
| for (int i = 1; i < input_shape.x && index < C4NUM; ++i) { | |||
| FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i)); | |||
| if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) { | |||
| index++; | |||
| if (index < C4NUM) { | |||
| temp[index] = index_input.y; | |||
| } | |||
| } else { | |||
| for (int j = 0; j <= index && index < C4NUM; ++j) { | |||
| int decimal = (int)temp[j] % C4NUM; | |||
| result_temp[decimal] = weight; | |||
| X = ((int)temp[0]) / C4NUM; | |||
| } | |||
| FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; | |||
| WRITE_IMAGE(output, (int2)(X, Y), result); | |||
| index = 0; | |||
| Y = (int)index_input.x; | |||
| temp[0] = index_input.y; | |||
| temp[1] = temp[2] = temp[3] = default_value; | |||
| result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value; | |||
| } | |||
| } | |||
| // judge the last element for input | |||
| X = ((int)temp[0]) / C4NUM; | |||
| for (int i = 0; i <= index && index < C4NUM; ++i) { | |||
| int decimal = (int)temp[i] % C4NUM; | |||
| result_temp[decimal] = weight; | |||
| } | |||
| FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; | |||
| WRITE_IMAGE(output, (int2)(X, Y), result); | |||
| } | |||
| __kernel void SparseToDenseVectorDim2Shape2(__read_only image2d_t input, __write_only image2d_t output, | |||
| __global float *weight, int2 input_shape, float default_value) { | |||
| FLT temp[8] = {default_value, default_value, default_value, default_value, | |||
| default_value, default_value, default_value, default_value}; | |||
| FLT result_temp[8] = {default_value, default_value, default_value, default_value, | |||
| default_value, default_value, default_value, default_value}; | |||
| int index = 0; // 0~4 | |||
| int weight_index = 0; | |||
| int X = 0; | |||
| FLT4 index_begin = READ_IMAGE(input, smp_zero, (int2)(0, 0)); | |||
| int Y = (int)index_begin.x; // N | |||
| temp[index] = index_begin.y; // c/4 | |||
| for (int i = 1; i < input_shape.x && index < C4NUM; ++i) { | |||
| FLT4 index_input = READ_IMAGE(input, smp_zero, (int2)(0, i)); | |||
| if ((((int)temp[index]) / C4NUM == ((int)index_input.y) / C4NUM) && (Y == (int)index_input.x)) { | |||
| index++; | |||
| if (index < C4NUM) { | |||
| temp[index] = index_input.y; | |||
| } | |||
| } else { | |||
| for (int j = 0; j <= index && index < C4NUM; ++j) { | |||
| int decimal = (int)temp[j] % C4NUM; | |||
| result_temp[decimal] = weight[weight_index++]; | |||
| X = ((int)temp[0]) / C4NUM; | |||
| } | |||
| FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; | |||
| WRITE_IMAGE(output, (int2)(X, Y), result); | |||
| index = 0; | |||
| Y = (int)index_input.x; | |||
| temp[0] = index_input.y; | |||
| temp[1] = temp[2] = temp[3] = default_value; | |||
| result_temp[0] = result_temp[1] = result_temp[2] = result_temp[3] = default_value; | |||
| } | |||
| } | |||
| // judge the last element for input | |||
| X = ((int)temp[0]) / C4NUM; | |||
| for (int i = 0; i <= index && index < C4NUM; ++i) { | |||
| int decimal = (int)temp[i] % C4NUM; | |||
| result_temp[decimal] = weight[weight_index++]; | |||
| } | |||
| FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]}; | |||
| WRITE_IMAGE(output, (int2)(X, Y), result); | |||
| } | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/opencl/kernel/fill.h" | |||
| #include <cstring> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Fill; | |||
| using mindspore::schema::PrimitiveType_Shape; | |||
| namespace mindspore::kernel { | |||
| int FillOpenCLKernel::RunFill() { | |||
| auto allocator_ = ocl_runtime_->GetAllocator(); | |||
| auto param = reinterpret_cast<FillParameter *>(this->op_parameter_); | |||
| default_ = param->num_dims_; | |||
| std::vector<size_t> img_size; | |||
| cl_float4 fill_value = {}; | |||
| fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_; | |||
| auto src_data = out_tensors_[0]->data_c(); | |||
| allocator_->GetImageSize(src_data, &img_size); | |||
| auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1}; | |||
| cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data)); | |||
| ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); | |||
| return RET_OK; | |||
| } | |||
| int FillOpenCLKernel::RunShape() { | |||
| auto allocator_ = ocl_runtime_->GetAllocator(); | |||
| auto src_data = out_tensors_[0]->data_c(); | |||
| cl_float4 fill_value = {default_, default_, default_, default_}; | |||
| for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) { | |||
| fill_value.s[0] = in_tensors_[0]->shape()[i]; | |||
| size_t index = static_cast<size_t>(i); | |||
| auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{1, 1, 1}; | |||
| cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data)); | |||
| ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FillOpenCLKernel::Init() { | |||
| auto param = this->op_parameter_; | |||
| if (out_tensors_[0]->shape().size() > 4) { | |||
| MS_LOG(ERROR) << " only support dim <= 4"; | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->shape().size() > 1 && param->type_ == PrimitiveType_Fill) { | |||
| MS_LOG(ERROR) << " fill only support dim = 1"; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FillOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->name() << " Running! "; | |||
| auto param = this->op_parameter_; | |||
| if (param->type_ == PrimitiveType_Fill) { | |||
| RunFill(); | |||
| } else { | |||
| RunShape(); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *FillOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| auto *kernel = new (std::nothrow) FillOpenCLKernel(opParameter, inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << " new FillOpenCLKernel failed "; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << " Init kernel failed, name: fill "; | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Fill, FillOpenCLKernelCreator); | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Shape, FillOpenCLKernelCreator); | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Fill, FillOpenCLKernelCreator); | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Shape, FillOpenCLKernelCreator); | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_FILL_H_ | |||
| #include <vector> | |||
| #include "mindspore/lite/nnacl/fp32/fill.h" | |||
| #include "mindspore/lite/nnacl/shape.h" | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| namespace mindspore::kernel { | |||
| class FillOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| FillOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~FillOpenCLKernel() override = default; | |||
| int Init() override; | |||
| int Run() override; | |||
| private: | |||
| int RunFill(); | |||
| int RunShape(); | |||
| cl::Kernel kernel_; | |||
| private: | |||
| float default_{0.0f}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif | |||
| @@ -0,0 +1,203 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/opencl/kernel/sparse_to_dense.h" | |||
| #include <cstring> | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| #include "src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc" | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_SparseToDense; | |||
| namespace mindspore::kernel { | |||
| int SparseToDenseOpenCLKernel::InitOutputToDefault() { | |||
| auto allocator_ = ocl_runtime_->GetAllocator(); | |||
| std::vector<size_t> img_size; | |||
| cl_float4 fill_value = {}; | |||
| fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_; | |||
| auto src_data = out_tensors_[0]->data_c(); | |||
| allocator_->GetImageSize(src_data, &img_size); | |||
| auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1}; | |||
| cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data)); | |||
| ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region); | |||
| return RET_OK; | |||
| } | |||
| int SparseToDenseOpenCLKernel::InitWeights() { | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| auto weight_tensor = in_tensors_[2]; | |||
| size_t size = 1; | |||
| for (int i = 0; i < weight_tensor->shape().size(); ++i) { | |||
| size *= weight_tensor->shape()[i]; | |||
| } | |||
| if (weight_scalar_) { | |||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | |||
| weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c())); | |||
| } else { | |||
| weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c()); | |||
| } | |||
| } else { | |||
| auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | |||
| size_t weight_size = UP_ROUND(size, C4NUM) * sizeof_FLT; | |||
| weight_vector_ = allocator->Malloc(weight_size); | |||
| allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true); | |||
| memset(weight_vector_, 0x00, weight_size); | |||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | |||
| if (enable_fp16_) { | |||
| memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT); | |||
| } else { | |||
| auto weight_fp32 = reinterpret_cast<float *>(weight_vector_); | |||
| auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c()); | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]); | |||
| } | |||
| } | |||
| } else { | |||
| if (enable_fp16_) { | |||
| auto weight_fp16 = reinterpret_cast<float16_t *>(weight_vector_); | |||
| auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c()); | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]); | |||
| } | |||
| } else { | |||
| memcpy(weight_vector_, weight_tensor->data_c(), size * sizeof_FLT); | |||
| } | |||
| } | |||
| allocator->UnmapBuffer(weight_vector_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SparseToDenseOpenCLKernel::Init() { | |||
| if (out_tensors_[0]->shape().size() > 2 || in_tensors_.size() < 3) { | |||
| MS_LOG(ERROR) << " only support dim <= 2 and in_tensors_.size >= 3"; | |||
| return RET_ERROR; | |||
| } | |||
| if ((in_tensors_[0]->shape()[1] > 3) && (input_dim_ == 2)) { | |||
| MS_LOG(ERROR) << "in_tensors_indices shape[1] must be 1 2 or 3 && input_dim_=2 ,but your shapes is: " | |||
| << in_tensors_[0]->shape()[1] << "your input_dim_ is: " << input_dim_; | |||
| return ERROR; | |||
| } | |||
| input_dim_ = in_tensors_[0]->shape().size(); | |||
| weight_scalar_ = in_tensors_[2]->IsScalar(); | |||
| std::string kernel_name = "SparseToDense" + std::string(weight_scalar_ ? "ScalarDim" : "VectorDim") + | |||
| std::to_string(in_tensors_[0]->shape()[1] == 1 ? 1 : input_dim_); | |||
| if (input_dim_ == 2 && in_tensors_[0]->shape()[1] != 1) { | |||
| kernel_name += "Shape" + std::to_string(in_tensors_[0]->shape()[1]); | |||
| } | |||
| std::set<std::string> build_options; | |||
| std::string source = sparse_to_dense_source; | |||
| std::string program_name = "SparseToDense"; | |||
| ocl_runtime_->LoadSource(program_name, source); | |||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| if (in_tensors_.size() > 3) { | |||
| auto input_tensor3 = in_tensors_[3]; | |||
| if (input_tensor3->data_type() == kNumberTypeFloat16) { | |||
| default_ = static_cast<float>(*reinterpret_cast<float16_t *>(input_tensor3->data_c())); | |||
| } else { | |||
| default_ = *reinterpret_cast<float *>(input_tensor3->data_c()); | |||
| } | |||
| } | |||
| InitWeights(); | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return RET_OK; | |||
| } | |||
| int SparseToDenseOpenCLKernel::InferShapeTo4D() { | |||
| if (in_tensors_[0]->shape().size() <= 4) { | |||
| if (in_tensors_[0]->shape().size() == 1) { | |||
| N_ = in_tensors_[0]->shape()[0]; | |||
| } else if (in_tensors_[0]->shape().size() == 2) { | |||
| N_ = in_tensors_[0]->shape()[0]; | |||
| C_ = in_tensors_[0]->shape()[1]; | |||
| } else if (in_tensors_[0]->shape().size() == 3) { | |||
| N_ = in_tensors_[0]->shape()[0]; | |||
| W_ = in_tensors_[0]->shape()[1]; | |||
| C_ = in_tensors_[0]->shape()[2]; | |||
| } else { | |||
| N_ = in_tensors_[0]->shape()[0]; | |||
| H_ = in_tensors_[0]->shape()[1]; | |||
| W_ = in_tensors_[0]->shape()[2]; | |||
| C_ = in_tensors_[0]->shape()[3]; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported inputdim: " << in_tensors_[0]->shape().size(); | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int SparseToDenseOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->name() << " Running! "; | |||
| InferShapeTo4D(); | |||
| cl_int2 input_shape = {static_cast<cl_int>(N_ * H_), static_cast<cl_int>(W_ * UP_DIV(C_, C4NUM))}; | |||
| InitOutputToDefault(); | |||
| std::vector<size_t> local = {1, 1}; | |||
| std::vector<size_t> global = {1, 1}; | |||
| int arg_cn = 0; | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c()); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c()); | |||
| if (weight_scalar_) { | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_scalar_); | |||
| } else { | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_vector_); | |||
| } | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape); | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, default_); | |||
| ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *SparseToDenseOpenCLKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, | |||
| OpParameter *opParameter, const lite::InnerContext *ctx, | |||
| const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| if (inputs.empty()) { | |||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto *kernel = new (std::nothrow) SparseToDenseOpenCLKernel(opParameter, inputs, outputs); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << " new HswishOpenCLKernel failed "; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << " Init kernel failed, name: hswish "; | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator); | |||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SparseToDense, SparseToDenseOpenCLKernelCreator); | |||
| } // namespace mindspore::kernel | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPARSE_TO_DENSE_H_ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||
| #include "mindspore/lite/nnacl/fp32/sparse_to_dense.h" | |||
| namespace mindspore::kernel { | |||
| class SparseToDenseOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| SparseToDenseOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs) | |||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||
| ~SparseToDenseOpenCLKernel() override = default; | |||
| int Init() override; | |||
| int Run() override; | |||
| int InitWeights() override; | |||
| private: | |||
| int InferShapeTo4D(); | |||
| int InitOutputToDefault(); | |||
| private: | |||
| cl::Kernel kernel_; | |||
| // bool IndicesIsScalar{false}; | |||
| bool enable_fp16_{false}; | |||
| float default_{0.0f}; | |||
| float weight_scalar_{0.f}; | |||
| void *weight_vector_{nullptr}; | |||
| int input_dim_{1}; | |||
| std::vector<int32_t> output_shape_; | |||
| size_t N_{1}; | |||
| size_t H_{1}; | |||
| size_t W_{1}; | |||
| size_t C_{1}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif | |||
| @@ -142,6 +142,7 @@ bool LoadLibraryFromPath(const std::string &library_path, void *handle) { | |||
| LOAD_OPENCL_FUNCTION_PTR(clRetainDevice); | |||
| LOAD_OPENCL_FUNCTION_PTR(clReleaseDevice); | |||
| LOAD_OPENCL_FUNCTION_PTR(clCreateImage); | |||
| LOAD_OPENCL_FUNCTION_PTR(clEnqueueFillImage); | |||
| #endif | |||
| #if CL_HPP_TARGET_OPENCL_VERSION >= 200 | |||
| LOAD_OPENCL_FUNCTION_PTR(clCreateCommandQueueWithProperties); | |||
| @@ -228,6 +229,7 @@ CL_DEFINE_FUNC_PTR(clEnqueueCopyImageToBuffer); | |||
| CL_DEFINE_FUNC_PTR(clRetainDevice); | |||
| CL_DEFINE_FUNC_PTR(clReleaseDevice); | |||
| CL_DEFINE_FUNC_PTR(clCreateImage); | |||
| CL_DEFINE_FUNC_PTR(clEnqueueFillImage); | |||
| #endif | |||
| #if CL_HPP_TARGET_OPENCL_VERSION >= 200 | |||
| CL_DEFINE_FUNC_PTR(clGetKernelSubGroupInfoKHR); | |||
| @@ -666,6 +668,14 @@ cl_mem clCreateImage(cl_context context, cl_mem_flags flags, const cl_image_form | |||
| return func(context, flags, image_format, image_desc, host_ptr, errcode_ret); | |||
| } | |||
| cl_int clEnqueueFillImage(cl_command_queue command_queue, cl_mem image, const void *fill_color, const size_t *origin, | |||
| const size_t *region, cl_uint num_events_in_wait_list, const cl_event *event_wait_list, | |||
| cl_event *event) { | |||
| auto func = mindspore::lite::opencl::clEnqueueFillImage; | |||
| MS_ASSERT(func != nullptr); | |||
| return func(command_queue, image, fill_color, origin, region, num_events_in_wait_list, event_wait_list, event); | |||
| } | |||
| #endif | |||
| #if CL_HPP_TARGET_OPENCL_VERSION >= 200 | |||
| @@ -127,6 +127,8 @@ using clRetainDeviceFunc = cl_int (*)(cl_device_id); | |||
| using clReleaseDeviceFunc = cl_int (*)(cl_device_id); | |||
| using clCreateImageFunc = cl_mem (*)(cl_context, cl_mem_flags, const cl_image_format *, const cl_image_desc *, void *, | |||
| cl_int *); | |||
| using clEnqueueFillImageFunc = cl_int (*)(cl_command_queue, cl_mem, const void *, const size_t *, const size_t *, | |||
| cl_uint, const cl_event *, cl_event *); | |||
| #endif | |||
| #if CL_HPP_TARGET_OPENCL_VERSION >= 200 | |||
| using clCreateProgramWithILFunc = cl_program (*)(cl_context, const void *, size_t, cl_int *); | |||
| @@ -199,6 +201,7 @@ CL_DECLARE_FUNC_PTR(clEnqueueCopyImageToBuffer); | |||
| CL_DECLARE_FUNC_PTR(clRetainDevice); | |||
| CL_DECLARE_FUNC_PTR(clReleaseDevice); | |||
| CL_DECLARE_FUNC_PTR(clCreateImage); | |||
| CL_DECLARE_FUNC_PTR(clEnqueueFillImage); | |||
| #endif | |||
| #if CL_HPP_TARGET_OPENCL_VERSION >= 200 | |||
| CL_DECLARE_FUNC_PTR(clGetKernelSubGroupInfoKHR); | |||
| @@ -0,0 +1,145 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/fill.h" | |||
| using mindspore::lite::Tensor; | |||
| using mindspore::schema::PrimitiveType_Fill; | |||
| using mindspore::schema::PrimitiveType_Shape; | |||
| using mindspore::schema::Format::Format_NHWC; | |||
| namespace mindspore { | |||
| class TestFillOpenCLCI : public mindspore::CommonTest { | |||
| public: | |||
| TestFillOpenCLCI() {} | |||
| }; | |||
| TEST_F(TestFillOpenCLCI, Fp32testfill) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {2}; | |||
| float input_data1[] = {3, 3}; | |||
| float correctOutput[] = {9, 9, 9, 9, 9, 9, 9, 9, 9}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {3, 3}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<FillParameter *>(malloc(sizeof(FillParameter))); | |||
| param->num_dims_ = 9; | |||
| param->op_parameter_.type_ = PrimitiveType_Fill; | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new FillParameter failed "; | |||
| return; | |||
| } | |||
| auto *fill_kernel = | |||
| new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (fill_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| fill_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{fill_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete fill_kernel; | |||
| return; | |||
| } | |||
| // to allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestFillOpenCLCI, Fp32testshape) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {2, 4}; | |||
| float input_data1[] = {-0.4045, -0.0924, -0.617, -0.10114, -0.9893, 0.3342, 2.445, -2.182}; | |||
| float correctOutput[] = {2, 4}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {2}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<ShapeParameter *>(malloc(sizeof(ShapeParameter))); | |||
| param->op_parameter_.type_ = PrimitiveType_Shape; | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new FillParameter failed "; | |||
| return; | |||
| } | |||
| auto *fill_kernel = | |||
| new (std::nothrow) kernel::FillOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (fill_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::FillOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| fill_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{fill_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete fill_kernel; | |||
| return; | |||
| } | |||
| // to allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,459 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.h" | |||
| using mindspore::lite::Tensor; | |||
| using mindspore::schema::Format::Format_NHWC; | |||
| namespace mindspore { | |||
| class TestSparseToDenseOpenCLCI : public mindspore::CommonTest { | |||
| public: | |||
| TestSparseToDenseOpenCLCI() {} | |||
| }; | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Scalar) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {6, 2}; | |||
| std::vector<int> input_shape2 = {2}; | |||
| std::vector<int> input_shape3 = {1}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; | |||
| float input_data2[] = {6, 10}; | |||
| float input_data3[] = {6.0}; | |||
| float input_data4[] = {0.0}; | |||
| float correctOutput[] = {6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {6, 10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Vector) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {6, 2}; | |||
| std::vector<int> input_shape2 = {2}; | |||
| std::vector<int> input_shape3 = {6}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {0, 0, 1, 2, 2, 3, 3, 6, 4, 7, 5, 9}; | |||
| float input_data2[] = {6, 10}; | |||
| float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; | |||
| float input_data4[] = {0.0}; | |||
| float correctOutput[] = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {6, 10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Shape1Vector) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {6, 1}; | |||
| std::vector<int> input_shape2 = {1}; | |||
| std::vector<int> input_shape3 = {6}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {0, 2, 3, 6, 7, 9}; | |||
| float input_data2[] = {10}; | |||
| float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; | |||
| float input_data4[] = {0.0}; | |||
| float correctOutput[] = {1, 0, 2, 3, 0, 0, 4, 5, 0, 6}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim2Shape1Scalar) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {7, 1}; // shape[1] = 1 | |||
| std::vector<int> input_shape2 = {1}; | |||
| std::vector<int> input_shape3 = {1}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {0, 1, 2, 3, 4, 5, 9}; | |||
| float input_data2[] = {10}; | |||
| float input_data3[] = {6.0}; | |||
| float input_data4[] = {0.0}; | |||
| float correctOutput[] = {6, 6, 6, 6, 6, 6, 0, 0, 0, 6}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| std::vector<int> output_shape = {10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, lite::Tensor::VAR); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, lite::Tensor::CONST_TENSOR); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, lite::Tensor::VAR); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim1Scalar) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {6}; | |||
| std::vector<int> input_shape2 = {1}; | |||
| std::vector<int> input_shape3 = {1}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {1, 3, 4, 5, 6, 7}; | |||
| float input_data2[] = {10}; | |||
| float input_data3[] = {1.0}; | |||
| float input_data4[] = {2.0}; | |||
| float correctOutput[] = {2, 1, 2, 1, 1, 1, 1, 1, 2, 2}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = lite::Tensor::CONST_TENSOR; | |||
| std::vector<int> output_shape = {10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, tensor_type); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, tensor_type); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, lite::Tensor::CONST_SCALAR); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, tensor_type); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, tensor_type); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| TEST_F(TestSparseToDenseOpenCLCI, Fp32Dim1Vector) { | |||
| MS_LOG(INFO) << " begin test "; | |||
| auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||
| auto runtime = runtime_wrapper.GetInstance(); | |||
| runtime->Init(); | |||
| auto allocator = runtime->GetAllocator(); | |||
| MS_LOG(INFO) << " init tensors "; | |||
| std::vector<int> input_shape1 = {6}; | |||
| std::vector<int> input_shape2 = {1}; | |||
| std::vector<int> input_shape3 = {6}; | |||
| std::vector<int> input_shape4 = {1}; | |||
| float input_data1[] = {1, 3, 4, 5, 6, 7}; | |||
| float input_data2[] = {10}; | |||
| float input_data3[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; | |||
| float input_data4[] = {2.0}; | |||
| float correctOutput[] = {2, 1, 2, 2, 3, 4, 5, 6, 2, 2}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensor_type = lite::Tensor::CONST_TENSOR; | |||
| std::vector<int> output_shape = {10}; | |||
| auto in_tensor1 = Tensor(data_type, input_shape1, Format_NHWC, tensor_type); | |||
| auto in_tensor2 = Tensor(data_type, input_shape2, Format_NHWC, tensor_type); | |||
| auto in_tensor3 = Tensor(data_type, input_shape3, Format_NHWC, tensor_type); | |||
| auto in_tensor4 = Tensor(data_type, input_shape4, Format_NHWC, tensor_type); | |||
| auto output_tensor = Tensor(data_type, output_shape, Format_NHWC, tensor_type); | |||
| // allocate memory for weights | |||
| in_tensor2.MallocData(); | |||
| in_tensor3.MallocData(); | |||
| in_tensor4.MallocData(); | |||
| std::vector<lite::Tensor *> inputs{&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}; | |||
| std::vector<lite::Tensor *> outputs{&output_tensor}; | |||
| // initialize weights | |||
| memcpy(inputs[1]->data_c(), input_data2, sizeof(input_data2)); | |||
| memcpy(inputs[2]->data_c(), input_data3, sizeof(input_data3)); | |||
| memcpy(inputs[3]->data_c(), input_data4, sizeof(input_data4)); | |||
| MS_LOG(INFO) << " initialize tensors "; | |||
| auto param = reinterpret_cast<SparseToDenseParameter *>(malloc(sizeof(SparseToDenseParameter))); | |||
| if (param == nullptr) { | |||
| MS_LOG(INFO) << " new ActivationParameter failed "; | |||
| return; | |||
| } | |||
| auto *sparse_to_dense_kernel = | |||
| new (std::nothrow) kernel::SparseToDenseOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| if (sparse_to_dense_kernel == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SparseToDenseOpenCLKernel failed "; | |||
| delete param; | |||
| return; | |||
| } | |||
| sparse_to_dense_kernel->Init(); | |||
| MS_LOG(INFO) << " initialize sub_graph "; | |||
| std::vector<kernel::LiteKernel *> kernels{sparse_to_dense_kernel}; | |||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({&in_tensor1}, outputs, kernels, kernels, kernels); | |||
| if (sub_graph == nullptr) { | |||
| MS_LOG(INFO) << " new kernel::SubGraphOpenCLKernel failed "; | |||
| delete param; | |||
| delete sparse_to_dense_kernel; | |||
| return; | |||
| } | |||
| // to do allocate memory for inputs | |||
| in_tensor1.MallocData(allocator); | |||
| sub_graph->Init(); | |||
| MS_LOG(INFO) << " initialize input data "; | |||
| memcpy(inputs[0]->data_c(), input_data1, sizeof(input_data1)); | |||
| std::cout << "==================output data================" << std::endl; | |||
| sub_graph->Run(); | |||
| auto *output_data_gpu = reinterpret_cast<float *>(output_tensor.data_c()); | |||
| CompareOutputData(output_data_gpu, correctOutput, output_tensor.ElementsNum(), 0.0001); | |||
| delete sub_graph; | |||
| } | |||
| } // namespace mindspore | |||