Browse Source

!11438 【MS】【LITE】【GPU】optimize malloc api

From: @wangdongxu6
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
da92f1affb
31 changed files with 299 additions and 496 deletions
  1. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h
  3. +36
    -89
      mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc
  4. +4
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h
  5. +37
    -85
      mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc
  6. +2
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h
  7. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  8. +3
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
  9. +12
    -5
      mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc
  10. +3
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc
  11. +3
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc
  12. +7
    -5
      mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc
  13. +0
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h
  14. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h
  15. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h
  16. +47
    -78
      mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc
  17. +1
    -1
      mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h
  18. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h
  19. +4
    -3
      mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc
  20. +4
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc
  21. +3
    -2
      mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc
  22. +14
    -27
      mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc
  23. +2
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h
  24. +3
    -3
      mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc
  25. +4
    -3
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc
  26. +2
    -1
      mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h
  27. +32
    -59
      mindspore/lite/src/runtime/kernel/opencl/utils.cc
  28. +5
    -57
      mindspore/lite/src/runtime/kernel/opencl/utils.h
  29. +37
    -42
      mindspore/lite/src/runtime/opencl/opencl_allocator.cc
  30. +24
    -11
      mindspore/lite/src/runtime/opencl/opencl_allocator.h
  31. +2
    -2
      mindspore/lite/src/runtime/opencl/opencl_runtime.cc

+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/cl/space_to_depth.cl View File

@@ -63,7 +63,7 @@ __kernel void DepthToSpace(__read_only image2d_t src_data, __write_only image2d_
int Y = get_global_id(1); // W
int Z = get_global_id(2); // H * N
if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;
if (out_shape.y == 0 || co_size == 0) return;
if (out_shape.y == 0 || block_size == 0) return;
int N = Z / out_shape.y;
int H = Z % out_shape.y;
int co_base = X * C4NUM;


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/activation.h View File

@@ -45,7 +45,7 @@ class ActivationOpenCLKernel : public OpenCLKernel {
static std::string GetActTypeString(int act_type);
int type_;
float alpha_;
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
GpuTensorInfo outShape;
};

} // namespace mindspore::kernel


+ 36
- 89
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.cc View File

