| @@ -24,27 +24,19 @@ FLT OptimizedPowerImpl(FLT x, int exponent) { | |||||
| return exponent >= 0 ? result : 1 / result; | return exponent >= 0 ? result : 1 / result; | ||||
| } | } | ||||
| __kernel void power(__read_only image2d_t input0, __global FLT *input1, __write_only image2d_t output, | |||||
| __kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output, | |||||
| int4 output_shape, FLT4 parameter) { | int4 output_shape, FLT4 parameter) { | ||||
| CHECK_IDX; | CHECK_IDX; | ||||
| int n = X / output_shape.y; | int n = X / output_shape.y; | ||||
| int h = X % output_shape.y; | int h = X % output_shape.y; | ||||
| int unalign_w = (int)parameter.w; | |||||
| FLT4 result; | FLT4 result; | ||||
| FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); | FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); | ||||
| int index_weight = (n * output_shape.y + h) * output_shape.z * unalign_w + Y * unalign_w + Z * C4NUM; | |||||
| FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h))); | |||||
| FLT tmp_result[4]; | FLT tmp_result[4]; | ||||
| FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w}; | FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w}; | ||||
| FLT tmp_result1[4] = {0.0f, 0.0f, 0.0f, 0.0f}; | |||||
| if ((Z + 1) * C4NUM <= unalign_w) { | |||||
| for (int i = 0; i < C4NUM; ++i) { | |||||
| tmp_result1[i] = input1[index_weight + i]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < unalign_w % C4NUM; ++i) { | |||||
| tmp_result1[i] = input1[index_weight + i]; | |||||
| } | |||||
| } | |||||
| FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w}; | |||||
| for (int i = 0; i < 4; ++i) { | for (int i = 0; i < 4; ++i) { | ||||
| tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y; | tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y; | ||||
| if (floor(tmp_result1[i]) == tmp_result1[i]) { | if (floor(tmp_result1[i]) == tmp_result1[i]) { | ||||
| @@ -65,6 +65,14 @@ int Conv2DOpenCLKernel::CheckSpecs() { | |||||
| MS_LOG(ERROR) << "Conv2D only supports 4D output Tensor but get " << out_tensors_.front()->shape().size() << "D."; | MS_LOG(ERROR) << "Conv2D only supports 4D output Tensor but get " << out_tensors_.front()->shape().size() << "D."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!in_tensors_.at(1)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2D don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2D don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // for fusion: ActivationType_LEAKY_RELU ActivationType_TANH | // for fusion: ActivationType_LEAKY_RELU ActivationType_TANH | ||||
| switch (static_cast<int>(param_->act_type_)) { | switch (static_cast<int>(param_->act_type_)) { | ||||
| case ActType_No: | case ActType_No: | ||||
| @@ -302,16 +310,8 @@ int Conv2DOpenCLKernel::InitBias() { | |||||
| } | } | ||||
| int Conv2DOpenCLKernel::InitWeights() { | int Conv2DOpenCLKernel::InitWeights() { | ||||
| if (!in_tensors_.at(1)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2D don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| InitFilter(); | InitFilter(); | ||||
| if (has_bias_) { | if (has_bias_) { | ||||
| if (!in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2D don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| InitBias(); | InitBias(); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -49,6 +49,14 @@ int Conv2dTransposeOpenCLKernel::CheckSpecs() { | |||||
| MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | MS_LOG(ERROR) << "Unsupported activation type " << param->act_type_; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!in_tensors_.at(1)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -117,10 +125,6 @@ void Conv2dTransposeOpenCLKernel::SetConstArgs() { | |||||
| } | } | ||||
| int Conv2dTransposeOpenCLKernel::InitWeights() { | int Conv2dTransposeOpenCLKernel::InitWeights() { | ||||
| if (!in_tensors_.at(1)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameter_); | ||||
| int ci = in_tensors_[0]->shape()[3]; | int ci = in_tensors_[0]->shape()[3]; | ||||
| int co = out_tensors_[0]->shape()[3]; | int co = out_tensors_[0]->shape()[3]; | ||||
| @@ -189,11 +193,7 @@ int Conv2dTransposeOpenCLKernel::InitWeights() { | |||||
| bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size); | bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * data_size, img_size); | ||||
| bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | ||||
| memset(bias_, 0x00, div_co * C4NUM * data_size); | memset(bias_, 0x00, div_co * C4NUM * data_size); | ||||
| if (in_tensors_.size() >= 3) { | |||||
| if (!in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Conv2dTranspose don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_.size() == 3) { | |||||
| auto bias_dtype = in_tensors_[2]->data_type(); | auto bias_dtype = in_tensors_[2]->data_type(); | ||||
| if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { | if (bias_dtype == kNumberTypeFloat32 && enable_fp16_) { | ||||
| for (int i = 0; i < co; i++) { | for (int i = 0; i < co; i++) { | ||||
| @@ -92,6 +92,10 @@ int FullConnectionOpenCLKernel::CheckSpecs() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| if (in_tensors_.size() == 3 && !in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "FullConnection don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| CI_remainder_ = input_nhw / N_; | CI_remainder_ = input_nhw / N_; | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -211,11 +215,7 @@ int FullConnectionOpenCLKernel::InitBias() { | |||||
| bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size); | bias_ = allocator->Malloc(im_dst_x * im_dst_y * C4NUM * dtype_size, img_size); | ||||
| bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | bias_ = allocator->MapBuffer(bias_, CL_MAP_WRITE, nullptr, true); | ||||
| memset(bias_, 0x00, co4 * C4NUM * dtype_size); | memset(bias_, 0x00, co4 * C4NUM * dtype_size); | ||||
| if (in_tensors_.size() >= 3) { | |||||
| if (!in_tensors_.at(2)->IsConst()) { | |||||
| MS_LOG(ERROR) << "FullConnection don't support non-constant bias yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_.size() == 3) { | |||||
| if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) { | if (in_tensors_[2]->data_type() == kNumberTypeFloat32 && enable_fp16_) { | ||||
| for (int i = 0; i < CO_; i++) { | for (int i = 0; i < CO_; i++) { | ||||
| reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i]; | reinterpret_cast<float16_t *>(bias_)[i] = reinterpret_cast<float *>(in_tensors_[2]->data_c())[i]; | ||||
| @@ -33,24 +33,24 @@ namespace mindspore::kernel { | |||||
| int LayerNormOpenCLKernel::CheckSpecs() { | int LayerNormOpenCLKernel::CheckSpecs() { | ||||
| auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); | auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); | ||||
| if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) { | |||||
| if (in_tensors_.size() != 3) { | |||||
| MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) { | |||||
| MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (param->elementwise_mode_ == ELEMENTWISE_NOT) { | |||||
| if (in_tensors_.size() != 1) { | |||||
| MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) { | |||||
| // if (in_tensors_.size() != 3) { | |||||
| // MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; | |||||
| // return RET_ERROR; | |||||
| // } | |||||
| // if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) { | |||||
| // MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl; | |||||
| // return RET_ERROR; | |||||
| // } | |||||
| // } else if (param->elementwise_mode_ == ELEMENTWISE_NOT) { | |||||
| // if (in_tensors_.size() != 1) { | |||||
| // MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl; | |||||
| // return RET_ERROR; | |||||
| // } | |||||
| // } else { | |||||
| // MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_; | |||||
| // return RET_ERROR; | |||||
| // } | |||||
| if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) { | if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) { | ||||
| MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size() | MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size() | ||||
| << " out_tensors_.size(): " << out_tensors_.size(); | << " out_tensors_.size(): " << out_tensors_.size(); | ||||
| @@ -184,7 +184,7 @@ int LayerNormOpenCLKernel::Initweight() { | |||||
| int LayerNormOpenCLKernel::Prepare() { | int LayerNormOpenCLKernel::Prepare() { | ||||
| use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); | use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); | ||||
| auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); | auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_); | ||||
| elementwise_affine_ = param->elementwise_mode_; | |||||
| elementwise_affine_ = true; // param->elementwise_mode_; | |||||
| normalized_dims_ = param->normalized_dims_; | normalized_dims_ = param->normalized_dims_; | ||||
| epsilon_ = param->epsilon_; | epsilon_ = param->epsilon_; | ||||
| if (elementwise_affine_) { | if (elementwise_affine_) { | ||||
| @@ -48,6 +48,10 @@ int MatMulOpenCLKernel::CheckSpecs() { | |||||
| MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4."; | MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4."; | ||||
| return mindspore::lite::RET_ERROR; | return mindspore::lite::RET_ERROR; | ||||
| } | } | ||||
| if (!in_tensors_.at(kWeightIndex)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Matmul don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -80,10 +84,6 @@ int MatMulOpenCLKernel::Prepare() { | |||||
| int MatMulOpenCLKernel::InitWeights() { | int MatMulOpenCLKernel::InitWeights() { | ||||
| // ABMCI @ ABCICO = ABMCO | // ABMCI @ ABCICO = ABMCO | ||||
| if (!in_tensors_.at(kWeightIndex)->IsConst()) { | |||||
| MS_LOG(ERROR) << "Matmul don't support non-constant filter yet."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ret = DequantWeight(); | auto ret = DequantWeight(); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| return ret; | return ret; | ||||
| @@ -48,40 +48,6 @@ int PowerOpenCLKernel::CheckSpecs() { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int PowerOpenCLKernel::Initweight() { | |||||
| auto allocator = ocl_runtime_->GetAllocator(); | |||||
| GpuTensorInfo img_info(in_tensors_.at(1)); | |||||
| auto weight_tensor = in_tensors_.at(1); | |||||
| size_t weight_size = img_info.OriginSize; | |||||
| weight_ = allocator->Malloc(weight_size); | |||||
| allocator->MapBuffer(weight_, CL_MAP_WRITE, nullptr, true); | |||||
| memset(weight_, 0x00, weight_size); | |||||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | |||||
| if (use_fp16_enable_) { | |||||
| memcpy(weight_, weight_tensor->data_c(), weight_size); | |||||
| } else { | |||||
| auto weight_fp32 = reinterpret_cast<float *>(weight_); | |||||
| auto origin_bias_fp16 = reinterpret_cast<float16_t *>(weight_tensor->data_c()); | |||||
| for (int i = 0; i < img_info.ElementsNum; ++i) { | |||||
| weight_fp32[i] = static_cast<float>(origin_bias_fp16[i]); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (use_fp16_enable_) { | |||||
| auto weight_fp16 = reinterpret_cast<float16_t *>(weight_); | |||||
| auto origin_bias_fp32 = reinterpret_cast<float *>(weight_tensor->data_c()); | |||||
| for (int i = 0; i < img_info.ElementsNum; ++i) { | |||||
| weight_fp16[i] = static_cast<float16_t>(origin_bias_fp32[i]); | |||||
| } | |||||
| } else { | |||||
| memcpy(weight_, weight_tensor->data_c(), weight_size); | |||||
| } | |||||
| } | |||||
| allocator->UnmapBuffer(weight_); | |||||
| return RET_OK; | |||||
| } | |||||
| void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | void PowerGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | ||||
| const int max_divider = 8; | const int max_divider = 8; | ||||
| const int max_x = 2, max_y = 8; | const int max_x = 2, max_y = 8; | ||||
| @@ -145,11 +111,9 @@ int PowerOpenCLKernel::Prepare() { | |||||
| std::string kernel_name = "power"; | std::string kernel_name = "power"; | ||||
| std::string source = power_source; | std::string source = power_source; | ||||
| std::string program_name = "power"; | std::string program_name = "power"; | ||||
| if (broadcast_ && in_tensors_.size() == 1) { | |||||
| if (broadcast_) { | |||||
| power_ = param->power_; | power_ = param->power_; | ||||
| kernel_name += "_broadcast"; | kernel_name += "_broadcast"; | ||||
| } else { | |||||
| Initweight(); | |||||
| } | } | ||||
| scale_ = param->scale_; | scale_ = param->scale_; | ||||
| shift_ = param->shift_; | shift_ = param->shift_; | ||||
| @@ -168,7 +132,7 @@ int PowerOpenCLKernel::Run() { | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); | ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); | ||||
| } else { | } else { | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); | ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, weight_, lite::opencl::MemType::BUF); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(1)->data_c()); | |||||
| } | } | ||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); | ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); | ||||
| ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); | ocl_runtime_->RunKernel(kernel_, global_range_, local_range_); | ||||
| @@ -37,14 +37,10 @@ class PowerOpenCLKernel : public OpenCLKernel { | |||||
| void SetGlobalLocal() override; | void SetGlobalLocal() override; | ||||
| int Run() override; | int Run() override; | ||||
| private: | |||||
| int Initweight(); | |||||
| private: | private: | ||||
| cl_int4 out_shape_{}; | cl_int4 out_shape_{}; | ||||
| bool broadcast_{false}; | bool broadcast_{false}; | ||||
| bool use_fp16_enable_{false}; | bool use_fp16_enable_{false}; | ||||
| void *weight_{nullptr}; | |||||
| float power_{1.0}; | float power_{1.0}; | ||||
| float scale_{0.0}; | float scale_{0.0}; | ||||
| float shift_{1.0}; | float shift_{1.0}; | ||||
| @@ -48,8 +48,8 @@ TEST_F(TestPowerOpenCLCI, Int32CI) { | |||||
| 100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0}; | 100.0, 121.0, 1728.0, 1.0, 196.0, 225.0, 16.0, 289.0}; | ||||
| for (auto fp16_enable : {false, true}) { | for (auto fp16_enable : {false, true}) { | ||||
| auto *param = CreateParameter(broadcast_, shift_, scale_); | auto *param = CreateParameter(broadcast_, shift_, scale_); | ||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||||
| param, fp16_enable, fp16_enable ? 1e-3 : 1e-9); | |||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param, | |||||
| fp16_enable, fp16_enable ? 1e-3 : 1e-9); | |||||
| } | } | ||||
| } | } | ||||
| @@ -68,8 +68,8 @@ TEST_F(TestPowerOpenCLCI, Fp32CI) { | |||||
| 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001, 0.0542811}; | 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001, 0.0542811}; | ||||
| for (auto fp16_enable : {false, true}) { | for (auto fp16_enable : {false, true}) { | ||||
| auto *param = CreateParameter(broadcast_, shift_, scale_); | auto *param = CreateParameter(broadcast_, shift_, scale_); | ||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||||
| param, fp16_enable, fp16_enable ? 1e-2 : 1e-6); | |||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param, | |||||
| fp16_enable, fp16_enable ? 1e-2 : 1e-6); | |||||
| } | } | ||||
| } | } | ||||
| @@ -87,8 +87,8 @@ TEST_F(TestPowerOpenCLCI, Fp32UnAlign) { | |||||
| 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001}; | 3.20657016, 0.64395994, 0.01526405, 0.13275899, 5.85509388, 0.16177453, 0.07150001}; | ||||
| for (auto fp16_enable : {false, true}) { | for (auto fp16_enable : {false, true}) { | ||||
| auto *param = CreateParameter(broadcast_, shift_, scale_); | auto *param = CreateParameter(broadcast_, shift_, scale_); | ||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||||
| param, fp16_enable, fp16_enable ? 1e-2 : 1e-6); | |||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param, | |||||
| fp16_enable, fp16_enable ? 1e-2 : 1e-6); | |||||
| } | } | ||||
| } | } | ||||
| @@ -121,8 +121,8 @@ TEST_F(TestPowerOpenCLCI, Fp16CI) { | |||||
| 0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466}; | 0.4856, 1.014, 0.2025, -1.736, 0.2134, 0.489, -0.596, 0.7466}; | ||||
| for (auto fp16_enable : {true}) { | for (auto fp16_enable : {true}) { | ||||
| auto *param = CreateParameter(broadcast_, shift_, scale_); | auto *param = CreateParameter(broadcast_, shift_, scale_); | ||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||||
| param, fp16_enable, fp16_enable ? 1e-3 : 1e-6); | |||||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, VAR}}, {output_shape, output_data}, param, | |||||
| fp16_enable, fp16_enable ? 1e-3 : 1e-6); | |||||
| } | } | ||||
| } | } | ||||