Browse Source

solve some problems for power ops

tags/v1.2.0-rc1
Pengyongrong 5 years ago
parent
commit
cf2b868892
9 changed files with 60 additions and 108 deletions
  1. +5
    -13
      mindspore/lite/src/runtime/kernel/opencl/cl/power.cl
  2. +8
    -8
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc
  3. +9
    -9
      mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
  4. +5
    -5
      mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc
  5. +19
    -19
      mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc
  6. +4
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc
  7. +2
    -38
      mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc
  8. +0
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/power.h
  9. +8
    -8
      mindspore/lite/test/ut/src/runtime/kernel/opencl/power_tests.cc

+ 5
- 13
mindspore/lite/src/runtime/kernel/opencl/cl/power.cl View File

@@ -24,27 +24,19 @@ FLT OptimizedPowerImpl(FLT x, int exponent) {
return exponent >= 0 ? result : 1 / result; return exponent >= 0 ? result : 1 / result;
} }


__kernel void power(__read_only image2d_t input0, __global FLT *input1, __write_only image2d_t output,
__kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
int4 output_shape, FLT4 parameter) { int4 output_shape, FLT4 parameter) {
CHECK_IDX; CHECK_IDX;
int n = X / output_shape.y; int n = X / output_shape.y;
int h = X % output_shape.y; int h = X % output_shape.y;
int unalign_w = (int)parameter.w;
FLT4 result; FLT4 result;
FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));
int index_weight = (n * output_shape.y + h) * output_shape.z * unalign_w + Y * unalign_w + Z * C4NUM;
FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));

FLT tmp_result[4]; FLT tmp_result[4];
FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w}; FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};
FLT tmp_result1[4] = {0.0f, 0.0f, 0.0f, 0.0f};
if ((Z + 1) * C4NUM <= unalign_w) {
for (int i = 0; i < C4NUM; ++i) {
tmp_result1[i] = input1[index_weight + i];
}
} else {
for (int i = 0; i < unalign_w % C4NUM; ++i) {
tmp_result1[i] = input1[index_weight + i];
}
}
FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w};

for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y; tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;
if (floor(tmp_result1[i]) == tmp_result1[i]) { if (floor(tmp_result1[i]) == tmp_result1[i]) {


+ 8
- 8
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc View File

@@ -65,6 +65,14 @@ int Conv2DOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "Conv2D only supports 4D output Tensor but get " << out_tensors_.front()->shape().size() << "D."; MS_LOG(ERROR) << "Conv2D only supports 4D output Tensor but get " << out_tensors_.front()->shape().size() << "D.";
return RET_ERROR; return RET_ERROR;
} }
if (!in_tensors_.at(1)->IsConst()) {
MS_LOG(ERROR) << "Conv2D don't support non-constant filter yet.";
return RET_ERROR;
}
if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "Conv2D don't support non-constant bias yet.";
return RET_ERROR;
}
// for fusion: ActivationType_LEAKY_RELU ActivationType_TANH // for fusion: ActivationType_LEAKY_RELU ActivationType_TANH
switch (static_cast<int>(param_->act_type_)) { switch (static_cast<int>(param_->act_type_)) {
case ActType_No: case ActType_No:
@@ -302,16 +310,8 @@ int Conv2DOpenCLKernel::InitBias() {
} }


int Conv2DOpenCLKernel::InitWeights() { int Conv2DOpenCLKernel::InitWeights() {
if (!in_tensors_.at(1)->IsConst()) {
MS_LOG(ERROR) << "Conv2D don't support non-constant filter yet.";
return RET_ERROR;
}
InitFilter(); InitFilter();
if (has_bias_) { if (has_bias_) {
if (!in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "Conv2D don't support non-constant bias yet.";
return RET_ERROR;
}
InitBias(); InitBias();
} }
return RET_OK; return RET_OK;


+ 9
- 9
mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc View File

@@ -49,6 +49,14 @@ int Conv2dTransposeOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_;
return RET_ERROR; return RET_ERROR;
} }
if (!in_tensors_.at(1)->IsConst()) {
MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant filter yet.";
return RET_ERROR;
}
if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant bias yet.";
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }


@@ -117,10 +125,6 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() {
} }