@@ -30,6 +30,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::ActivationType_NO_ACTIVATION;
using mindspore::schema::ActivationType_RELU;
@@ -45,7 +46,7 @@ int ArithmeticOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (param->broadcasting_ && out_tensors_[0]->shape()[0] > 1) {
if (param->broadcasting_ && out_tensors_.front()->DimensionSize(0) > 1) {
MS_LOG(ERROR) << "Broadcasting don't support N > 1";
return RET_ERROR;
}
@@ -63,85 +64,29 @@ int ArithmeticOpenCLKernel::CheckSpecs() {

void ArithmeticOpenCLKernel::SetGlobalLocal() {
if (element_flag_) {
local_size_ = {};
auto out_shape = out_tensors_[0]->shape();
if (out_shape.size() == 2) {
size_t H = out_shape[0];
size_t W = UP_DIV(out_shape[1], C4NUM);
global_size_ = {W, H};
} else {
size_t H = out_shape[0] * out_shape[1];
size_t W = out_shape[2] * UP_DIV(out_shape[3], C4NUM);
global_size_ = {W, H};
}
global_size_ = {out_shape_.width, out_shape_.height};
} else {
local_size_ = {};
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
global_size_ = {static_cast<size_t>(UP_DIV(out_shape[3], C4NUM)), static_cast<size_t>(out_shape[2]),
static_cast<size_t>(out_shape[1] * out_shape[0])};
global_size_ = {out_shape_.Slice, out_shape_.W, out_shape_.H * out_shape_.N};
}
AlignGlobalLocal(global_size_, local_size_);
AlignGlobalLocal(global_size_, {});
}

int ArithmeticOpenCLKernel::InitWeights() {
auto allocator = ocl_runtime_->GetAllocator();
auto fp16_enable = ocl_runtime_->GetFp16Enable();
auto data_size = fp16_enable ? sizeof(float16_t) : sizeof(float);
for (auto in_tensor_ : in_tensors_) {
auto nhwc_shape = GetNHWCShape(in_tensor_->shape());
inputs_nhwc_shapes_.push_back(nhwc_shape);
if (!in_tensor_->IsConst()) {
inputs_weight_ptrs_.push_back(nullptr);
for (int i = 0; i < 2; ++i) {
const auto &in_tensor = in_tensors_.at(i);
GpuTensorInfo *in_shape = (i == 0) ? &in0_shape_ : &in1_shape_;
if (in_tensor->IsConst()) {
std::vector<char> weight(in_shape->Image2DSize, 0);
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, *in_shape);
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{in_shape->width, in_shape->height, dtype};
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
weight_ptrs_.push_back(weight_ptr_);
} else {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size = GetImage2dShapeFromNHWC(nhwc_shape, schema::Format_NHWC4);
int pack_weight_size = img_size[0] * img_size[1] * C4NUM;
int plane = nhwc_shape[1] * nhwc_shape[2];
int channel = nhwc_shape[3];
int batch = nhwc_shape[0];
img_size.push_back(fp16_enable ? CL_HALF_FLOAT : CL_FLOAT);
if (!fp16_enable) {
float *weight = new (std::nothrow) float[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNHWCToNHWC4<float16_t, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
inputs_weight_ptrs_.push_back(weight_ptr_);
delete[] weight;
} else {
float16_t *weight = new (std::nothrow) float16_t[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float16_t(float16_t)> to_dtype = [](float16_t x) -> float16_t { return x; };
PackNHWCToNHWC4<float16_t, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
inputs_weight_ptrs_.push_back(weight_ptr_);
delete[] weight;
}
weight_ptrs_.push_back(nullptr);
}
}
return RET_OK;
@@ -150,21 +95,21 @@ int ArithmeticOpenCLKernel::InitWeights() {
void ArithmeticOpenCLKernel::SetConstArgs() {
int arg_idx = 3;
if (!element_flag_) {
cl_int4 input0_shape = {inputs_nhwc_shapes_[0][0], inputs_nhwc_shapes_[0][1], inputs_nhwc_shapes_[0][2],
UP_DIV(inputs_nhwc_shapes_[0][3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input0_shape);
cl_int4 input1_shape = {inputs_nhwc_shapes_[1][0], inputs_nhwc_shapes_[1][1], inputs_nhwc_shapes_[1][2],
UP_DIV(inputs_nhwc_shapes_[1][3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input1_shape);
auto out_shape = GetNHWCShape(out_tensors_[0]->shape());
cl_int4 output_shape{out_shape[0], out_shape[1], out_shape[2], UP_DIV(out_shape[3], C4NUM)};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
cl_int4 in0_shape = {static_cast<int>(in0_shape_.N), static_cast<int>(in0_shape_.H), static_cast<int>(in0_shape_.W),
static_cast<int>(in0_shape_.Slice)};
cl_int4 in1_shape = {static_cast<int>(in1_shape_.N), static_cast<int>(in1_shape_.H), static_cast<int>(in1_shape_.W),
static_cast<int>(in1_shape_.Slice)};
cl_int4 out_shape = {static_cast<int>(out_shape_.N), static_cast<int>(out_shape_.H), static_cast<int>(out_shape_.W),
static_cast<int>(out_shape_.Slice)};
int broadcastC_flag = 0; // do not need broadcast in C4
if (inputs_nhwc_shapes_[0][3] == 1 && inputs_nhwc_shapes_[1][3] != 1) {
if (in0_shape_.C == 1 && in1_shape_.C != 1) {
broadcastC_flag = 1; // BroadCast C4 in input0
} else if (inputs_nhwc_shapes_[0][3] != 1 && inputs_nhwc_shapes_[1][3] == 1) {
} else if (in0_shape_.C != 1 && in1_shape_.C == 1) {
broadcastC_flag = 2; // BroadCast C4 in input1
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in0_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in1_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_shape);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, broadcastC_flag);
} else {
cl_int2 output_shape{static_cast<int>(global_range_[0]), static_cast<int>(global_range_[1])};
@@ -175,11 +120,14 @@ void ArithmeticOpenCLKernel::SetConstArgs() {
}

int ArithmeticOpenCLKernel::Prepare() {
lite::STATUS error_code = RET_OK;
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name_);
#else

in0_shape_ = GpuTensorInfo(in_tensors_[0]);
in1_shape_ = GpuTensorInfo(in_tensors_[1]);
out_shape_ = GpuTensorInfo(out_tensors_[0]);

auto *param = reinterpret_cast<const ArithmeticParameter *>(op_parameter_);
if (Type() == PrimitiveType_BiasAdd) {
const_cast<ArithmeticParameter *>(param)->broadcasting_ = true;
@@ -197,7 +145,7 @@ int ArithmeticOpenCLKernel::Prepare() {
std::string program_name = "Arithmetic";
std::string source = arithmetic_source;
ocl_runtime_->LoadSource(program_name, source);
error_code = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name_);
int error_code = ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name_);
#endif
if (error_code != RET_OK) {
return error_code;
@@ -212,11 +160,10 @@ int ArithmeticOpenCLKernel::Prepare() {

int ArithmeticOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";

auto input_0_ptr = weight_ptrs_[0] == nullptr ? in_tensors_[0]->data_c() : weight_ptrs_[0];
auto input_1_ptr = weight_ptrs_[1] == nullptr ? in_tensors_[1]->data_c() : weight_ptrs_[1];
int arg_idx = 0;
auto input_0_ptr = inputs_weight_ptrs_[0] == nullptr ? in_tensors_[0]->data_c() : inputs_weight_ptrs_[0];
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_0_ptr);
auto input_1_ptr = inputs_weight_ptrs_[1] == nullptr ? in_tensors_[1]->data_c() : inputs_weight_ptrs_[1];
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_1_ptr);
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);


+ 4
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h View File

@@ -43,8 +43,10 @@ class ArithmeticOpenCLKernel : public OpenCLKernel {
bool element_flag_{true};
float activation_min_{-FLT_MAX};
float activation_max_{FLT_MAX};
std::vector<std::vector<int>> inputs_nhwc_shapes_;
std::vector<void *> inputs_weight_ptrs_;
GpuTensorInfo in0_shape_;
GpuTensorInfo in1_shape_;
GpuTensorInfo out_shape_;
std::vector<void *> weight_ptrs_;
std::string kernel_name_;
};
} // namespace mindspore::kernel


+ 37
- 85
mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc View File

@@ -18,7 +18,6 @@
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/concat.cl.inc"
@@ -27,22 +26,23 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Concat;

namespace mindspore::kernel {

int ConcatOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
auto dst_data = out_tensors_[0]->data_c();
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
auto *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
for (int i = 0; i < in_tensors_.size(); i++) {
auto src_data = inputs_weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : inputs_weight_ptrs_.at(i);
auto src_data = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i);
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
auto *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region);
dst_origin[1] += region[1];
}
@@ -75,8 +75,8 @@ int ConcatOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << " GPU Unsupported shape.size > 4 ";
return RET_ERROR;
}
for (int i = 0; i < in_tensors_.size(); ++i) {
auto in_tensors_shape_size = in_tensors_[i]->shape().size();
for (auto &in_tensor : in_tensors_) {
auto in_tensors_shape_size = in_tensor->shape().size();
if (in_tensors_shape_size > 4) {
MS_LOG(ERROR) << " GPU Unsupported in_tensor shape.size > 4 ";
return RET_ERROR;
@@ -109,7 +109,7 @@ int ConcatOpenCLKernel::CheckSpecs() {

void ConcatOpenCLKernel::SetConstArgs() {
GpuTensorInfo img_info(out_tensors_[0]);
size_t dtype = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
size_t dtype = ocl_runtime_->GetFp16Enable() ? sizeof(cl_half) : sizeof(cl_float);
stride_w = img_info.RowPitch() / dtype;
cl_int4 output_shape_ = {};
for (int i = 0; i < out_tensors_[0]->shape().size(); ++i) {
@@ -118,22 +118,22 @@ void ConcatOpenCLKernel::SetConstArgs() {
Broadcast2GpuShape(out_shape_.s, output_shape_.s, out_tensors_[0]->shape().size(), 1);
int arg_cn = in_tensors_.size() + 1;
if (axis_ == 3 && !Align_) {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensors_[i]->shape().size(); ++j) {
temp.s[j] = in_tensors_[i]->shape()[j];
for (int j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensors_[i]->shape().size(), 1);
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
}
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, stride_w);
} else {
for (int i = 0; i < in_tensors_.size(); ++i) {
for (auto &in_tensor : in_tensors_) {
cl_int4 temp = {};
for (int j = 0; j < in_tensors_[i]->shape().size(); ++j) {
temp.s[j] = in_tensors_[i]->shape()[j];
for (int j = 0; j < in_tensor->shape().size(); ++j) {
temp.s[j] = in_tensor->shape()[j];
}
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensors_[i]->shape().size(), 1);
Broadcast2GpuShape(in_shape_.s, temp.s, in_tensor->shape().size(), 1);
in_shape_.s[3] = UP_DIV(in_shape_.s[3], C4NUM);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
}
@@ -160,84 +160,36 @@ void ConcatOpenCLKernel::SetGlobalLocal() {
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
}

int ConcatOpenCLKernel::ConvertWeightToTensor(const std::vector<lite::Tensor *> &in_tensors,
std::vector<void *> *inputs_weight_ptrs, bool fp16_enable,
size_t data_size) {
for (auto in_tensor_ : in_tensors) {
auto nhwc_shape = GetNHWCShape(in_tensor_->shape());
if (!in_tensor_->IsConst()) {
(*inputs_weight_ptrs).push_back(nullptr);
int ConcatOpenCLKernel::ConvertWeightToTensor() {
auto allocator = ocl_runtime_->GetAllocator();
bool fp16_enable = ocl_runtime_->GetFp16Enable();
for (auto in_tensor : in_tensors_) {
auto in_shape = GpuTensorInfo(in_tensor);
if (in_tensor->IsConst()) {
std::vector<char> weight(in_shape.Image2DSize, 0);
bool src_is_fp16 = in_tensor->data_type() == kNumberTypeFloat16;
PackNHWCToNHWC4(in_tensor->data_c(), weight.data(), src_is_fp16, fp16_enable, in_shape);
size_t dtype = fp16_enable ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{in_shape.width, in_shape.height, dtype};
auto weight_ptr_ = allocator->Malloc(img_size, weight.data());
weight_ptrs_.push_back(weight_ptr_);
} else {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size = GetImage2dShapeFromNHWC(nhwc_shape, schema::Format_NHWC4);
int pack_weight_size = img_size[0] * img_size[1] * C4NUM;
int plane = nhwc_shape[1] * nhwc_shape[2];
int channel = nhwc_shape[3];
int batch = nhwc_shape[0];
img_size.push_back(fp16_enable ? CL_HALF_FLOAT : CL_FLOAT);
if (!fp16_enable) {
float *weight = new (std::nothrow) float[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float(float16_t)> to_dtype = [](float16_t x) -> float { return static_cast<float>(x); };
PackNHWCToNHWC4<float16_t, float>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
(*inputs_weight_ptrs).push_back(weight_ptr_);
delete[] weight;
} else {
float16_t *weight = new (std::nothrow) float16_t[pack_weight_size];
if (weight == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
memset(weight, 0x00, pack_weight_size * data_size);
if (in_tensor_->data_type() == kNumberTypeFloat32) {
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
} else if (in_tensor_->data_type() == kNumberTypeFloat16) {
std::function<float16_t(float16_t)> to_dtype = [](float16_t x) -> float16_t { return x; };
PackNHWCToNHWC4<float16_t, float16_t>(in_tensor_->data_c(), weight, batch, plane, channel, to_dtype);
}
if (batch * plane * channel == 1) {
// scalar
weight[3] = weight[2] = weight[1] = weight[0];
}
auto weight_ptr_ = allocator->Malloc(pack_weight_size, img_size, weight);
(*inputs_weight_ptrs).push_back(weight_ptr_);
delete[] weight;
}
weight_ptrs_.push_back(nullptr);
}
}
return RET_OK;
}

int ConcatOpenCLKernel::Prepare() {
enable_fp16_ = ocl_runtime_->GetFp16Enable();
auto data_size = enable_fp16_ ? sizeof(float16_t) : sizeof(float);
ConvertWeightToTensor(in_tensors_, &inputs_weight_ptrs_, enable_fp16_, data_size);
ConvertWeightToTensor();
if (axis_ == 0) {
for (int i = 0; i < in_tensors_.size(); ++i) {
if (in_tensors_.at(i)->shape().size() != 1) {
return RET_OK;
}
if (std::any_of(in_tensors_.begin(), in_tensors_.end(), [](lite::Tensor *t) { return t->shape().size() != 1; })) {
return RET_OK;
}
axis_ = 3;
}
for (int i = 0; i < in_tensors_.size(); ++i) {
int length = in_tensors_[0]->shape().size();
if (in_tensors_[i]->shape()[length - 1] % C4NUM != 0) {
for (auto const &in_tensor : in_tensors_) {
if (in_tensor->shape().back() % C4NUM != 0) {
Align_ = false;
}
}
@@ -268,7 +220,7 @@ int ConcatOpenCLKernel::Run() {
}
int arg_cn = 0;
for (int i = 0; i < in_tensors_.size(); ++i) {
auto input_ptr = inputs_weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : inputs_weight_ptrs_.at(i);
auto input_ptr = weight_ptrs_.at(i) == nullptr ? in_tensors_[i]->data_c() : weight_ptrs_.at(i);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_ptr);
}
if (axis_ == 3 && !Align_) {


+ 2
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h View File

@@ -43,8 +43,7 @@ class ConcatOpenCLKernel : public OpenCLKernel {
uint32_t OC = {1};
std::vector<size_t> global;
bool Align_{true};
std::vector<void *> inputs_weight_ptrs_;
bool enable_fp16_{false};
std::vector<void *> weight_ptrs_;
cl_int stride_w{1};
cl_int4 in_shape_{};
cl_int4 out_shape_{};
@@ -52,8 +51,7 @@ class ConcatOpenCLKernel : public OpenCLKernel {

private:
int RunAxis0();
int ConvertWeightToTensor(const std::vector<lite::Tensor *> &in_tensors, std::vector<void *> *inputs_weight_ptrs,
bool fp16_enable, size_t data_size);
int ConvertWeightToTensor();
};

} // namespace mindspore::kernel


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -255,7 +255,7 @@ void Conv2DOpenCLKernel::InitFilter() {
size_t height = KH_ * KW_ * UP_ROUND(CI_, CI_TILE);
size_t dtype = use_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
size = width * height * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size, {width, height, dtype});
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * KH_ * KW_ * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);


+ 3
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc View File

@@ -28,6 +28,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::PrimitiveType_DeConv2D;
@@ -193,8 +194,8 @@ int Conv2dTransposeOpenCLKernel::InitWeights() {
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size);
ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, div_co * C4NUM * data_size);
if (in_tensors_.size() == 3) {


+ 12
- 5
mindspore/lite/src/runtime/kernel/opencl/kernel/depthwise_conv2d.cc View File

@@ -37,6 +37,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;

@@ -61,6 +62,7 @@ int DepthwiseConv2dOpenCLKernel::CheckSpecs() {
}
return RET_OK;
}

int DepthwiseConv2dOpenCLKernel::Prepare() {
std::string kernel_name = "DepthwiseConv2d";
if (out_mem_type_ == MemType::BUF) {
@@ -114,13 +116,10 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {

int plane_in = parameter->kernel_h_ * parameter->kernel_w_;
int plane_out = plane_in * C4NUM;
std::vector<size_t> img_size;
if (filter_type_ == MemType::IMG) {
int alignment = ocl_runtime_->GetImagePitchAlignment();
plane_out = UP_ROUND(plane_out, alignment) * C4NUM;
pack_weight_size = plane_out * CO4;
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
img_size = {(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
}
pack_weight_size = pack_weight_size * dtype_size;
auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out,
@@ -153,7 +152,13 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto src_type = in_tensors_.at(kWeightIndex)->data_type();
auto dst_type = is_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32;
ConvertFilter(origin_weight, temp_filter.data(), src_type, dst_type, plane_in, plane_out, out_info.C);
packed_weight_ = allocator->Malloc(pack_weight_size, img_size, temp_filter.data());
if (filter_type_ == MemType::IMG) {
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
packed_weight_ = allocator->Malloc(img_size, temp_filter.data());
} else {
packed_weight_ = allocator->Malloc(pack_weight_size, temp_filter.data());
}
FreeDequantedWeight();
if (packed_weight_ == nullptr) {
return RET_ERROR;
@@ -182,12 +187,13 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum();
ConvertBias(in_tensors_.at(kBiasIndex)->data_c(), temp_bias.data(), element_size, dtype_size, src_type, dst_type);
}
bias_data_ = allocator->Malloc(bias_size, {}, temp_bias.data());
bias_data_ = allocator->Malloc(bias_size, temp_bias.data());
if (bias_data_ == nullptr) {
return RET_ERROR;
}
return mindspore::lite::RET_OK;
}

void DepthwiseConv2dOpenCLKernel::SetConstArgs() {
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter_);
auto in_info = GpuTensorInfo(in_tensors_[0]);
@@ -216,6 +222,7 @@ void DepthwiseConv2dOpenCLKernel::SetConstArgs() {
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, relu_clips[parameter->act_type_].first);
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, relu_clips[parameter->act_type_].second);
}

void DepthwiseConv2dOpenCLKernel::SetGlobalLocal() {
auto out_info = GpuTensorInfo(out_tensors_[0]);
// set global


+ 3
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc View File

@@ -26,6 +26,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Fill;
using mindspore::schema::PrimitiveType_Shape;

@@ -35,13 +36,13 @@ int FillOpenCLKernel::RunFill() {
auto allocator_ = ocl_runtime_->GetAllocator();
auto param = reinterpret_cast<FillParameter *>(this->op_parameter_);
default_ = param->num_dims_;
std::vector<size_t> img_size;
ImageSize img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;


+ 3
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc View File

@@ -29,6 +29,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::ActivationType_RELU;
using mindspore::schema::ActivationType_RELU6;
using mindspore::schema::ActivationType_TANH;
@@ -211,8 +212,8 @@ int FullConnectionOpenCLKernel::InitBias() {
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
std::vector<size_t> img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size);
ImageSize img_size{im_dst_x, im_dst_y, img_dtype};
bias_ = allocator->Malloc(img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, co4 * C4NUM * dtype_size);
if (in_tensors_.size() == 3) {


+ 7
- 5
mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc View File

@@ -27,6 +27,8 @@

using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_MatMul;

namespace mindspore::kernel {
@@ -55,7 +57,7 @@ int MatMulOpenCLKernel::CheckSpecs() {
transposeA = param->a_transpose_;
if (transposeA) {
MS_LOG(ERROR) << "matmul only support a_transpose_=false yet.";
return mindspore::lite::RET_ERROR;
return RET_ERROR;
}
transposeB = param->b_transpose_;
act_weight_ = !in_tensors_[1]->IsConst();
@@ -63,7 +65,7 @@ int MatMulOpenCLKernel::CheckSpecs() {
if (in_tensors_[0]->shape().size() != out_tensors_[0]->shape().size() || in_tensors_[0]->shape().size() < 2 ||
in_tensors_[0]->shape().size() > 4) {
MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4.";
return mindspore::lite::RET_ERROR;
return RET_ERROR;
}
return RET_OK;
}
@@ -100,7 +102,7 @@ int MatMulOpenCLKernel::Prepare() {
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return mindspore::lite::RET_OK;
return RET_OK;
}

int MatMulOpenCLKernel::InitWeights() {
@@ -207,7 +209,7 @@ int MatMulOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_count++, in_tensors_[1]->data_c());
}
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return mindspore::lite::RET_OK;
return RET_OK;
}

kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *> &inputs,
@@ -244,7 +246,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector<lite::Tensor *>
return kernel;
}
auto ret = kernel->CheckSpecs();
if (ret != mindspore::lite::RET_OK) {
if (ret != RET_OK) {
MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!";
delete kernel;
return nullptr;


+ 0
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.h View File

@@ -22,7 +22,6 @@
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/common/utils.h"
#include "nnacl/matmul_parameter.h"
#define MAXDEPTH 5

namespace mindspore::kernel {



+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/one_hot.h View File

@@ -41,8 +41,8 @@ class OneHotOpenCLKernel : public OpenCLKernel {
float on_value_{1.0f};
float off_value_{0.0f};
int axis_{0};
GpuTensorInfo in_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo in_shape_;
GpuTensorInfo out_shape_;
};
} // namespace mindspore::kernel



+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/reduce.h View File

@@ -39,7 +39,7 @@ class ReduceOpenCLKernel : public OpenCLKernel {
private:
cl_float4 GenC4Mask();
static std::string GetReduceTypeStr(int type);
GpuTensorInfo outShape = GpuTensorInfo(nullptr);
GpuTensorInfo outShape;
bool use_local_{false};
bool wc_reduce_{false};
static const size_t LOCAL_CACHE_THREAD{16};


+ 47
- 78
mindspore/lite/src/runtime/kernel/opencl/kernel/scale.cc View File

@@ -30,6 +30,7 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::lite::opencl::MemType;
using mindspore::schema::PrimitiveType_Scale;

@@ -64,87 +65,55 @@ void ScaleOpenCLKernel::Image2dGetWorkGroupSize() {
}

int ScaleOpenCLKernel::InitWeights() {
if (!weight_vector_flag_) {
auto *in_tensor = in_tensors_[0];
auto *scale_tensor = in_tensors_[1];
auto *offset_tensor = in_tensors_[2];
auto scale_dtype = scale_tensor->data_type();
if (!weight_vector_flag_ || !scale_tensor->IsConst()) {
return RET_OK;
}
if (in_tensors_[1]->IsConst()) {
auto allocator = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
GetImageSize(0, &img_size);
img_size[2] = in_tensors_[1]->data_type() == kNumberTypeFloat16 ? CL_HALF_FLOAT : CL_FLOAT;
if (broadcast_flag_) {
img_size[1] = 1;
img_size[0] = UP_DIV(in_tensors_[1]->shape()[0], C4NUM);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, in_tensors_[1]->data_c());
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, in_tensors_[2]->data_c());
return RET_OK;
auto allocator = ocl_runtime_->GetAllocator();
auto fp16_enable = ocl_runtime_->GetFp16Enable();
ImageSize img_size;
GetImageSize(0, &img_size);
img_size.dtype = scale_dtype == kNumberTypeFloat16 ? CL_HALF_FLOAT : CL_FLOAT;

if (broadcast_flag_) {
img_size.height = 1;
img_size.width = UP_DIV(scale_tensor->shape()[0], C4NUM);
scale_ptr_ = allocator->Malloc(img_size, scale_tensor->data_c());
offset_ptr_ = allocator->Malloc(img_size, offset_tensor->data_c());
return RET_OK;
}

if (in_tensor->format() == scale_tensor->format()) {
if (in_tensor->data_type() == scale_tensor->data_type()) {
scale_ptr_ = allocator->Malloc(img_size, scale_tensor->data_c());
offset_ptr_ = allocator->Malloc(img_size, offset_tensor->data_c());
} else {
MS_LOG(ERROR) << "Unsupported data type transpose from " << scale_tensor->data_type() << "to "
<< in_tensor->data_type();
return RET_ERROR;
}
auto image2d_info = GpuTensorInfo(in_tensors_[1]);
int pack_weight_size = image2d_info.ElementsC4Num;
int plane = image2d_info.H * image2d_info.W;
int channel = image2d_info.C;
int batch = image2d_info.N;
if (in_tensors_[0]->format() == in_tensors_[1]->format()) {
if (in_tensors_[0]->data_type() == in_tensors_[1]->data_type()) {
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, in_tensors_[1]->data_c());
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, in_tensors_[2]->data_c());
} else {
MS_LOG(ERROR) << "Unsupport data type transpose from " << in_tensors_[1]->data_type() << "to "
<< in_tensors_[0]->data_type();
return RET_ERROR;
}
} else if (in_tensors_[0]->format() == schema::Format_NHWC) {
if (in_tensors_[1]->format() == schema::Format_NHWC) {
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
auto *scale = new (std::nothrow) float[pack_weight_size];
if (scale == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
auto *offset = new (std::nothrow) float[pack_weight_size];
if (offset == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
delete[] scale;
return RET_ERROR;
}
std::function<float(float)> to_dtype = [](float x) -> float { return x; };
PackNHWCToNHWC4<float, float>(in_tensors_[1]->data_c(), scale, batch, plane, channel, to_dtype);
PackNHWCToNHWC4<float, float>(in_tensors_[2]->data_c(), offset, batch, plane, channel, to_dtype);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, scale);
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, offset);
delete[] scale;
delete[] offset;
} else if (in_tensors_[0]->data_type() == kNumberTypeFloat16) {
auto *scale = new (std::nothrow) float16_t[pack_weight_size];
if (scale == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
return RET_ERROR;
}
auto *offset = new (std::nothrow) float16_t[pack_weight_size];
if (offset == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed!";
delete[] scale;
return RET_ERROR;
}
std::function<float16_t(float)> to_dtype = [](float x) -> float16_t { return static_cast<float16_t>(x); };
PackNHWCToNHWC4<float, float16_t>(in_tensors_[1]->data_c(), scale, batch, plane, channel, to_dtype);
PackNHWCToNHWC4<float, float16_t>(in_tensors_[2]->data_c(), offset, batch, plane, channel, to_dtype);
scale_ptr_ = allocator->Malloc(in_tensors_[1]->ElementsNum(), img_size, scale);
offset_ptr_ = allocator->Malloc(in_tensors_[2]->ElementsNum(), img_size, offset);
delete[] scale;
delete[] offset;
} else {
MS_LOG(ERROR) << "Unsupport data type transpose from " << in_tensors_[1]->data_type() << "to "
<< in_tensors_[0]->data_type();
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupport format transpose from " << in_tensors_[1]->format() << "to "
<< in_tensors_[0]->format();
return RET_ERROR;
}
} else if (in_tensor->format() == schema::Format_NHWC && scale_tensor->format() == schema::Format_NHWC) {
if (scale_dtype == kNumberTypeFloat32 || scale_dtype == kNumberTypeFloat16) {
auto image2d_info = GpuTensorInfo(scale_tensor);
int pack_weight_size = image2d_info.ElementsC4Num;
std::vector<char> scale(pack_weight_size, 0);
std::vector<char> offset(pack_weight_size, 0);
bool src_is_fp16 = scale_dtype == kNumberTypeFloat16;
PackNHWCToNHWC4(scale_tensor->data_c(), scale.data(), src_is_fp16, fp16_enable, image2d_info);
PackNHWCToNHWC4(offset_tensor->data_c(), offset.data(), src_is_fp16, fp16_enable, image2d_info);
scale_ptr_ = allocator->Malloc(img_size, scale.data());
offset_ptr_ = allocator->Malloc(img_size, offset.data());
} else {
MS_LOG(ERROR) << "Unsupported data type transpose from " << scale_tensor->data_type() << "to "
<< in_tensor->data_type();
return RET_ERROR;
}
return RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported format transpose from " << scale_tensor->format() << "to " << in_tensor->format();
return RET_ERROR;
}
return RET_OK;
}
@@ -231,7 +200,7 @@ int ScaleOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<float>(scale));
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, static_cast<float>(offset));
} else {
MS_LOG(ERROR) << "Unsupport data type " << in_tensors_[1]->data_type();
MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[1]->data_type();
return RET_ERROR;
}
}


+ 1
- 1
mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h View File

@@ -52,7 +52,7 @@ class SoftmaxOpenCLKernel : public OpenCLKernel {
std::vector<size_t> local_size_;
std::vector<size_t> global_size_;
int axis_{0};
GpuTensorInfo out_shape = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape;
};

} // namespace mindspore::kernel


+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_depth.h View File

@@ -36,8 +36,8 @@ class SpaceToDepthOpenCLKernel : public OpenCLKernel {
void SetGlobalLocal() override;

private:
GpuTensorInfo in_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo out_shape_ = GpuTensorInfo(nullptr);
GpuTensorInfo in_shape_;
GpuTensorInfo out_shape_;
};
} // namespace mindspore::kernel



+ 4
- 3
mindspore/lite/src/runtime/kernel/opencl/kernel/sparse_to_dense.cc View File

@@ -27,19 +27,20 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_SparseToDense;

namespace mindspore::kernel {

int SparseToDenseOpenCLKernel::InitOutputToDefault() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
cl_float4 fill_value = {};
fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_;
auto src_data = out_tensors_[0]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK;
@@ -113,7 +114,7 @@ int SparseToDenseOpenCLKernel::CheckSpecs() {
}
auto param = reinterpret_cast<SparseToDenseParameter *>(op_parameter_);
if (param->validate_indices_) {
MS_LOG(ERROR) << "Unspported unordered for in_tensors_indices";
MS_LOG(ERROR) << "Unsupported unordered for in_tensors_indices";
return RET_ERROR;
}
return RET_OK;


+ 4
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/split.cc View File

@@ -26,12 +26,13 @@ using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Split;

namespace mindspore::kernel {

int SplitOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
auto src_data = in_tensors_[0]->data_c();
cl::Image2D *in_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
if (in_image == nullptr) {
@@ -41,9 +42,10 @@ int SplitOpenCLKernel::RunAxis0() {
auto src_area = cl::array<cl::size_type, 3U>{0, 0, 0};
for (int i = 0; i < out_tensors_.size(); i++) {
auto dst_data = out_tensors_[i]->data_c();
ImageSize img_size;
allocator_->GetImageSize(dst_data, &img_size);
auto dst_area = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
if (out_image == nullptr) {
MS_LOG(ERROR) << "RunAxis0 out_image can not be nullptr";


+ 3
- 2
mindspore/lite/src/runtime/kernel/opencl/kernel/stack.cc View File

@@ -25,13 +25,14 @@

using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::opencl::ImageSize;
using mindspore::schema::PrimitiveType_Stack;

namespace mindspore::kernel {

int StackOpenCLKernel::RunAxis0() {
auto allocator_ = ocl_runtime_->GetAllocator();
std::vector<size_t> img_size;
ImageSize img_size;
auto dst_data = out_tensors_[0]->data_c();
auto dst_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(dst_data));
@@ -39,7 +40,7 @@ int StackOpenCLKernel::RunAxis0() {
auto src_data = in_tensors_[i]->data_c();
allocator_->GetImageSize(src_data, &img_size);
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{img_size[0], img_size[1], 1};
auto region = cl::array<cl::size_type, 3U>{img_size.width, img_size.height, 1};
cl::Image2D *input_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueCopyImage(*input_image, *out_image, src_origin, dst_origin, region);
dst_origin[1] += region[1];


+ 14
- 27
mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.cc View File

@@ -20,19 +20,14 @@
#include "src/runtime/kernel/opencl/kernel/matmul.h"
#include "src/runtime/kernel/opencl/kernel/strassen.h"
#include "src/common/utils.h"

#ifndef PROGRAM_WITH_IL

#include "src/runtime/kernel/opencl/cl/strassen.cl.inc"

#endif
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;

namespace mindspore::kernel {

int StrassenOpenCLKernel::Prepare() {
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
std::string kernel_name = "MatMul_Strassen_NHWC4_2d";
std::string source = strassen_source;
std::string program_name = "MatMul";
@@ -43,8 +38,6 @@ int StrassenOpenCLKernel::Prepare() {
ocl_runtime_->BuildKernel(kernel_back_result, program_name, "Strassen_Back_Result");
ocl_runtime_->BuildKernel(MatMul_StrassenBUFFilled, program_name, "MatMul_BUF_Filled");
ocl_runtime_->BuildKernel(MatMul_StrassenIMGFilled, program_name, "MatMul_IMG_Filled");

#endif
auto ret = InitWeights();
if (ret != RET_OK) {
return ret;
@@ -52,31 +45,25 @@ int StrassenOpenCLKernel::Prepare() {
SetConstArgs();
SetGlobalLocal();
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return mindspore::lite::RET_OK;
return RET_OK;
}

void StrassenOpenCLKernel::AllocatorMemoryForStrassen(int NumA, int NumB) {
std::vector<size_t> img_size;
img_size.push_back(UP_DIV(NumA, C4NUM));
img_size.push_back(NumA);
auto allocator = ocl_runtime_->GetAllocator();
size_t img_dtype = enable_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
ImageSize img_size{static_cast<size_t>(UP_DIV(NumA, C4NUM)), static_cast<size_t>(NumA), img_dtype};
size_t dtype_size = enable_fp16_ ? sizeof(cl_half) : sizeof(cl_float);
img_size.push_back(img_dtype);
auto allocator = ocl_runtime_->GetAllocator();
size_t memA = NumA * NumA;

size_t memB = NumB * NumB * dtype_size;
for (int depth = 0; depth < MAXDEPTH; depth++) {
B_temp[depth] = allocator->Malloc(memB);
A_temp[depth] = allocator->Malloc(memA, img_size);

M1[depth] = allocator->Malloc(memA, img_size);
M2[depth] = allocator->Malloc(memA, img_size);
M3[depth] = allocator->Malloc(memA, img_size);
M4[depth] = allocator->Malloc(memA, img_size);
M5[depth] = allocator->Malloc(memA, img_size);
M6[depth] = allocator->Malloc(memA, img_size);
M7[depth] = allocator->Malloc(memA, img_size);
A_temp[depth] = allocator->Malloc(img_size);
M1[depth] = allocator->Malloc(img_size);
M2[depth] = allocator->Malloc(img_size);
M3[depth] = allocator->Malloc(img_size);
M4[depth] = allocator->Malloc(img_size);
M5[depth] = allocator->Malloc(img_size);
M6[depth] = allocator->Malloc(img_size);
M7[depth] = allocator->Malloc(img_size);
}
}

@@ -333,6 +320,6 @@ int StrassenOpenCLKernel::Run() {
}
DoStrassen(in_tensors_.at(0)->data_c(), padWeight_, out_tensors_.at(0)->data_c(), in_tensors_.at(0)->shape()[0], 0,
threshold);
return mindspore::lite::RET_OK;
return RET_OK;
}
} // namespace mindspore::kernel

+ 2
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/strassen.h View File

@@ -21,6 +21,8 @@
#include <vector>
#include "src/runtime/kernel/opencl/kernel/matmul.h"

#define MAXDEPTH 5

namespace mindspore::kernel {

class StrassenOpenCLKernel : public MatMulOpenCLKernel {


+ 3
- 3
mindspore/lite/src/runtime/kernel/opencl/kernel/winograd.cc View File

@@ -98,7 +98,7 @@ void WinogradOpenCLKernel::InitFilter() {
size_t height = CO_SLICES_;
size_t dtype = use_fp16_ ? CL_HALF_FLOAT : CL_FLOAT;
size = width * height * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size, {width, height, dtype});
packed_filter_ = allocator->Malloc({width, height, dtype});
} else {
size = UP_DIV(CO_SLICES_, Ogroup) * 6 * 6 * CI_SLICES_ * Ogroup * CI_TILE * CO_TILE * sizeof_FLT_;
packed_filter_ = allocator->Malloc(size);
@@ -136,11 +136,11 @@ void WinogradOpenCLKernel::AllocateMemory() {

size_t width = TILE_HW_;
size_t height = CI_SLICES_ * 36;
winograd_mem0_ = allocator->Malloc(width * height * sizeof_FLT_, {width, height, img_dtype});
winograd_mem0_ = allocator->Malloc({width, height, img_dtype});

width = TILE_HW_;
height = CO_SLICES_ * 36;
winograd_mem1_ = allocator->Malloc(width * height * sizeof_FLT_, {width, height, img_dtype});
winograd_mem1_ = allocator->Malloc({width, height, img_dtype});
}

void WinogradOpenCLKernel::SetConstArgs() {


+ 4
- 3
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc View File

@@ -19,6 +19,7 @@

using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::opencl::ImageSize;

namespace mindspore::kernel {

@@ -60,7 +61,7 @@ int OpenCLKernel::AlignGlobalLocal(const std::vector<size_t> &global, const std:
return RET_OK;
}

int OpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
MS_ASSERT(img_size);
if (idx >= out_tensors_.size()) {
return RET_ERROR;
@@ -133,13 +134,13 @@ int OpenCLKernel::PreProcess() {
auto *output = out_tensors_.at(i);
MS_ASSERT(output);
if (GetMemType() == lite::opencl::MemType::IMG) {
std::vector<size_t> img_size;
ImageSize img_size;
ret = GetImageSize(i, &img_size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetImageSize failed";
return ret;
}
auto data_ptr = allocator->Malloc(output->Size(), img_size);
auto data_ptr = allocator->Malloc(img_size);
if (data_ptr == nullptr) {
MS_LOG(ERROR) << "Malloc data failed";
return RET_ERROR;


+ 2
- 1
mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.h View File

@@ -78,6 +78,7 @@ void Broadcast2GpuShape(DstT *dst, const SrcT *src, int src_num, DstT default_va
}

struct GpuTensorInfo {
GpuTensorInfo() = default;
explicit GpuTensorInfo(const lite::Tensor *tensor) {
if (tensor == nullptr) {
return;
@@ -194,7 +195,7 @@ class OpenCLKernel : public LiteKernel {
virtual int AssignTuningParam(const BaseTuningParameter &param);
virtual int Tune();

int GetImageSize(size_t idx, std::vector<size_t> *img_size);
int GetImageSize(size_t idx, lite::opencl::ImageSize *img_size);
void PrintOutput(int print_num = 10, const std::string &out_file = "");
lite::opencl::MemType GetMemType() { return out_mem_type_; }
void SetMemType(lite::opencl::MemType mem_type) { out_mem_type_ = mem_type; }


+ 32
- 59
mindspore/lite/src/runtime/kernel/opencl/utils.cc View File

@@ -123,29 +123,6 @@ int GetMaxDivisorStrategy1(int x, int divisor) {
}
}

std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global) {
MS_ASSERT(local.size() == global.size() && local.size() == 3);
std::vector<size_t> result(3);
for (int i = 0; i < 3; ++i) {
result[i] = UP_ROUND(global[i], local[i]);
}
return result;
}

std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size) {
MS_ASSERT(global.size() == 3);
size_t local_z = GetMaxDivisorStrategy0(global[2], 8);
if (local_z == 0) {
MS_LOG(ERROR) << "Divide by zero";
return {};
}
size_t local_xy = max_size / local_z;
size_t local_x = std::min(UP_DIV(global[0], 2), local_xy);
size_t local_y = std::min(local_xy / local_x, global[1]);
std::vector<size_t> local = {local_x, local_y, local_z};
return local;
}

std::string CLErrorCode(cl_int error_code) {
switch (error_code) {
case CL_SUCCESS:
@@ -295,42 +272,6 @@ int WriteToBin(const std::string &file_path, void *data, size_t size) {
return 0;
}

std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape) {
int n, h, w, c;
n = h = w = c = 1;
if (tensor_shape.size() == 1) {
c = tensor_shape[0];
} else if (tensor_shape.size() == 2) {
n = tensor_shape[0];
c = tensor_shape[1];
} else if (tensor_shape.size() == 3) {
n = tensor_shape[0];
h = tensor_shape[1];
c = tensor_shape[2];
} else if (tensor_shape.size() == 4) {
n = tensor_shape[0];
h = tensor_shape[1];
w = tensor_shape[2];
c = tensor_shape[3];
}
return {n, h, w, c};
}

std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format) {
if (tensor_shape.size() != 4) {
return {1, 1};
}
size_t image_x, image_y;
image_x = image_y = 1;
if (format == schema::Format_NHWC4) {
image_x = tensor_shape[2] * UP_DIV(tensor_shape[3], C4NUM);
image_y = tensor_shape[0] * tensor_shape[1];
} else if (format == schema::Format_NC4HW4) {
image_x = tensor_shape[2];
image_y = tensor_shape[0] * tensor_shape[1] * UP_DIV(tensor_shape[3], C4NUM);
}
return {image_x, image_y};
}
int GetBroadcastGpuAxis(int ndim, int ori_axis) {
if (ori_axis >= ndim) {
return ndim - 1;
@@ -349,4 +290,36 @@ int GetBroadcastGpuAxis(int ndim, int ori_axis) {
}
return axis;
}

void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor) {
MS_ASSERT(src);
MS_ASSERT(dst);
auto src_fp16 = reinterpret_cast<float16_t *>(src);
auto src_fp32 = reinterpret_cast<float32_t *>(src);
auto dst_fp16 = reinterpret_cast<float16_t *>(dst);
auto dst_fp32 = reinterpret_cast<float32_t *>(dst);
for (int n = 0, src_idx = 0; n < tensor.N; n++) {
for (int h = 0; h < tensor.H; ++h) {
for (int w = 0; w < tensor.W; ++w) {
for (int c = 0; c < tensor.C; ++c, ++src_idx) {
int dst_idx = ((n * tensor.H + h) * tensor.W + w) * tensor.Slice * C4NUM + c;
if (dst_is_fp16) {
dst_fp16[dst_idx] = src_is_fp16 ? src_fp16[src_idx] : static_cast<float16_t>(src_fp32[src_idx]);
} else {
dst_fp32[dst_idx] = src_is_fp16 ? static_cast<float32_t>(src_fp16[src_idx]) : src_fp32[src_idx];
}
}
}
}
}
// scalar
if (tensor.ElementsNum == 1) {
if (dst_is_fp16) {
dst_fp16[3] = dst_fp16[2] = dst_fp16[1] = dst_fp16[0];
} else {
dst_fp32[3] = dst_fp32[2] = dst_fp32[1] = dst_fp32[0];
}
}
}

} // namespace mindspore::kernel

+ 5
- 57
mindspore/lite/src/runtime/kernel/opencl/utils.h View File

@@ -25,6 +25,7 @@
#include "nnacl/op_base.h"
#include "src/lite_kernel.h"
#include "src/common/utils.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"

namespace mindspore::lite {
kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
@@ -33,6 +34,8 @@ kernel::LiteKernel *GetOpenCLKernel(const std::vector<Tensor *> &in_tensors, con

namespace mindspore::kernel {

struct GpuTensorInfo;

// for fusion
extern const std::set<schema::PrimitiveType> ArithmeticPrimitives;
extern const std::set<schema::PrimitiveType> ArithmeticSelfPrimitives;
@@ -49,20 +52,14 @@ int GetMaxDivisorStrategy0(int x, int divisor);

int GetMaxDivisorStrategy1(int x, int divisor);

std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global);

std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size);

std::string CLErrorCode(cl_int error_code);

int WriteToBin(const std::string &file_path, void *data, size_t size);

std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape);

std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format);

int GetBroadcastGpuAxis(int ndim, int ori_axis);

void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, const GpuTensorInfo &tensor);

template <class T1, class T2>
void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel,
const std::function<T2(T1)> &to_dtype) {
@@ -86,55 +83,6 @@ void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_o
}
}

template <class T1, class T2>
void PackNHWCToNHWC4(void *src, void *dst, int batch, int plane, int channel, const std::function<T2(T1)> &to_dtype) {
MS_ASSERT(src);
MS_ASSERT(dst);
int c4 = UP_DIV(channel, C4NUM);
int nhwc4_batch_unit_offset = c4 * C4NUM * plane;
int ic_remainder_ = channel % C4NUM;
if (ic_remainder_ != 0) {
int nhwc4_batch_offset = 0;
for (int b = 0; b < batch; b++) {
int batch_offset = b * channel * plane;
for (int i = 0; i < plane; ++i) {
for (int c = 0; c < channel; ++c) {
(static_cast<T2 *>(dst) + nhwc4_batch_offset + i * c4 * C4NUM + c)[0] =
to_dtype((static_cast<T1 *>(src) + batch_offset + i * channel + c)[0]);
}
}
nhwc4_batch_offset += nhwc4_batch_unit_offset;
}
} else {
size_t ori_input_size = batch * plane * channel;
for (size_t n = 0; n < ori_input_size; ++n) {
(static_cast<T2 *>(dst) + n)[0] = to_dtype((static_cast<T1 *>(src) + n)[0]);
}
}
}

template <class T1, class T2>
void PackNHWCToNC4HW4(void *src, void *dst, int batch, int plane, int channel, const std::function<T2(T1)> &to_dtype) {
MS_ASSERT(src);
MS_ASSERT(dst);
int c4 = UP_DIV(channel, C4NUM);
for (int b = 0; b < batch; b++) {
int src_oc_offset = b * plane * channel;
int dst_oc_offset = b * plane * c4 * C4NUM;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = src_oc_offset + k * channel;
int dst_kernel_offset = dst_oc_offset + k * C4NUM;
for (int i = 0; i < channel; i++) {
int c4_block_num = i / C4NUM;
int c4_block_rem = i % C4NUM;
int src_ic_offset = src_kernel_offset + i;
int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem;
(static_cast<T2 *>(dst) + dst_ic_offset)[0] = to_dtype((static_cast<T1 *>(src) + src_ic_offset)[0]);
}
}
}
}

template <class T>
std::vector<T> MatrixMultiply(const T A[], const T B[], int M, int N, int K) {
std::vector<T> C(M * K);


+ 37
- 42
mindspore/lite/src/runtime/opencl/opencl_allocator.cc View File

@@ -44,14 +44,14 @@ void OpenCLAllocator::UnLock() {
}
}

void *OpenCLAllocator::MinimumFit(size_t size, const std::vector<size_t> &img_size) {
void *OpenCLAllocator::MinimumFit(MemType mem_type, size_t size, const ImageSize &img_size) {
auto iter = free_list_.lower_bound(size);
while (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) {
auto mem_buf = iter->second;
bool is_match{mem_buf->img_size.size() == img_size.size()};
for (int i = 0; i < img_size.size() && is_match; ++i) {
is_match &= img_size[i] == mem_buf->img_size[i];
bool is_match = mem_buf->mem_type_ == mem_type;
if (mem_type == MemType::IMG) {
is_match &= mem_buf->device_ptr_ != nullptr;
is_match &= mem_buf->img_size_ == img_size;
}
if (is_match) {
free_list_.erase(iter);
@@ -88,22 +88,22 @@ void *OpenCLAllocator::CreateBuffer(size_t size, void *data, size_t flags, cl::B
return host_ptr;
}

void *OpenCLAllocator::CreateImage2D(size_t size, const std::vector<size_t> &img_size, void *data, size_t flags,
bool is_map, cl::Buffer **buffer, cl::Image2D **image) {
void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map,
cl::Buffer **buffer, cl::Image2D **image) {
cl_int ret = CL_SUCCESS;
MS_ASSERT(buffer);
MS_ASSERT(image);
MS_ASSERT(img_size.size() == 3);
if (data == nullptr) {
// copy from cl2.hpp
cl_image_desc desc = {CL_MEM_OBJECT_IMAGE2D, img_size[0], img_size[1], 0, 0, 0, 0, 0, 0, (**buffer).get()};
cl_image_desc desc = {CL_MEM_OBJECT_IMAGE2D, img_size.width, img_size.height, 0, 0, 0, 0, 0, 0, (**buffer).get()};
const cl::Context &context = *ocl_runtime_->Context();
cl_image_format image_format{CL_RGBA, static_cast<uint32_t>(img_size[2])};
cl_image_format image_format{CL_RGBA, static_cast<uint32_t>(img_size.dtype)};
*image = new (std::nothrow) cl::Image2D(clCreateImage(context.get(), 0, &image_format, &desc, nullptr, &ret));
} else {
cl::ImageFormat image_format(CL_RGBA, img_size[2]);
cl::ImageFormat image_format(CL_RGBA, img_size.dtype);
*image = new (std::nothrow) cl::Image2D(*ocl_runtime_->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
image_format, img_size[0], img_size[1], 0, data, &ret);
image_format, img_size.width, img_size.height, 0, data, &ret);
}
if (*image == nullptr) {
delete *buffer;
@@ -116,10 +116,10 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const std::vector<size_t> &img
MS_LOG(ERROR) << "Create OpenCL Image2D (ERROR CODE: " << mindspore::kernel::CLErrorCode(ret) << ")";
return nullptr;
}
MS_LOG(DEBUG) << "Malloc a new Image2D, width=" << img_size[0] << ", height=" << img_size[1];
MS_LOG(DEBUG) << "Malloc a new Image2D, width=" << img_size.width << ", height=" << img_size.height;
void *host_ptr = nullptr;
if (is_map) {
std::vector<size_t> region{img_size[0], img_size[1], 1};
std::vector<size_t> region{img_size.width, img_size.height, 1};
host_ptr = ocl_runtime_->MapBuffer(**image, true, CL_MAP_READ | CL_MAP_WRITE, region);
if (host_ptr == nullptr) {
delete *buffer;
@@ -136,22 +136,20 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const std::vector<size_t> &img
return host_ptr;
}

void *OpenCLAllocator::Malloc(size_t size) { return Malloc(size, std::vector<size_t>{}); }

void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t> &img_size, void *data) {
void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) {
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
MS_ASSERT(img_size.size() == 0 || img_size.size() == 3);
if (!img_size.empty()) {
size_t dtype_size = img_size[2] == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4);
if (mem_type == MemType::IMG) {
size_t dtype_size = img_size.dtype == CL_FLOAT ? sizeof(cl_float4) : sizeof(cl_half4);
uint32_t image_alignment = ocl_runtime_->GetImagePitchAlignment();
size = UP_ROUND(img_size[0], image_alignment) * img_size[1] * dtype_size;
size = UP_ROUND(img_size.width, image_alignment) * img_size.height * dtype_size;
}
if (size > ocl_runtime_->GetMaxAllocSize()) {
MS_LOG(ERROR) << "MallocData out of max_size, size: " << size;
return nullptr;
}
Lock();
void *host_ptr = MinimumFit(size, img_size);
void *host_ptr = MinimumFit(mem_type, size, img_size);
if (host_ptr != nullptr && data == nullptr) {
UnLock();
return host_ptr;
@@ -172,14 +170,14 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t> &img_size,
host_ptr = clSVMAlloc((*ocl_runtime_->Context())(), flags, size, 0);
} else {
flags |= (data == nullptr) ? CL_MEM_ALLOC_HOST_PTR : CL_MEM_COPY_HOST_PTR;
if (img_size.empty() || data == nullptr) {
if (mem_type == MemType::BUF || data == nullptr) {
host_ptr = CreateBuffer(size, data, flags, &buffer);
if (host_ptr == nullptr) {
UnLock();
return nullptr;
}
}
if (!img_size.empty()) {
if (mem_type == MemType::IMG) {
void *host_ptr_im = CreateImage2D(size, img_size, data, flags, data != nullptr, &buffer, &image);
if (data != nullptr && host_ptr_im == nullptr) {
UnLock();
@@ -199,10 +197,11 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t> &img_size,
mem_buf->device_ptr_ = static_cast<void *>(buffer);
mem_buf->host_ptr_ = host_ptr;
mem_buf->image_ptr_ = static_cast<void *>(image);
mem_buf->img_size = img_size;
mem_buf->mem_type_ = mem_type;
mem_buf->img_size_ = img_size;
allocated_list_[host_ptr] = mem_buf;
UnLock();
std::string type_name = img_size.empty() ? "buffer" : "Image2D";
std::string type_name = mem_type == MemType::BUF ? "buffer" : "Image2D";
MS_LOG(DEBUG) << "Malloc a new " << type_name << ". size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_
<< ", device addr: " << mem_buf->device_ptr_ << ", image_addr: " << image
<< ", total size: " << total_size_;
@@ -216,12 +215,12 @@ void OpenCLAllocator::Free(void *buf) {
Lock();
auto iter = allocated_list_.find(buf);
if (iter != allocated_list_.end()) {
if (iter->second->map_flags) {
if (iter->second->map_flags_) {
int ret = UnmapBuffer(buf);
if (ret != RET_OK) {
MS_LOG(WARNING) << "UnmapBuffer failed.";
}
iter->second->map_flags = false;
iter->second->map_flags_ = false;
}
auto mem_buf = iter->second;
allocated_list_.erase(iter);
@@ -271,7 +270,7 @@ template <typename T>
void OpenCLAllocator::ClearMemList(T *list) {
auto svm_capabilities = ocl_runtime_->GetSVMCapabilities();
for (auto it = list->begin(); it != list->end(); it++) {
if (it->second->map_flags) {
if (it->second->map_flags_) {
int ret = UnmapBuffer(it->second->host_ptr_);
if (ret != RET_OK) {
MS_LOG(WARNING) << "UnmapBuffer failed.";
@@ -330,7 +329,7 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue,
return nullptr;
}

if (it->second->map_flags) {
if (it->second->map_flags_) {
UnLock();
MS_LOG(WARNING) << "Host ptr " << host_ptr << " has mapped";
return host_ptr;
@@ -338,12 +337,12 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue,
MemBuf *mem_buf = it->second;
MS_ASSERT(mem_buf);
void *new_host_ptr{nullptr};
if (mem_buf->img_size.empty()) {
if (mem_buf->mem_type_ == MemType::BUF) {
cl::Buffer *buffer = static_cast<cl::Buffer *>(mem_buf->device_ptr_);
MS_ASSERT(buffer);
new_host_ptr = ocl_runtime_->MapBuffer(*buffer, flags, mem_buf->size_, nullptr, sync);
} else {
std::vector<size_t> region{mem_buf->img_size[0], mem_buf->img_size[1], 1};
std::vector<size_t> region{mem_buf->img_size_.width, mem_buf->img_size_.height, 1};
cl::Image2D *image = static_cast<cl::Image2D *>(mem_buf->image_ptr_);
MS_ASSERT(image);
new_host_ptr = ocl_runtime_->MapBuffer(*image, sync, CL_MAP_READ | CL_MAP_WRITE, region);
@@ -355,7 +354,7 @@ void *OpenCLAllocator::MapBuffer(void *host_ptr, int flags, void *command_queue,
return nullptr;
}

mem_buf->map_flags = true;
mem_buf->map_flags_ = true;
mem_buf->host_ptr_ = new_host_ptr;
allocated_list_.erase(it);
allocated_list_[new_host_ptr] = mem_buf;
@@ -377,10 +376,10 @@ int OpenCLAllocator::UnmapBuffer(void *host_ptr, void *command_queue) {
MS_LOG(ERROR) << "Map buffer failed, can not found buffer :" << host_ptr;
return RET_ERROR;
}
if (it->second->map_flags) {
it->second->map_flags = false;
cl::Memory *mem =
static_cast<cl::Memory *>(it->second->img_size.empty() ? it->second->device_ptr_ : it->second->image_ptr_);
if (it->second->map_flags_) {
it->second->map_flags_ = false;
cl::Memory *mem = static_cast<cl::Memory *>(it->second->mem_type_ == MemType::BUF ? it->second->device_ptr_
: it->second->image_ptr_);
return ocl_runtime_->UnmapBuffer(*mem, it->second->host_ptr_, static_cast<cl::CommandQueue *>(command_queue));
} else {
MS_LOG(WARNING) << "Host ptr " << host_ptr << " do not mapped";
@@ -399,16 +398,12 @@ MemType OpenCLAllocator::GetMemType(void *host_ptr) {
}
MemBuf *mem_buf = it->second;
MS_ASSERT(mem_buf);
if (mem_buf->img_size.empty()) {
mem_type = MemType::BUF;
} else {
mem_type = MemType::IMG;
}
mem_type = mem_buf->mem_type_;
UnLock();
return mem_type;
}

int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector<size_t> *img_size) {
int OpenCLAllocator::GetImageSize(void *host_ptr, ImageSize *img_size) {
MS_ASSERT(img_size);
Lock();
auto it = allocated_list_.find(host_ptr);
@@ -419,8 +414,8 @@ int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector<size_t> *img_size)
}
MemBuf *mem_buf = it->second;
MS_ASSERT(mem_buf);
if (!mem_buf->img_size.empty()) {
*img_size = mem_buf->img_size;
if (mem_buf->mem_type_ == MemType::IMG) {
*img_size = mem_buf->img_size_;
}
UnLock();
return RET_OK;


+ 24
- 11
mindspore/lite/src/runtime/opencl/opencl_allocator.h View File

@@ -31,14 +31,25 @@ namespace mindspore::lite::opencl {

class OpenCLRuntime;
enum class MemType : char { BUF, IMG };
struct ImageSize {
size_t width = 0;
size_t height = 0;
size_t dtype = CL_FLOAT;
bool operator==(const struct ImageSize &other) const {
return width == other.width && height == other.height && dtype == other.dtype;
}
};

class OpenCLAllocator : public Allocator {
public:
explicit OpenCLAllocator(OpenCLRuntime *ocl_runtime);
~OpenCLAllocator() override;
void SetContext(const AllocatorContext &ctx) override;
void *Malloc(size_t size) override;
void *Malloc(size_t size, const std::vector<size_t> &img_size, void *data = nullptr);
// malloc buffer
void *Malloc(size_t size) override { return _Malloc(MemType::BUF, nullptr, size); }
void *Malloc(size_t size, void *data) { return _Malloc(MemType::BUF, data, size); }
// malloc image
void *Malloc(const ImageSize &img_size, void *data = nullptr) { return _Malloc(MemType::IMG, data, 0, img_size); }
void Free(void *ptr) override;
size_t total_size() override;

@@ -48,7 +59,7 @@ class OpenCLAllocator : public Allocator {
void *MapBuffer(void *host_ptr, int flags, void *command_queue = nullptr, bool sync = true);
int UnmapBuffer(void *host_ptr, void *command_queue = nullptr);
MemType GetMemType(void *host_ptr);
int GetImageSize(void *host_ptr, std::vector<size_t> *img_size);
int GetImageSize(void *host_ptr, ImageSize *img_size);
void *Prepare(void *ptr) override {
if (ptr != nullptr) {
ptr = MapBuffer(ptr, CL_MAP_READ | CL_MAP_WRITE, nullptr, true);
@@ -59,9 +70,10 @@ class OpenCLAllocator : public Allocator {
private:
void Lock();
void UnLock();
void *MinimumFit(size_t size, const std::vector<size_t> &img_size);
void *MinimumFit(MemType mem_type, size_t size, const ImageSize &img_size);
void *_Malloc(MemType mem_type, void *data, size_t size = 0, const ImageSize &img_size = ImageSize());
void *CreateBuffer(size_t size, void *data, size_t flags, cl::Buffer **buffer);
void *CreateImage2D(size_t size, const std::vector<size_t> &img_size, void *data, size_t flags, bool is_map,
void *CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map,
cl::Buffer **buffer, cl::Image2D **image);
template <typename T>
void ClearMemList(T *list);
@@ -70,12 +82,13 @@ class OpenCLAllocator : public Allocator {
OpenCLRuntime *ocl_runtime_{nullptr};
std::mutex lock;
struct MemBuf {
size_t size_;
void *device_ptr_;
void *host_ptr_;
void *image_ptr_;
std::vector<size_t> img_size;
bool map_flags{false};
size_t size_{0};
void *device_ptr_{nullptr};
void *host_ptr_{nullptr};
void *image_ptr_{nullptr};
MemType mem_type_{MemType::BUF};
ImageSize img_size_;
bool map_flags_{false};
};

// <membuf->buf, membuf>


+ 2
- 2
mindspore/lite/src/runtime/opencl/opencl_runtime.cc View File

@@ -497,14 +497,14 @@ int OpenCLRuntime::ReadOrWriteImage(void *buffer, void *data, bool is_read) {
MS_LOG(WARNING) << "Can't get Image2D for " << buffer;
return RET_ERROR;
}
std::vector<size_t> img_size;
ImageSize img_size;
int ret = allocator_->GetImageSize(buffer, &img_size);
if (ret != RET_OK) {
MS_LOG(WARNING) << "Can't get GetImageSize for " << buffer;
return RET_ERROR;
}
cl::array<size_t, 3> origin = {0, 0, 0};
cl::array<size_t, 3> region = {img_size[0], img_size[1], 1};
cl::array<size_t, 3> region = {img_size.width, img_size.height, 1};
if (is_read) {
ret = command_queue->enqueueReadImage(*image, true, origin, region, 0, 0, data, nullptr, nullptr);
} else {


Loading…
Cancel
Save