From: @pengyongrong Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -14,20 +14,35 @@ __kernel void SoftMaxAxis3_NHWC4(__read_only image2d_t input, __write_only image | |||||
| if (X >= H || Y >= W) return; | if (X >= H || Y >= W) return; | ||||
| // get max | |||||
| float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); | |||||
| float input_max = last.x; | |||||
| if (mask.y > 0.5f) input_max = max(input_max, last.y); | |||||
| if (mask.z > 0.5f) input_max = max(input_max, last.z); | |||||
| if (mask.w > 0.5f) input_max = max(input_max, last.w); | |||||
| for (int d = 0; d < C4 - 1; ++d) { | |||||
| float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); | |||||
| input_max = max(input_max, t.x); | |||||
| input_max = max(input_max, t.y); | |||||
| input_max = max(input_max, t.z); | |||||
| input_max = max(input_max, t.w); | |||||
| } | |||||
| float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max); | |||||
| float sum = 0.0f; | float sum = 0.0f; | ||||
| for (int d = 0; d < C4 - 1; ++d) { | for (int d = 0; d < C4 - 1; ++d) { | ||||
| float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); | float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); | ||||
| sum += dot(exp(t), (float4)(1.f)); | |||||
| sum += dot(exp(t - input_max_f4), (float4)(1.f)); | |||||
| } | } | ||||
| float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); | float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); | ||||
| sum += dot(exp(t), mask); | |||||
| sum += dot(exp(min(t - input_max_f4, 0)), mask); | |||||
| for (int d = 0; d < C4 - 1; ++d) { | for (int d = 0; d < C4 - 1; ++d) { | ||||
| float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); | float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, X))); | ||||
| result = exp(result) / sum; | |||||
| result = exp(result - input_max_f4) / sum; | |||||
| WRITE_IMAGE(output, (int2)(Y * C4 + d, X), TO_FLT4(result)); | WRITE_IMAGE(output, (int2)(Y * C4 + d, X), TO_FLT4(result)); | ||||
| } | } | ||||
| float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); | float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, X))); | ||||
| result = exp(result) / sum; | |||||
| result = exp(min(result - input_max_f4, 0)) / sum; | |||||
| result = result * mask; | result = result * mask; | ||||
| WRITE_IMAGE(output, (int2)(Y * C4 + C4 - 1, X), TO_FLT4(result)); | WRITE_IMAGE(output, (int2)(Y * C4 + C4 - 1, X), TO_FLT4(result)); | ||||
| } | } | ||||
| @@ -16,14 +16,14 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/runtime/kernel/opencl/kernel/prelu.h" | |||||
| #include <mindspore/lite/nnacl/prelu_parameter.h> | |||||
| #include <set> | #include <set> | ||||
| #include <vector> | #include <vector> | ||||
| #include "src/runtime/kernel/opencl/cl/prelu.cl.inc" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "nnacl/fp32/common_func_fp32.h" | #include "nnacl/fp32/common_func_fp32.h" | ||||
| #include "src/runtime/kernel/opencl/kernel/prelu.h" | |||||
| #include "src/runtime/kernel/opencl/cl/prelu.cl.inc" | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | using mindspore::kernel::KERNEL_ARCH::kGPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -36,7 +36,6 @@ namespace mindspore::kernel { | |||||
| int PReluOpenCLKernel::InitWeights() { | int PReluOpenCLKernel::InitWeights() { | ||||
| auto allocator = ocl_runtime_->GetAllocator(); | auto allocator = ocl_runtime_->GetAllocator(); | ||||
| auto weight_tensor = in_tensors_.at(1); | auto weight_tensor = in_tensors_.at(1); | ||||
| int C_ = weight_shape_.s[3]; | |||||
| if (weight_is_scalar) { | if (weight_is_scalar) { | ||||
| if (weight_tensor->data_type() == kNumberTypeFloat16) { | if (weight_tensor->data_type() == kNumberTypeFloat16) { | ||||
| weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c())); | weight_scalar_ = static_cast<float>(*reinterpret_cast<float16_t *>(weight_tensor->data_c())); | ||||
| @@ -44,6 +43,7 @@ int PReluOpenCLKernel::InitWeights() { | |||||
| weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c()); | weight_scalar_ = *reinterpret_cast<float *>(weight_tensor->data_c()); | ||||
| } | } | ||||
| } else { | } else { | ||||
| int C_ = weight_tensor->ElementsNum(); | |||||
| auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); | ||||
| size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT; | size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT; | ||||
| weight_vector_ = allocator->Malloc(weight_size); | weight_vector_ = allocator->Malloc(weight_size); | ||||
| @@ -123,7 +123,8 @@ int PReluOpenCLKernel::Prepare() { | |||||
| } | } | ||||
| Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1); | Broadcast2GpuShape(out_shape_.s, output_shape.s, out_tensors_.at(0)->shape().size(), 1); | ||||
| Broadcast2GpuShape(weight_shape_.s, weight_shape.s, in_tensors_.at(1)->shape().size(), 1); | Broadcast2GpuShape(weight_shape_.s, weight_shape.s, in_tensors_.at(1)->shape().size(), 1); | ||||
| weight_is_scalar = weight_shape_.s[3] == 1; | |||||
| auto param = reinterpret_cast<PReluParameter *>(op_parameter_); | |||||
| weight_is_scalar = param->channelShared; | |||||
| enable_fp16_ = ocl_runtime_->GetFp16Enable(); | enable_fp16_ = ocl_runtime_->GetFp16Enable(); | ||||
| std::string source = prelu_source; | std::string source = prelu_source; | ||||
| std::string program_name = "PRelu"; | std::string program_name = "PRelu"; | ||||
| @@ -321,7 +321,7 @@ void TryMergeArithmeticAct(LiteKernel *act, std::set<LiteKernel *> *removed_set) | |||||
| // FullConnection(NO_ACTIVATION) + Activation(RELU/RELU6/TANH) | // FullConnection(NO_ACTIVATION) + Activation(RELU/RELU6/TANH) | ||||
| template <typename ParamType> | template <typename ParamType> | ||||
| void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) { | void TryMergeXxxActivation(LiteKernel *act, std::set<LiteKernel *> *removed_set) { | ||||
| MS_ASSERT(node); | |||||
| MS_ASSERT(act); | |||||
| MS_ASSERT(removed_set); | MS_ASSERT(removed_set); | ||||
| auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | auto *act_param = reinterpret_cast<ActivationParameter *>(reinterpret_cast<OpenCLKernel *>(act)->GetParameter()); | ||||
| LiteKernel *node = act->in_kernels().front(); | LiteKernel *node = act->in_kernels().front(); | ||||
| @@ -534,7 +534,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::set<LiteKernel *> *removed_set | |||||
| if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { | if (AIsInB(pred, nodes) && IsEltwiseAndOperatorSupported(pred) && pred->out_kernels().size() == 1) { | ||||
| auto *tensor = pred->out_tensors().front(); | auto *tensor = pred->out_tensors().front(); | ||||
| MS_ASSERT(pred->out_kernels().front() == node); | MS_ASSERT(pred->out_kernels().front() == node); | ||||
| MS_ASSERT(AIsInB(tensor, node.in_tensors())); | |||||
| MS_ASSERT(AIsInB(tensor, &node->in_tensors())); | |||||
| pred_eltwises.insert(pred); | pred_eltwises.insert(pred); | ||||
| // create FusionEltwiseParameter for this pred eltwise | // create FusionEltwiseParameter for this pred eltwise | ||||
| auto param = CreateFusionEltwiseParameter(pred); | auto param = CreateFusionEltwiseParameter(pred); | ||||
| @@ -93,7 +93,6 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi | |||||
| cl_int ret = CL_SUCCESS; | cl_int ret = CL_SUCCESS; | ||||
| MS_ASSERT(buffer); | MS_ASSERT(buffer); | ||||
| MS_ASSERT(image); | MS_ASSERT(image); | ||||
| MS_ASSERT(img_size.size() == 3); | |||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| // copy from cl2.hpp | // copy from cl2.hpp | ||||
| cl_image_desc desc = {CL_MEM_OBJECT_IMAGE2D, img_size.width, img_size.height, 0, 0, 0, 0, 0, 0, (**buffer).get()}; | cl_image_desc desc = {CL_MEM_OBJECT_IMAGE2D, img_size.width, img_size.height, 0, 0, 0, 0, 0, 0, (**buffer).get()}; | ||||
| @@ -136,23 +135,27 @@ void *OpenCLAllocator::CreateImage2D(size_t size, const ImageSize &img_size, voi | |||||
| return host_ptr; | return host_ptr; | ||||
| } | } | ||||
| int OpenCLAllocator::GetImgDtypeSize(const ImageSize &img_size) { | |||||
| size_t dtype_size = 0; | |||||
| if (img_size.dtype == CL_FLOAT) { | |||||
| dtype_size = sizeof(cl_float); | |||||
| } else if (img_size.dtype == CL_HALF_FLOAT) { | |||||
| dtype_size = sizeof(cl_half); | |||||
| } else if (img_size.dtype == CL_UNSIGNED_INT8) { | |||||
| dtype_size = sizeof(cl_uchar); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dtype " << img_size.dtype; | |||||
| return RET_ERROR; | |||||
| } | |||||
| uint32_t image_alignment = ocl_runtime_->GetImagePitchAlignment(); | |||||
| size_t size = UP_ROUND(img_size.width, image_alignment) * img_size.height * C4NUM * dtype_size; | |||||
| return size; | |||||
| } | |||||
| void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) { | void *OpenCLAllocator::_Malloc(MemType mem_type, void *data, size_t size, const ImageSize &img_size) { | ||||
| auto svm_capabilities = ocl_runtime_->GetSVMCapabilities(); | auto svm_capabilities = ocl_runtime_->GetSVMCapabilities(); | ||||
| MS_ASSERT(img_size.size() == 0 || img_size.size() == 3); | |||||
| if (mem_type == MemType::IMG) { | if (mem_type == MemType::IMG) { | ||||
| size_t dtype_size = 0; | |||||
| if (img_size.dtype == CL_FLOAT) { | |||||
| dtype_size = sizeof(cl_float); | |||||
| } else if (img_size.dtype == CL_HALF_FLOAT) { | |||||
| dtype_size = sizeof(cl_half); | |||||
| } else if (img_size.dtype == CL_UNSIGNED_INT8) { | |||||
| dtype_size = sizeof(cl_uchar); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported dtype " << img_size.dtype; | |||||
| return nullptr; | |||||
| } | |||||
| uint32_t image_alignment = ocl_runtime_->GetImagePitchAlignment(); | |||||
| size = UP_ROUND(img_size.width, image_alignment) * img_size.height * C4NUM * dtype_size; | |||||
| size = GetImgDtypeSize(img_size); | |||||
| } | } | ||||
| if (size > ocl_runtime_->GetMaxAllocSize()) { | if (size > ocl_runtime_->GetMaxAllocSize()) { | ||||
| MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; | MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; | ||||
| @@ -75,6 +75,7 @@ class OpenCLAllocator : public Allocator { | |||||
| void *CreateBuffer(size_t size, void *data, size_t flags, cl::Buffer **buffer); | void *CreateBuffer(size_t size, void *data, size_t flags, cl::Buffer **buffer); | ||||
| void *CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, | void *CreateImage2D(size_t size, const ImageSize &img_size, void *data, size_t flags, bool is_map, | ||||
| cl::Buffer **buffer, cl::Image2D **image); | cl::Buffer **buffer, cl::Image2D **image); | ||||
| int GetImgDtypeSize(const ImageSize &img_size); | |||||
| template <typename T> | template <typename T> | ||||
| void ClearMemList(T *list); | void ClearMemList(T *list); | ||||