int Conv2dTransposeOpenCLKernel::InitWeights() { int Conv2dTransposeOpenCLKernel::InitWeights() {
if (!in_tensors_.at(1)->IsConst()) {
MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant filter yet.";
return RET_ERROR;
}
ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_);
int ci = in_tensors_[0]->shape()[3]; int ci = in_tensors_[0]->shape()[3];
int co = out_tensors_[0]->shape()[3]; int co = out_tensors_[0]->shape()[3];
@@ -189,11 +193,7 @@ int Conv2dTransposeOpenCLKernel::InitWeights() {
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size); bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, div_co * C4NUM * data_size); memset(bias_, 0x00, div_co * C4NUM * data_size);
if (in_tensors_.size() >= 3) {
if (!in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant bias yet.";
return RET_ERROR;
}
if (in_tensors_.size() == 3) {
auto bias_dtype = in_tensors_[2]->data_type(); auto bias_dtype = in_tensors_[2]->data_type();
if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) {
for (int i = 0; i < co; i++) { for (int i = 0; i < co; i++) {


+ 5
- 5
mindspore/lite/src/runtime/kernel/opencl/kernel/fullconnection.cc View File

@@ -92,6 +92,10 @@ int FullConnectionOpenCLKernel::CheckSpecs() {
return RET_ERROR; return RET_ERROR;
} }
} }
if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "FullConnection don't support non-constant bias yet.";
return RET_ERROR;
}
CI_remainder_ = input_nhw / N_; CI_remainder_ = input_nhw / N_;
return RET_OK; return RET_OK;
} }
@@ -211,11 +215,7 @@ int FullConnectionOpenCLKernel::InitBias() {
bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size); bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size);
bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true);
memset(bias_, 0x00, co4 * C4NUM * dtype_size); memset(bias_, 0x00, co4 * C4NUM * dtype_size);
if (in_tensors_.size() >= 3) {
if (!in_tensors_.at(2)->IsConst()) {
MS_LOG(ERROR) << "FullConnection don't support non-constant bias yet.";
return RET_ERROR;
}
if (in_tensors_.size() == 3) {
if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) { if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) {
for (int i = 0; i < CO_; i++) { for (int i = 0; i < CO_; i++) {
reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i]; reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i];


+ 19
- 19
mindspore/lite/src/runtime/kernel/opencl/kernel/layer_norm.cc View File

@@ -33,24 +33,24 @@ namespace mindspore::kernel {


int LayerNormOpenCLKernel::CheckSpecs() { int LayerNormOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_);
if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) {
if (in_tensors_.size() != 3) {
MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
return RET_ERROR;
}
if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) {
MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl;
return RET_ERROR;
}
} else if (param->elementwise_mode_ == ELEMENTWISE_NOT) {
if (in_tensors_.size() != 1) {
MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_;
return RET_ERROR;
}
// if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) {
// if (in_tensors_.size() != 3) {
// MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
// return RET_ERROR;
// }
// if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) {
// MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl;
// return RET_ERROR;
// }
// } else if (param->elementwise_mode_ == ELEMENTWISE_NOT) {
// if (in_tensors_.size() != 1) {
// MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
// return RET_ERROR;
// }
// } else {
// MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_;
// return RET_ERROR;
// }
if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) { if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size() MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size()
<< " out_tensors_.size(): " << out_tensors_.size(); << " out_tensors_.size(): " << out_tensors_.size();
@@ -184,7 +184,7 @@ int LayerNormOpenCLKernel::Initweight() {
int LayerNormOpenCLKernel::Prepare() { int LayerNormOpenCLKernel::Prepare() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_);
elementwise_affine_ = param->elementwise_mode_;
elementwise_affine_ = true; // param->elementwise_mode_;
normalized_dims_ = param->normalized_dims_; normalized_dims_ = param->normalized_dims_;
epsilon_ = param->epsilon_; epsilon_ = param->epsilon_;
if (elementwise_affine_) { if (elementwise_affine_) {


+ 4
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc View File

@@ -48,6 +48,10 @@ int MatMulOpenCLKernel::CheckSpecs() {
MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4."; MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4.";
return mindspore::lite::RET_ERROR; return mindspore::lite::RET_ERROR;
} }
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
MS_LOG(ERROR) << "Matmul don't support non-constant filter yet.";
return RET_ERROR;
}
return RET_OK; return RET_OK;
} }


@@ -80,10 +84,6 @@ int MatMulOpenCLKernel::Prepare() {


int MatMulOpenCLKernel::InitWeights() { int MatMulOpenCLKernel::InitWeights() {
// ABMCI @ ABCICO = ABMCO // ABMCI @ ABCICO = ABMCO
if (!in_tensors_.at(kWeightIndex)->IsConst()) {
MS_LOG(ERROR) << "Matmul don't support non-constant filter yet.";
return RET_ERROR;
}
auto ret = DequantWeight(); auto ret = DequantWeight();
if (ret != RET_OK) { if (ret != RET_OK) {
return ret; return ret;


+ 2
- 38
mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc View File

@@ -48,40 +48,6 @@ int PowerOpenCLKernel::CheckSpecs() {
return RET_OK; return RET_OK;
} }


int PowerOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1));
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.OriginSize;
weight_ = allocator->Malloc(weight_size);
allocator->MapBuffer(weight_, CL_MAP_WRITE, nullptr, true);
memset(weight_, 0x00, weight_size);

if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (use_fp16_enable_) {
memcpy(weight_, weight_tensor->data_c(), weight_size);
} else {
auto weight_fp32 = reinterpret_cast<float *>(weight_);
auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]);
}
}
} else {
if (use_fp16_enable_) {
auto weight_fp16 = reinterpret_cast<float16_t *>(weight_);
auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]);
}
} else {
memcpy(weight_, weight_tensor->data_c(), weight_size);
}
}
allocator->UnmapBuffer(weight_);
return RET_OK;
}

