diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a8ee4858ce..e6584dc059 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -32,6 +32,10 @@ enum Format : int { CKHW, KHWC, CHWK, + HW, + HW4, + NC, + NC4, NC4HW4 = 100, NUM_OF_FORMAT } diff --git a/mindspore/lite/src/executor.cc b/mindspore/lite/src/executor.cc index 3dc546a74a..bcd20cf07c 100644 --- a/mindspore/lite/src/executor.cc +++ b/mindspore/lite/src/executor.cc @@ -104,7 +104,7 @@ int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format d allocator->Free(src_data); return RET_OK; } else { - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in float32"; return RET_ERROR; } @@ -116,7 +116,7 @@ int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format MS_ASSERT(4 == tensor->shape().size()); // auto src_format = tensor->GetFormat(); // todo - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in uint8"; return RET_ERROR; } diff --git a/mindspore/lite/src/ir/tensor.cc b/mindspore/lite/src/ir/tensor.cc index cd9fb766d6..091ec942ae 100644 --- a/mindspore/lite/src/ir/tensor.cc +++ b/mindspore/lite/src/ir/tensor.cc @@ -104,8 +104,8 @@ bool Tensor::operator==(const Value &other) const { } int32_t Tensor::Batch() const { - if (this->shape_.size() != 4) { - MS_LOG(ERROR) << "tensor should have 4 dim"; + if (this->shape_.size() != 4 && this->shape_.size() != 2) { + MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); return -1; } switch (this->format_) { @@ -115,6 +115,8 @@ int32_t Tensor::Batch() const { case schema::Format_NC4HW4: case schema::Format_KCHW: case schema::Format_KHWC: + case schema::Format_NC: + case schema::Format_NC4: return this->shape_[0]; case schema::Format_HWCK: case schema::Format_CHWK: @@ -124,19 +126,21 @@ int32_t Tensor::Batch() const { case schema::Format_CKHW: return this->shape_[1]; default: - MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); return -1; } } int32_t Tensor::Channel() const { - if (this->shape_.size() != 4) { - MS_LOG(ERROR) << "tensor should have 4 dim"; + if (this->shape_.size() != 4 && this->shape_.size() != 2) { + MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); return -1; } switch (this->format_) { case schema::Format_NCHW: case schema::Format_KCHW: + case schema::Format_NC: + case schema::Format_NC4: return this->shape_[1]; case schema::Format_HWCK: return this->shape_[2]; @@ -155,8 +159,8 @@ int32_t Tensor::Channel() const { } int32_t Tensor::Height() const { - if (this->shape_.size() != 4) { - MS_LOG(ERROR) << "tensor should have 4 dim"; + if (this->shape_.size() != 4 && this->shape_.size() != 2) { + MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); return -1; } switch (this->format_) { @@ -172,16 +176,18 @@ int32_t Tensor::Height() const { return this->shape_[1]; case schema::Format_HWCK: case schema::Format_HWKC: + case schema::Format_HW: + case schema::Format_HW4: return this->shape_[0]; default: - MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); + MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); return -1; } } int32_t Tensor::Width() const { - if (this->shape_.size() != 4) { - MS_LOG(ERROR) << "tensor should have 4 dim"; + if (this->shape_.size() != 4 && this->shape_.size() != 2) { + MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); return -1; } switch (this->format_) { @@ -197,12 +203,24 @@ int32_t Tensor::Width() const { return this->shape_[2]; case schema::Format_HWCK: case schema::Format_HWKC: + case schema::Format_HW: + case schema::Format_HW4: return this->shape_[1]; default: return -1; } } +int32_t Tensor::ElementsC4Num() const { + int32_t result = 0; + if (this->shape_.size() == 4) { + result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); + } else if (this->shape_.size() == 2) { + result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4); + } + return result; +} + std::string Tensor::ToString() const { std::ostringstream oss; oss << "Format: " << schema::EnumNameFormat(this->format_); @@ -235,7 +253,7 @@ std::string Tensor::ToString() const { } } break; default: - oss << "Unsupport data type to print"; + oss << "Unsupported data type to print"; break; } return oss.str(); diff --git a/mindspore/lite/src/ir/tensor.h b/mindspore/lite/src/ir/tensor.h index 6dc21c4613..cc1ccafc33 100644 --- a/mindspore/lite/src/ir/tensor.h +++ b/mindspore/lite/src/ir/tensor.h @@ -66,7 +66,7 @@ class Tensor : public mindspore::tensor::MetaTensor { int32_t Width() const; - int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); } + int32_t ElementsC4Num() const; int DataSize() const { return this->ElementsNum(); } diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 565f8de767..2115b9e76e 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -37,7 +37,7 @@ int Cast::InferShape(std::vector inputs_, std::vectordata_type()) == kSupportDataType.end()) { - MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + MS_LOG(ERROR) << "Unsupported input data type " << input->data_type(); return RET_INPUT_TENSOR_ERROR; } if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 3363ef6d51..b39a5ef407 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -74,7 +74,7 @@ int CastCPUKernel::DoCast(int thread_id) { Float32ToInt32(reinterpret_cast(input->Data()) + offset, reinterpret_cast(output_data) + offset, data_num); } else { - MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type; + MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR; } } else { @@ -88,7 +88,7 @@ int CastCPUKernel::DoCast(int thread_id) { reinterpret_cast(output_data) + offset, data_num); break; default: - MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; + MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; return RET_ERROR; } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl index f1a5c69d94..5ecff6a4d7 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax.cl @@ -1,21 +1,15 @@ -#define SLICES 4 - -int DivideRoundUp(int n, int div) { - int q = n / div; - return n % div == 0 ? q : q + 1; -} - -__kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) { - int X = get_global_id(0); // width - int Y = get_global_id(1); // height - int H = input_shape.y; - int W = input_shape.z; - int C = input_shape.w; +__kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + int H = input_shape.x; + int W = input_shape.y; + int C = input_shape.z; + int S = input_shape.w; if (X >= W || Y >= H) return; float sum = 0.0f; - for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + for (int d = 0; d < S; ++d) { float4 t = input[(Y * W + X * H) * C + d]; sum += exp(t.x); if (d * 4 + 1 < C) sum += exp(t.y); @@ -23,10 +17,34 @@ __kernel void SoftMax(__global float4 *input, __global float4 *output, const int if (d * 4 + 3 < C) sum += exp(t.w); } - for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { + for (int d = 0; d < S; ++d) { float4 t = input[(Y * W + X * H) * C + d]; t = exp(t) / sum; float4 result = convert_float4(t); output[(Y * W + X * H) * C + d] = result; } } + +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +__kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { + int X = get_global_id(0); + int Y = get_global_id(1); + if (X >= input_shape.x || Y >= input_shape.y) return; + + float sum = 0.0f; + for (int d = 0; d < input_shape.w; ++d) { + float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); + sum += exp(t.x); + if (d * 4 + 1 < input_shape.z) sum += exp(t.y); + if (d * 4 + 2 < input_shape.z) sum += exp(t.z); + if (d * 4 + 3 < input_shape.z) sum += exp(t.w); + } + + for (int d = 0; d < input_shape.w; ++d) { + float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); + t = exp(t) / sum; + float4 result = convert_float4(t); + write_imagef(output, (int2)(Y * input_shape.w + d, X), result); + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl new file mode 100644 index 0000000000..f74197e7b6 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl @@ -0,0 +1,50 @@ +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; + +// what is mask and args.slices_x32 +__kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, + const int slices, const int slices_x32) { + int tid = get_local_id(0); + int slices_count = 0; + int offset = 0; + float sum = 0.0f; + do { + int z = offset + tid; + if (z < slices) { + float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f); + float4 src = read_imagef(input, smp_none, (int2)(0, 0)); + sum += dot(mask_temp, exp(src)); + offset += 32; + } + slices_count++; + } while (slices_count < slices_x32); + + __local float4 tmp[8]; + __local float *tmpx1 = (__local float *)tmp; + tmpx1[tid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (tid == 0) { + sum = dot((float4)(1.0f), tmp[0]); + sum += dot((float4)(1.0f), tmp[1]); + sum += dot((float4)(1.0f), tmp[2]); + sum += dot((float4)(1.0f), tmp[3]); + sum += dot((float4)(1.0f), tmp[4]); + sum += dot((float4)(1.0f), tmp[5]); + sum += dot((float4)(1.0f), tmp[6]); + sum += dot((float4)(1.0f), tmp[7]); + tmpx1[0] = 1.0f / sum; + } + barrier(CLK_LOCAL_MEM_FENCE); + sum = tmpx1[0]; + + offset = 0; + slices_count = 0; + do { + int z = offset + tid; + if (z < slices) { + float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum); + write_imagef(output, (int2)(0, 0), res); + offset += 32; + } + slices_count++; + } while (slices_count < slices_x32); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc index 0148ae35fb..5613f456fb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,69 +17,143 @@ #include "src/runtime/kernel/opencl/kernel/softmax.h" #include #include +#include "include/errorcode.h" #include "src/kernel_registry.h" #include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/utils.h" #ifndef PROGRAM_WITH_IL #include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" +#include "src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl.inc" #endif using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::lite::KernelRegistrar; using mindspore::schema::PrimitiveType_SoftMax; -namespace mindspore { -namespace kernel { +namespace mindspore::kernel { + +std::vector SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) { + std::vector mask{4, 0.0f}; + const int reminder = channels % 4 == 0 ? 4 : channels % 4; + for (int i = 0; i < reminder; ++i) { + mask[i] = 1.0f; + } + return mask; +} + +int SoftmaxOpenCLKernel::InitGlobalSize() { + const size_t global_x = out_tensors_[0]->Height(); + const size_t global_y = out_tensors_[0]->Width(); + const size_t global_z = 1; + global_size_ = {global_x, global_y, global_z}; + return lite::RET_OK; +} + +int SoftmaxOpenCLKernel::SetWorkGroupSize() { + // set work group size + InitGlobalSize(); + int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); + local_size_ = GetCommonLocalSize(global_size_, max_work_group_size); + global_size_ = GetCommonGlobalSize(local_size_, global_size_); + return lite::RET_OK; +} + +int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() { + local_size_ = {32, 1, 1}; + global_size_ = {32, 1, 1}; + return lite::RET_OK; +} + +int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + size_t im_dst_x, im_dst_y; + if (onexone_flag_) { + im_dst_x = UP_DIV(in_tensors_[0]->shape()[1], C4NUM); + im_dst_y = 1; + } else { + size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); + im_dst_x = out_tensors_[0]->Width() * CO4; + im_dst_y = out_tensors_[0]->Height(); + } +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + img_size->clear(); + std::vector vec{im_dst_x, im_dst_y, img_dtype}; + *img_size = vec; + return RET_OK; +} + int SoftmaxOpenCLKernel::Init() { std::string kernel_name = "SoftMax"; - if (parameter_->axis_ != -1 && parameter_->axis_ != 3) { - MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; - return -1; - } + std::string program_name = "SoftMax"; + std::string source = softmax_source_fp32; + runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) { + // support 4d tensor + onexone_flag_ = false; + } else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) { + // support 2d tensor + kernel_name += "1x1"; + program_name += "1x1"; + source = softmax1x1_source_fp32; + onexone_flag_ = true; + } else { + MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; + } #ifdef PROGRAM_WITH_IL - ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); + runtime_->CreateKernelFromIL(kernel_(), kernel_name); #else + if (mem_type_ == MEM_TYPE::BUF) { + kernel_name += "_BUF"; + program_name += "_BUF"; + } else { + kernel_name += "_IMG"; + program_name += "_IMG"; + } std::set build_options; - std::string source = softmax_source_fp32; - std::string program_name = "SoftMax"; - ocl_runtime->LoadSource(program_name, source); - ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); -#endif + runtime_->LoadSource(program_name, source); out_tensors_[0]->SetFormat(schema::Format_NHWC4); + runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); +#endif MS_LOG(DEBUG) << kernel_name << " Init Done!"; - return 0; + return lite::RET_OK; } -int SoftmaxOpenCLKernel::InitBuffer() { return 0; } -int SoftmaxOpenCLKernel::ReSize() { return 0; } - int SoftmaxOpenCLKernel::Run() { MS_LOG(DEBUG) << this->name() << " Running!"; - auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - auto allocator = ocl_runtime->GetAllocator(); - - // global and local workers - const uint32_t grid_x = in_tensors_[0]->shape()[2]; // W - const uint32_t grid_y = in_tensors_[0]->shape()[1]; // H - const uint32_t grid_z = 1; - std::vector global = {grid_x, grid_y, grid_z}; - std::vector local = {1, 1, 1}; - - // input and output - cl::Buffer *input = reinterpret_cast(allocator->GetDeviceBuffer(in_tensors_[0]->Data())); - cl::Buffer *output = reinterpret_cast(allocator->GetDeviceBuffer(out_tensors_[0]->Data())); - cl_int4 input_size = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], - in_tensors_[0]->shape()[3]}; + std::cout << "run" << std::endl; + + // attribute int arg_idx = 0; - ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output); - ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size); + if (onexone_flag_) { + int channel_size = in_tensors_[0]->shape()[1]; + int slices = UP_DIV(channel_size, C4NUM); + cl_int slices_x32 = UP_DIV(slices, 32); + auto mask_ = GetMaskForLastChannel(channel_size); + cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; - // run opengl kernel - ocl_runtime->RunKernel(kernel_, global, local, nullptr); + runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + runtime_->SetKernelArg(kernel_, arg_idx++, mask); + runtime_->SetKernelArg(kernel_, arg_idx++, slices); + runtime_->SetKernelArg(kernel_, arg_idx, slices_x32); + SetWorkGroupSize1x1(); + } else { + int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM); + cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices}; - return 0; + runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); + runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); + runtime_->SetKernelArg(kernel_, arg_idx, input_shape); + SetWorkGroupSize(); + } + + // run opengl kernel + runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr); + return lite::RET_OK; } kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, @@ -104,5 +178,4 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector &inputs, + explicit SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { + : OpenCLKernel(parameter, inputs, outputs) { parameter_ = reinterpret_cast(parameter); } - ~SoftmaxOpenCLKernel() override{}; + ~SoftmaxOpenCLKernel() override{}; int Init() override; - int ReSize() override; int Run() override; - int InitBuffer(); + int GetImageSize(size_t idx, std::vector *img_size) override; + + int InitGlobalSize(); + int SetWorkGroupSize1x1(); + int SetWorkGroupSize(); + std::vector GetMaskForLastChannel(int channels); private: - SoftmaxParameter *parameter_; cl::Kernel kernel_; + SoftmaxParameter *parameter_; + lite::opencl::OpenCLRuntime *runtime_; + enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; + + bool onexone_flag_{false}; + std::vector local_size_; + std::vector global_size_; }; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index 9f725d62e2..aab9795734 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -175,4 +175,3 @@ std::string CLErrorCode(cl_int error_code) { } } // namespace kernel } // namespace mindspore - diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.h b/mindspore/lite/src/runtime/kernel/opencl/utils.h index 1593a93ee1..f34135c29c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.h +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.h @@ -85,4 +85,3 @@ std::string CLErrorCode(cl_int error_code); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ - diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc index 9e94ee89c8..7327f335ea 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc @@ -19,6 +19,7 @@ #include "utils/log_adapter.h" #include "src/runtime/opencl/opencl_runtime.h" #include "include/errorcode.h" +#include "src/runtime/kernel/opencl/utils.h" namespace mindspore::lite::opencl { @@ -128,7 +129,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector& img_size) cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, img_size[0], img_size[1], 0, nullptr, &ret); if (ret != CL_SUCCESS) { - MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; + MS_LOG(ERROR) << "Create OpenCL Image2D failed!" << kernel::CLErrorCode(ret); UnLock(); delete buffer; return nullptr; @@ -187,7 +188,7 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, img_size[0], img_size[1], 0, data, &ret); if (ret != CL_SUCCESS) { - MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; + MS_LOG(ERROR) << "Create OpenCL Image2D failed - " << kernel::CLErrorCode(ret); UnLock(); delete buffer; return nullptr; diff --git a/mindspore/lite/src/runtime/opencl/opencl_executor.cc b/mindspore/lite/src/runtime/opencl/opencl_executor.cc index 1103281ba7..6a98a11c55 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_executor.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_executor.cc @@ -52,6 +52,7 @@ int OpenCLExecutor::Run(std::vector &inputs, std::vector img_size; op_kernel->GetImageSize(i, &img_size); auto data_ptr = op_allocator->Malloc(output->Size(), img_size); + output->SetData(data_ptr); } else { output->MallocData(allocator); @@ -109,7 +110,7 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format case kNumberTypeFloat32: return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); default: - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format); return RET_ERROR; } @@ -160,7 +161,7 @@ int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema // TODO(wandongdong): add support !! return RET_OK; } else { - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in float32"; return RET_ERROR; } @@ -194,7 +195,7 @@ int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema: allocator_->Free(src_data); return RET_OK; } else { - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in float32"; return RET_ERROR; } @@ -216,7 +217,7 @@ int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schem allocator_->Free(src_data); return RET_OK; } else { - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in float32"; return RET_ERROR; } @@ -228,7 +229,7 @@ int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::F MS_ASSERT(4 == tensor->shape().size()); // auto src_format = tensor->GetFormat(); // todo - MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " + MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " << schema::EnumNameFormat(dst_format) << " in uint8"; return RET_ERROR; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc index 6d81702c07..5ba3394d5b 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/softmax_tests.cc @@ -17,76 +17,90 @@ #include #include "mindspore/core/utils/log_adapter.h" #include "common/common_test.h" -#include "mindspore/lite/src/common/file_utils.h" #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" +#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" namespace mindspore { class TestSoftmaxOpenCL : public mindspore::CommonTest {}; -void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; } - -TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { - std::cout << "======" << std::endl; - MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL"; +void RunTestCase(std::vector input_shape, std::vector output_shape, std::string input_file, + std::string expect_file, SoftmaxParameter *param, schema::Format format) { + std::cout << "runtime" << std::endl; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); - MS_LOG(INFO) << "create SoftmaxParameter"; - auto param = new SoftmaxParameter(); - InitSoftaxParam(param); + // define tensor + MS_LOG(INFO) << "defineTensor"; + std::cout << "defineTensor" << std::endl; - MS_LOG(INFO) << "create Tensors"; - std::vector shape_in = {1, 2, 2, 1}; - std::vector shape_out = {1, 2, 2, 1}; auto data_type = kNumberTypeFloat32; auto tensorType = schema::NodeType_ValueNode; - lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType); - lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType); - std::vector inputs{tensor_in}; - std::vector outputs{tensor_out}; + auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, format, tensorType); + auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, format, tensorType); + std::vector inputs{input_tensor}; + std::vector outputs{output_tensor}; - MS_LOG(INFO) << "create OpenCL Kernel"; - auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); - Softmax_kernel->Init(); - std::vector kernels{Softmax_kernel}; + // run + MS_LOG(INFO) << "NewOpenCLKernel"; + std::cout << "NewOpenCLKernel" << std::endl; + auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast(param), inputs, outputs); + MS_LOG(INFO) << "KernelInit"; + std::cout << "KernelInit" << std::endl; + kernel->Init(); - MS_LOG(INFO) << "create SubGraphOpenCLKernel"; + std::cout << "LiteKernel" << std::endl; + std::vector kernels{kernel}; + inputs[0]->MallocData(allocator); + std::cout << "SubGraphOpenCLKernel" << std::endl; auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + MS_LOG(INFO) << "pGraphinit"; pGraph->Init(); - MS_LOG(INFO) << "initialize data"; - std::vector tensor_map = {tensor_in}; - for (auto &tensor_file : tensor_map) { - auto tensor = tensor_file; - size_t size = tensor->Size(); - const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)}; - memcpy(tensor->Data(), data, size); + // load data + MS_LOG(INFO) << "load data1"; + + LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); + auto *input_data = reinterpret_cast(input_tensor->Data()); + printf("\ninput[0:10]:"); + for (int i = 0; i < 10; i++) { + printf("[%d]:%.3f ", i, input_data[i]); } + printf("\n\n"); - MS_LOG(INFO) << "pGraph->Run()"; + MS_LOG(INFO) << "Run"; pGraph->Run(); - MS_LOG(INFO) << "==================output data================="; - float *output_data = reinterpret_cast(tensor_out->Data()); - size_t output_size = tensor_out->Size(); + MS_LOG(INFO) << "compare result"; + std::cout << "compare result" << std::endl; + CompareOutput(output_tensor, expect_file); +} - printf("output:"); - for (int i = 0; i < 4; i++) { - printf("%.3f ", output_data[i]); - } - printf("\n"); - float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f}; +TEST_F(TestSoftmaxOpenCL, Softmax_1) { + std::vector input_shape = {1, 2, 2, 8}; + std::vector output_shape = {1, 2, 2, 8}; + std::string input_file = "softmax_in.bin"; + std::string expect_file = "softmax_out.bin"; + auto param = new SoftmaxParameter; + param->axis_ = 3; + schema::Format format = schema::Format_NHWC4; - for (int i = 0; i < tensor_out->ElementsNum(); ++i) { - if (std::fabs(output_data[i] - expect[i]) > 1e-5) { - printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]); - } - } - printf("\nTest all close OK for %zu!\n", output_size); - lite::CompareOutputData(output_data, expect, 4); + RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); } +// TEST_F(TestSoftmaxOpenCL, Softmax_1x1) { +// std::vector input_shape = {1, 100}; +// std::vector output_shape = {1, 100}; +// std::string input_file = "softmax1x1_in.bin"; +// std::string expect_file = "softmax1x1_out.bin"; +// auto param = new SoftmaxParameter; +// param->axis_ = 1; +// schema::Format format = schema::Format_NHWC4; +// +// RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); +//} + } // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc index e834e29b07..5c02760ca0 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.cc @@ -40,13 +40,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_ size_t output_size = output_tensor->Size(); float *expect_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); - printf("output[0:10]:"); - for (int i = 0; i < 10; i++) { + printf("output[0:12]:"); + for (int i = 0; i < 12; i++) { printf("[%d]:%.3f ", i, output_data[i]); } printf("\n"); - printf("expect[0:10]:"); - for (int i = 0; i < 10; i++) { + printf("expect[0:12]:"); + for (int i = 0; i < 12; i++) { printf("[%d]:%.3f ", i, expect_data[i]); } printf("\n"); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 0861cda27d..7b59d6d215 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -157,7 +157,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { weightTensor->format = schema::Format_CHWK; } else { - MS_LOG(ERROR) << "unsupport format"; + MS_LOG(ERROR) << "Unsupported format"; return -1; } } break; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 1ade0e5dfe..d8358a8e5e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -184,7 +184,7 @@ size_t GetDataTypeSize(const TypeId &data_type) { return sizeof(int64_t); default: MS_LOG(ERROR) << data_type; - MS_LOG(ERROR) << "unsupport datatype"; + MS_LOG(ERROR) << "Unsupported datatype"; return RET_ERROR; } }