Merge pull request !4337 from chenzhongming/litetags/v0.7.0-beta
| @@ -32,6 +32,10 @@ enum Format : int { | |||
| CKHW, | |||
| KHWC, | |||
| CHWK, | |||
| HW, | |||
| HW4, | |||
| NC, | |||
| NC4, | |||
| NC4HW4 = 100, | |||
| NUM_OF_FORMAT | |||
| } | |||
| @@ -104,7 +104,7 @@ int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format d | |||
| allocator->Free(src_data); | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -116,7 +116,7 @@ int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format | |||
| MS_ASSERT(4 == tensor->shape().size()); | |||
| // auto src_format = tensor->GetFormat(); | |||
| // todo | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in uint8"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -104,8 +104,8 @@ bool Tensor::operator==(const Value &other) const { | |||
| } | |||
| int32_t Tensor::Batch() const { | |||
| if (this->shape_.size() != 4) { | |||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||
| return -1; | |||
| } | |||
| switch (this->format_) { | |||
| @@ -115,6 +115,8 @@ int32_t Tensor::Batch() const { | |||
| case schema::Format_NC4HW4: | |||
| case schema::Format_KCHW: | |||
| case schema::Format_KHWC: | |||
| case schema::Format_NC: | |||
| case schema::Format_NC4: | |||
| return this->shape_[0]; | |||
| case schema::Format_HWCK: | |||
| case schema::Format_CHWK: | |||
| @@ -124,19 +126,21 @@ int32_t Tensor::Batch() const { | |||
| case schema::Format_CKHW: | |||
| return this->shape_[1]; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); | |||
| MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); | |||
| return -1; | |||
| } | |||
| } | |||
| int32_t Tensor::Channel() const { | |||
| if (this->shape_.size() != 4) { | |||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||
| return -1; | |||
| } | |||
| switch (this->format_) { | |||
| case schema::Format_NCHW: | |||
| case schema::Format_KCHW: | |||
| case schema::Format_NC: | |||
| case schema::Format_NC4: | |||
| return this->shape_[1]; | |||
| case schema::Format_HWCK: | |||
| return this->shape_[2]; | |||
| @@ -155,8 +159,8 @@ int32_t Tensor::Channel() const { | |||
| } | |||
| int32_t Tensor::Height() const { | |||
| if (this->shape_.size() != 4) { | |||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||
| return -1; | |||
| } | |||
| switch (this->format_) { | |||
| @@ -172,16 +176,18 @@ int32_t Tensor::Height() const { | |||
| return this->shape_[1]; | |||
| case schema::Format_HWCK: | |||
| case schema::Format_HWKC: | |||
| case schema::Format_HW: | |||
| case schema::Format_HW4: | |||
| return this->shape_[0]; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); | |||
| MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); | |||
| return -1; | |||
| } | |||
| } | |||
| int32_t Tensor::Width() const { | |||
| if (this->shape_.size() != 4) { | |||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||
| return -1; | |||
| } | |||
| switch (this->format_) { | |||
| @@ -197,12 +203,24 @@ int32_t Tensor::Width() const { | |||
| return this->shape_[2]; | |||
| case schema::Format_HWCK: | |||
| case schema::Format_HWKC: | |||
| case schema::Format_HW: | |||
| case schema::Format_HW4: | |||
| return this->shape_[1]; | |||
| default: | |||
| return -1; | |||
| } | |||
| } | |||
| int32_t Tensor::ElementsC4Num() const { | |||
| int32_t result = 0; | |||
| if (this->shape_.size() == 4) { | |||
| result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); | |||
| } else if (this->shape_.size() == 2) { | |||
| result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4); | |||
| } | |||
| return result; | |||
| } | |||
| std::string Tensor::ToString() const { | |||
| std::ostringstream oss; | |||
| oss << "Format: " << schema::EnumNameFormat(this->format_); | |||
| @@ -235,7 +253,7 @@ std::string Tensor::ToString() const { | |||
| } | |||
| } break; | |||
| default: | |||
| oss << "Unsupport data type to print"; | |||
| oss << "Unsupported data type to print"; | |||
| break; | |||
| } | |||
| return oss.str(); | |||
| @@ -66,7 +66,7 @@ class Tensor : public mindspore::tensor::MetaTensor { | |||
| int32_t Width() const; | |||
| int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); } | |||
| int32_t ElementsC4Num() const; | |||
| int DataSize() const { return this->ElementsNum(); } | |||
| @@ -37,7 +37,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) { | |||
| MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input->data_type(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { | |||
| @@ -74,7 +74,7 @@ int CastCPUKernel::DoCast(int thread_id) { | |||
| Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset, | |||
| reinterpret_cast<int32_t *>(output_data) + offset, data_num); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type; | |||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| } else { | |||
| @@ -88,7 +88,7 @@ int CastCPUKernel::DoCast(int thread_id) { | |||
| reinterpret_cast<float *>(output_data) + offset, data_num); | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; | |||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| @@ -1,21 +1,15 @@ | |||
| #define SLICES 4 | |||
| int DivideRoundUp(int n, int div) { | |||
| int q = n / div; | |||
| return n % div == 0 ? q : q + 1; | |||
| } | |||
| __kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) { | |||
| int X = get_global_id(0); // width | |||
| int Y = get_global_id(1); // height | |||
| int H = input_shape.y; | |||
| int W = input_shape.z; | |||
| int C = input_shape.w; | |||
| __kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) { | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| int H = input_shape.x; | |||
| int W = input_shape.y; | |||
| int C = input_shape.z; | |||
| int S = input_shape.w; | |||
| if (X >= W || Y >= H) return; | |||
| float sum = 0.0f; | |||
| for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { | |||
| for (int d = 0; d < S; ++d) { | |||
| float4 t = input[(Y * W + X * H) * C + d]; | |||
| sum += exp(t.x); | |||
| if (d * 4 + 1 < C) sum += exp(t.y); | |||
| @@ -23,10 +17,34 @@ __kernel void SoftMax(__global float4 *input, __global float4 *output, const int | |||
| if (d * 4 + 3 < C) sum += exp(t.w); | |||
| } | |||
| for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { | |||
| for (int d = 0; d < S; ++d) { | |||
| float4 t = input[(Y * W + X * H) * C + d]; | |||
| t = exp(t) / sum; | |||
| float4 result = convert_float4(t); | |||
| output[(Y * W + X * H) * C + d] = result; | |||
| } | |||
| } | |||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| __kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { | |||
| int X = get_global_id(0); | |||
| int Y = get_global_id(1); | |||
| if (X >= input_shape.x || Y >= input_shape.y) return; | |||
| float sum = 0.0f; | |||
| for (int d = 0; d < input_shape.w; ++d) { | |||
| float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); | |||
| sum += exp(t.x); | |||
| if (d * 4 + 1 < input_shape.z) sum += exp(t.y); | |||
| if (d * 4 + 2 < input_shape.z) sum += exp(t.z); | |||
| if (d * 4 + 3 < input_shape.z) sum += exp(t.w); | |||
| } | |||
| for (int d = 0; d < input_shape.w; ++d) { | |||
| float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); | |||
| t = exp(t) / sum; | |||
| float4 result = convert_float4(t); | |||
| write_imagef(output, (int2)(Y * input_shape.w + d, X), result); | |||
| } | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||
| // what is mask and args.slices_x32 | |||
| __kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, | |||
| const int slices, const int slices_x32) { | |||
| int tid = get_local_id(0); | |||
| int slices_count = 0; | |||
| int offset = 0; | |||
| float sum = 0.0f; | |||
| do { | |||
| int z = offset + tid; | |||
| if (z < slices) { | |||
| float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f); | |||
| float4 src = read_imagef(input, smp_none, (int2)(0, 0)); | |||
| sum += dot(mask_temp, exp(src)); | |||
| offset += 32; | |||
| } | |||
| slices_count++; | |||
| } while (slices_count < slices_x32); | |||
| __local float4 tmp[8]; | |||
| __local float *tmpx1 = (__local float *)tmp; | |||
| tmpx1[tid] = sum; | |||
| barrier(CLK_LOCAL_MEM_FENCE); | |||
| if (tid == 0) { | |||
| sum = dot((float4)(1.0f), tmp[0]); | |||
| sum += dot((float4)(1.0f), tmp[1]); | |||
| sum += dot((float4)(1.0f), tmp[2]); | |||
| sum += dot((float4)(1.0f), tmp[3]); | |||
| sum += dot((float4)(1.0f), tmp[4]); | |||
| sum += dot((float4)(1.0f), tmp[5]); | |||
| sum += dot((float4)(1.0f), tmp[6]); | |||
| sum += dot((float4)(1.0f), tmp[7]); | |||
| tmpx1[0] = 1.0f / sum; | |||
| } | |||
| barrier(CLK_LOCAL_MEM_FENCE); | |||
| sum = tmpx1[0]; | |||
| offset = 0; | |||
| slices_count = 0; | |||
| do { | |||
| int z = offset + tid; | |||
| if (z < slices) { | |||
| float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum); | |||
| write_imagef(output, (int2)(0, 0), res); | |||
| offset += 32; | |||
| } | |||
| slices_count++; | |||
| } while (slices_count < slices_x32); | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,69 +17,143 @@ | |||
| #include "src/runtime/kernel/opencl/kernel/softmax.h" | |||
| #include <string> | |||
| #include <set> | |||
| #include "include/errorcode.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| #ifndef PROGRAM_WITH_IL | |||
| #include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" | |||
| #include "src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl.inc" | |||
| #endif | |||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::schema::PrimitiveType_SoftMax; | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace mindspore::kernel { | |||
| std::vector<float> SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) { | |||
| std::vector<float> mask{4, 0.0f}; | |||
| const int reminder = channels % 4 == 0 ? 4 : channels % 4; | |||
| for (int i = 0; i < reminder; ++i) { | |||
| mask[i] = 1.0f; | |||
| } | |||
| return mask; | |||
| } | |||
| int SoftmaxOpenCLKernel::InitGlobalSize() { | |||
| const size_t global_x = out_tensors_[0]->Height(); | |||
| const size_t global_y = out_tensors_[0]->Width(); | |||
| const size_t global_z = 1; | |||
| global_size_ = {global_x, global_y, global_z}; | |||
| return lite::RET_OK; | |||
| } | |||
| int SoftmaxOpenCLKernel::SetWorkGroupSize() { | |||
| // set work group size | |||
| InitGlobalSize(); | |||
| int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); | |||
| local_size_ = GetCommonLocalSize(global_size_, max_work_group_size); | |||
| global_size_ = GetCommonGlobalSize(local_size_, global_size_); | |||
| return lite::RET_OK; | |||
| } | |||
| int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() { | |||
| local_size_ = {32, 1, 1}; | |||
| global_size_ = {32, 1, 1}; | |||
| return lite::RET_OK; | |||
| } | |||
| int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||
| size_t im_dst_x, im_dst_y; | |||
| if (onexone_flag_) { | |||
| im_dst_x = UP_DIV(in_tensors_[0]->shape()[1], C4NUM); | |||
| im_dst_y = 1; | |||
| } else { | |||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||
| im_dst_x = out_tensors_[0]->Width() * CO4; | |||
| im_dst_y = out_tensors_[0]->Height(); | |||
| } | |||
| #ifdef ENABLE_FP16 | |||
| size_t img_dtype = CL_HALF_FLOAT; | |||
| #else | |||
| size_t img_dtype = CL_FLOAT; | |||
| #endif | |||
| img_size->clear(); | |||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||
| *img_size = vec; | |||
| return RET_OK; | |||
| } | |||
| int SoftmaxOpenCLKernel::Init() { | |||
| std::string kernel_name = "SoftMax"; | |||
| if (parameter_->axis_ != -1 && parameter_->axis_ != 3) { | |||
| MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; | |||
| return -1; | |||
| } | |||
| std::string program_name = "SoftMax"; | |||
| std::string source = softmax_source_fp32; | |||
| runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) { | |||
| // support 4d tensor | |||
| onexone_flag_ = false; | |||
| } else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) { | |||
| // support 2d tensor | |||
| kernel_name += "1x1"; | |||
| program_name += "1x1"; | |||
| source = softmax1x1_source_fp32; | |||
| onexone_flag_ = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; | |||
| } | |||
| #ifdef PROGRAM_WITH_IL | |||
| ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | |||
| runtime_->CreateKernelFromIL(kernel_(), kernel_name); | |||
| #else | |||
| if (mem_type_ == MEM_TYPE::BUF) { | |||
| kernel_name += "_BUF"; | |||
| program_name += "_BUF"; | |||
| } else { | |||
| kernel_name += "_IMG"; | |||
| program_name += "_IMG"; | |||
| } | |||
| std::set<std::string> build_options; | |||
| std::string source = softmax_source_fp32; | |||
| std::string program_name = "SoftMax"; | |||
| ocl_runtime->LoadSource(program_name, source); | |||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| #endif | |||
| runtime_->LoadSource(program_name, source); | |||
| out_tensors_[0]->SetFormat(schema::Format_NHWC4); | |||
| runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||
| #endif | |||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | |||
| return 0; | |||
| return lite::RET_OK; | |||
| } | |||
| int SoftmaxOpenCLKernel::InitBuffer() { return 0; } | |||
| int SoftmaxOpenCLKernel::ReSize() { return 0; } | |||
| int SoftmaxOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->name() << " Running!"; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| // global and local workers | |||
| const uint32_t grid_x = in_tensors_[0]->shape()[2]; // W | |||
| const uint32_t grid_y = in_tensors_[0]->shape()[1]; // H | |||
| const uint32_t grid_z = 1; | |||
| std::vector<size_t> global = {grid_x, grid_y, grid_z}; | |||
| std::vector<size_t> local = {1, 1, 1}; | |||
| // input and output | |||
| cl::Buffer *input = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(in_tensors_[0]->Data())); | |||
| cl::Buffer *output = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(out_tensors_[0]->Data())); | |||
| cl_int4 input_size = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], | |||
| in_tensors_[0]->shape()[3]}; | |||
| std::cout << "run" << std::endl; | |||
| // attribute | |||
| int arg_idx = 0; | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output); | |||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size); | |||
| if (onexone_flag_) { | |||
| int channel_size = in_tensors_[0]->shape()[1]; | |||
| int slices = UP_DIV(channel_size, C4NUM); | |||
| cl_int slices_x32 = UP_DIV(slices, 32); | |||
| auto mask_ = GetMaskForLastChannel(channel_size); | |||
| cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; | |||
| // run opengl kernel | |||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, mask); | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, slices); | |||
| runtime_->SetKernelArg(kernel_, arg_idx, slices_x32); | |||
| SetWorkGroupSize1x1(); | |||
| } else { | |||
| int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||
| cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices}; | |||
| return 0; | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||
| runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||
| runtime_->SetKernelArg(kernel_, arg_idx, input_shape); | |||
| SetWorkGroupSize(); | |||
| } | |||
| // run opengl kernel | |||
| runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr); | |||
| return lite::RET_OK; | |||
| } | |||
| kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||
| @@ -104,5 +178,4 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T | |||
| } | |||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| } // namespace mindspore::kernel | |||
| @@ -23,29 +23,37 @@ | |||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SoftmaxOpenCLKernel : public LiteKernel { | |||
| namespace mindspore::kernel { | |||
| class SoftmaxOpenCLKernel : public OpenCLKernel { | |||
| public: | |||
| explicit SoftmaxOpenCLKernel(OpParameter *parameter, | |||
| const std::vector<lite::tensor::Tensor *> &inputs, | |||
| explicit SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||
| : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { | |||
| : OpenCLKernel(parameter, inputs, outputs) { | |||
| parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter); | |||
| } | |||
| ~SoftmaxOpenCLKernel() override{}; | |||
| ~SoftmaxOpenCLKernel() override{}; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int InitBuffer(); | |||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||
| int InitGlobalSize(); | |||
| int SetWorkGroupSize1x1(); | |||
| int SetWorkGroupSize(); | |||
| std::vector<float> GetMaskForLastChannel(int channels); | |||
| private: | |||
| SoftmaxParameter *parameter_; | |||
| cl::Kernel kernel_; | |||
| SoftmaxParameter *parameter_; | |||
| lite::opencl::OpenCLRuntime *runtime_; | |||
| enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; | |||
| bool onexone_flag_{false}; | |||
| std::vector<size_t> local_size_; | |||
| std::vector<size_t> global_size_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ | |||
| @@ -175,4 +175,3 @@ std::string CLErrorCode(cl_int error_code) { | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -85,4 +85,3 @@ std::string CLErrorCode(cl_int error_code); | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ | |||
| @@ -19,6 +19,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "src/runtime/opencl/opencl_runtime.h" | |||
| #include "include/errorcode.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| namespace mindspore::lite::opencl { | |||
| @@ -128,7 +129,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size) | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, | |||
| img_size[0], img_size[1], 0, nullptr, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed!" << kernel::CLErrorCode(ret); | |||
| UnLock(); | |||
| delete buffer; | |||
| return nullptr; | |||
| @@ -187,7 +188,7 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v | |||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, | |||
| image_format, img_size[0], img_size[1], 0, data, &ret); | |||
| if (ret != CL_SUCCESS) { | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed - " << kernel::CLErrorCode(ret); | |||
| UnLock(); | |||
| delete buffer; | |||
| return nullptr; | |||
| @@ -52,6 +52,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso | |||
| std::vector<size_t> img_size; | |||
| op_kernel->GetImageSize(i, &img_size); | |||
| auto data_ptr = op_allocator->Malloc(output->Size(), img_size); | |||
| output->SetData(data_ptr); | |||
| } else { | |||
| output->MallocData(allocator); | |||
| @@ -109,7 +110,7 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format | |||
| case kNumberTypeFloat32: | |||
| return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); | |||
| default: | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -160,7 +161,7 @@ int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema | |||
| // TODO(wandongdong): add support !! | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -194,7 +195,7 @@ int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema: | |||
| allocator_->Free(src_data); | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -216,7 +217,7 @@ int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schem | |||
| allocator_->Free(src_data); | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in float32"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -228,7 +229,7 @@ int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::F | |||
| MS_ASSERT(4 == tensor->shape().size()); | |||
| // auto src_format = tensor->GetFormat(); | |||
| // todo | |||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||
| << schema::EnumNameFormat(dst_format) << " in uint8"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -17,76 +17,90 @@ | |||
| #include <memory> | |||
| #include "mindspore/core/utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | |||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||
| namespace mindspore { | |||
| class TestSoftmaxOpenCL : public mindspore::CommonTest {}; | |||
| void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; } | |||
| TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { | |||
| std::cout << "======" << std::endl; | |||
| MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL"; | |||
| void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, std::string input_file, | |||
| std::string expect_file, SoftmaxParameter *param, schema::Format format) { | |||
| std::cout << "runtime" << std::endl; | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| MS_LOG(INFO) << "create SoftmaxParameter"; | |||
| auto param = new SoftmaxParameter(); | |||
| InitSoftaxParam(param); | |||
| // define tensor | |||
| MS_LOG(INFO) << "defineTensor"; | |||
| std::cout << "defineTensor" << std::endl; | |||
| MS_LOG(INFO) << "create Tensors"; | |||
| std::vector<int> shape_in = {1, 2, 2, 1}; | |||
| std::vector<int> shape_out = {1, 2, 2, 1}; | |||
| auto data_type = kNumberTypeFloat32; | |||
| auto tensorType = schema::NodeType_ValueNode; | |||
| lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType); | |||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType); | |||
| std::vector<lite::tensor::Tensor *> inputs{tensor_in}; | |||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||
| auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, format, tensorType); | |||
| auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, format, tensorType); | |||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||
| MS_LOG(INFO) << "create OpenCL Kernel"; | |||
| auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| Softmax_kernel->Init(); | |||
| std::vector<kernel::LiteKernel *> kernels{Softmax_kernel}; | |||
| // run | |||
| MS_LOG(INFO) << "NewOpenCLKernel"; | |||
| std::cout << "NewOpenCLKernel" << std::endl; | |||
| auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||
| MS_LOG(INFO) << "KernelInit"; | |||
| std::cout << "KernelInit" << std::endl; | |||
| kernel->Init(); | |||
| MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | |||
| std::cout << "LiteKernel" << std::endl; | |||
| std::vector<kernel::LiteKernel *> kernels{kernel}; | |||
| inputs[0]->MallocData(allocator); | |||
| std::cout << "SubGraphOpenCLKernel" << std::endl; | |||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||
| MS_LOG(INFO) << "pGraphinit"; | |||
| pGraph->Init(); | |||
| MS_LOG(INFO) << "initialize data"; | |||
| std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in}; | |||
| for (auto &tensor_file : tensor_map) { | |||
| auto tensor = tensor_file; | |||
| size_t size = tensor->Size(); | |||
| const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)}; | |||
| memcpy(tensor->Data(), data, size); | |||
| // load data | |||
| MS_LOG(INFO) << "load data1"; | |||
| LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); | |||
| auto *input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||
| printf("\ninput[0:10]:"); | |||
| for (int i = 0; i < 10; i++) { | |||
| printf("[%d]:%.3f ", i, input_data[i]); | |||
| } | |||
| printf("\n\n"); | |||
| MS_LOG(INFO) << "pGraph->Run()"; | |||
| MS_LOG(INFO) << "Run"; | |||
| pGraph->Run(); | |||
| MS_LOG(INFO) << "==================output data================="; | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||
| size_t output_size = tensor_out->Size(); | |||
| MS_LOG(INFO) << "compare result"; | |||
| std::cout << "compare result" << std::endl; | |||
| CompareOutput(output_tensor, expect_file); | |||
| } | |||
| printf("output:"); | |||
| for (int i = 0; i < 4; i++) { | |||
| printf("%.3f ", output_data[i]); | |||
| } | |||
| printf("\n"); | |||
| float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f}; | |||
| TEST_F(TestSoftmaxOpenCL, Softmax_1) { | |||
| std::vector<int> input_shape = {1, 2, 2, 8}; | |||
| std::vector<int> output_shape = {1, 2, 2, 8}; | |||
| std::string input_file = "softmax_in.bin"; | |||
| std::string expect_file = "softmax_out.bin"; | |||
| auto param = new SoftmaxParameter; | |||
| param->axis_ = 3; | |||
| schema::Format format = schema::Format_NHWC4; | |||
| for (int i = 0; i < tensor_out->ElementsNum(); ++i) { | |||
| if (std::fabs(output_data[i] - expect[i]) > 1e-5) { | |||
| printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]); | |||
| } | |||
| } | |||
| printf("\nTest all close OK for %zu!\n", output_size); | |||
| lite::CompareOutputData(output_data, expect, 4); | |||
| RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); | |||
| } | |||
| // TEST_F(TestSoftmaxOpenCL, Softmax_1x1) { | |||
| // std::vector<int> input_shape = {1, 100}; | |||
| // std::vector<int> output_shape = {1, 100}; | |||
| // std::string input_file = "softmax1x1_in.bin"; | |||
| // std::string expect_file = "softmax1x1_out.bin"; | |||
| // auto param = new SoftmaxParameter; | |||
| // param->axis_ = 1; | |||
| // schema::Format format = schema::Format_NHWC4; | |||
| // | |||
| // RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); | |||
| //} | |||
| } // namespace mindspore | |||
| @@ -40,13 +40,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_ | |||
| size_t output_size = output_tensor->Size(); | |||
| float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||
| printf("output[0:10]:"); | |||
| for (int i = 0; i < 10; i++) { | |||
| printf("output[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, output_data[i]); | |||
| } | |||
| printf("\n"); | |||
| printf("expect[0:10]:"); | |||
| for (int i = 0; i < 10; i++) { | |||
| printf("expect[0:12]:"); | |||
| for (int i = 0; i < 12; i++) { | |||
| printf("[%d]:%.3f ", i, expect_data[i]); | |||
| } | |||
| printf("\n"); | |||
| @@ -157,7 +157,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | |||
| weightTensor->format = schema::Format_CHWK; | |||
| } else { | |||
| MS_LOG(ERROR) << "unsupport format"; | |||
| MS_LOG(ERROR) << "Unsupported format"; | |||
| return -1; | |||
| } | |||
| } break; | |||
| @@ -184,7 +184,7 @@ size_t GetDataTypeSize(const TypeId &data_type) { | |||
| return sizeof(int64_t); | |||
| default: | |||
| MS_LOG(ERROR) << data_type; | |||
| MS_LOG(ERROR) << "unsupport datatype"; | |||
| MS_LOG(ERROR) << "Unsupported datatype"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||