void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8; const int max_divider = 8;
const int max_x = 2, max_y = 8; const int max_x = 2, max_y = 8;
@@ -145,11 +111,9 @@ int PowerOpenCLKernel::Prepare() {
std::string kernel_name = "power"; std::string kernel_name = "power";
std::string source = power_source; std::string source = power_source;
std::string program_name = "power"; std::string program_name = "power";
if (broadcast_ && in_tensors_.size() == 1) {
if (broadcast_) {
power_ = param->power_; power_ = param->power_;
kernel_name += "_broadcast"; kernel_name += "_broadcast";
} else {
Initweight();
} }
scale_ = param->scale_; scale_ = param->scale_;
shift_ = param->shift_; shift_ = param->shift_;
@@ -168,7 +132,7 @@ int PowerOpenCLKernel::Run() {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c());
} else { } else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c());
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_, lite::opencl::MemType::BUF);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(1)->data_c());
} }
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c());
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); ocl_runtime_->RunKernel(kernel_, global_range_, local_range_);


+ 0
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/power.h View File

@@ -37,14 +37,10 @@ class PowerOpenCLKernel : public OpenCLKernel {
void SetGlobalLocal() override; void SetGlobalLocal() override;
int Run() override; int Run() override;


private:
int Initweight();

private: private:
cl_int4 out_shape_{}; cl_int4 out_shape_{};
bool broadcast_{false}; bool broadcast_{false};
bool use_fp16_enable_{false}; bool use_fp16_enable_{false};
void *weight_{nullptr};
float power_{1.0}; float power_{1.0};
float scale_{0.0}; float scale_{0.0};
float shift_{1.0}; float shift_{1.0};


+ 8
- 8
mindspore/lite/test/ut/src/runtime/kernel/opencl/power_tests.cc View File

@@ -48,8 +48,8 @@ TEST_F(TestPowerOpenCLCI, Int32CI) {
100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0}; 100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0};
for (auto fp16_enable : {false, true}) { for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_); auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-3 : 1e-9);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
fp16_enable, fp16_enable ? 1e-3 : 1e-9);
} }
} }


@@ -68,8 +68,8 @@ TEST_F(TestPowerOpenCLCI, Fp32CI) {
3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001, 0.0542811}; 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001, 0.0542811};
for (auto fp16_enable : {false, true}) { for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_); auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
fp16_enable, fp16_enable ? 1e-2 : 1e-6);
} }
} }


@@ -87,8 +87,8 @@ TEST_F(TestPowerOpenCLCI, Fp32UnAlign) {
3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001}; 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001};
for (auto fp16_enable : {false, true}) { for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_); auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-2 : 1e-6);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
fp16_enable, fp16_enable ? 1e-2 : 1e-6);
} }
} }


@@ -121,8 +121,8 @@ TEST_F(TestPowerOpenCLCI, Fp16CI) {
0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466}; 0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466};
for (auto fp16_enable : {true}) { for (auto fp16_enable : {true}) {
auto *param = CreateParameter(broadcast_, shift_, scale_); auto *param = CreateParameter(broadcast_, shift_, scale_);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data},
param, fp16_enable, fp16_enable ? 1e-3 : 1e-6);
TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param,
fp16_enable, fp16_enable ? 1e-3 : 1e-6);
} }
} }




Loading…
Cancel
Save