| @@ -11,23 +11,28 @@ | |||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||
| __kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids, | |||
| int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) { | |||
| int X = get_global_id(0); // reduce len | |||
| int X = get_global_id(0); // lower reduce stride | |||
| int Y = get_global_id(1); // upper axis accumulation | |||
| if (X >= src_size.x || Y >= src_size.y) { | |||
| return; | |||
| } | |||
| int offset = X + Y * src_size.z; | |||
| int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0; | |||
| int align_c4 = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0; | |||
| int align_in = 0; | |||
| int align_out = 0; | |||
| bool keep_dims = cus_size.y; | |||
| int width = shape.z * shape.w; | |||
| if (flags.z == 3) { | |||
| align_in = (Y / shape.z) * cus_size.z; | |||
| align_out = (Y / shape.z) * cus_size.w; | |||
| } | |||
| if (flags.z == 0) { | |||
| align_in = X / (shape.y) * cus_size.z; | |||
| align_in = X / (width)*cus_size.z; | |||
| align_out = align_in; | |||
| } | |||
| if (flags.z == 2 && !keep_dims) { | |||
| align_out = (Y / shape.y) * cus_size.w; | |||
| } | |||
| for (int k = 0; k < src_size.w; ++k) { | |||
| int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in); | |||
| int idx1 = offset + k * src_size.x; | |||
| @@ -61,8 +61,6 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() { | |||
| auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_); | |||
| cl_int4 in_shape{static_cast<int>(im_in_.N), static_cast<int>(im_in_.H), static_cast<int>(im_in_.W), | |||
| static_cast<int>(im_in_.C)}; | |||
| in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C; | |||
| in_shape.s[1] = im_in_.W * im_in_.C; | |||
| cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_}; | |||
| int arg_cnt = 2; | |||
| ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF); | |||
| @@ -77,17 +75,20 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() { | |||
| void ArgMinMaxOpenCLKernel::SetGlobalLocal() { | |||
| auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_); | |||
| im_in_ = GpuTensorInfo(in_tensors_[0]); | |||
| im_out_ = GpuTensorInfo(out_tensors_[0]); | |||
| std::vector<size_t> in_shape = {im_in_.N, im_in_.H, im_in_.W, im_in_.C}; | |||
| auto in_shape_align = in_shape; | |||
| in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM); | |||
| auto out_shape_align = in_shape_align; | |||
| out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_; | |||
| std::vector<size_t> out_shape = {im_out_.N, im_out_.H, im_out_.W, im_out_.C}; | |||
| auto out_shape_align = out_shape; | |||
| out_shape_align[3] = UP_ROUND(out_shape[3], C4NUM); | |||
| int reduce_len = GetUpPow2(in_shape.at(param->axis_)); | |||
| int dtype_size = in_tensors_[0]->data_type() == kNumberTypeFloat16 ? sizeof(int16_t) : sizeof(float); | |||
| cus_size_ = {reduce_len, static_cast<int>(im_in_.RowPitch() / dtype_size), 1, 1}; | |||
| cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM; | |||
| cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM); | |||
| cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3]; | |||
| int in_pitch = im_in_.RowPitch() / dtype_size; | |||
| int out_pitch = im_out_.RowPitch() / dtype_size; | |||
| cus_size_ = {reduce_len, param->keep_dims_, 1, 1}; | |||
| cus_size_.s[2] = in_pitch - im_in_.width * C4NUM; | |||
| cus_size_.s[3] = out_pitch - im_out_.width * C4NUM; | |||
| src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies<int>()), | |||
| std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies<int>()), | |||
| std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies<int>()), | |||
| @@ -100,22 +101,25 @@ void ArgMinMaxOpenCLKernel::SetGlobalLocal() { | |||
| }; | |||
| switch (param->axis_) { | |||
| case 0: | |||
| strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H; | |||
| strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, in_pitch) * im_in_.H; | |||
| strides_.s[1] = strides_.s[0] * im_in_.N; | |||
| strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H; | |||
| strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, out_pitch) * im_in_.H; | |||
| strides_.s[3] = strides_.s[2] * param->topk_; | |||
| break; | |||
| case 1: | |||
| strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]); | |||
| strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H; | |||
| strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]); | |||
| strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_; | |||
| strides_.s[0] = UP_ROUND(strides_.s[0], in_pitch); | |||
| strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, in_pitch) * im_in_.H; | |||
| // org dim(4,3) org axis(1,0) | |||
| strides_.s[2] = UP_ROUND(strides_.s[2], out_pitch); | |||
| strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, out_pitch) * param->topk_; | |||
| break; | |||
| case 2: | |||
| strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]); | |||
| strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]); | |||
| strides_.s[1] = UP_ROUND(strides_.s[1], in_pitch); | |||
| // org dim(4,3,2) org axis(2,1,0) | |||
| strides_.s[3] = param->keep_dims_ ? UP_ROUND(strides_.s[3], out_pitch) : strides_.s[2]; | |||
| break; | |||
| default: // 3 | |||
| // org dim(4,3,2,1) org axis(3,2,1,0) | |||
| break; | |||
| } | |||
| local_size_ = {1, 1, 1}; | |||
| @@ -147,8 +151,10 @@ int ArgMinMaxOpenCLKernel::Prepare() { | |||
| auto *param = reinterpret_cast<ArgMinMaxParameter *>(this->op_parameter_); | |||
| param->dims_size_ = in_tensors_[0]->shape().size(); | |||
| param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_; | |||
| param->axis_ = (4 - param->dims_size_) + param->axis_; | |||
| param->axis_ = GetBroadcastGpuAxis(param->dims_size_, param->axis_); | |||
| param->get_max_ = (Type() == PrimitiveType_ArgMax); | |||
| param->keep_dims_ = | |||
| param->keep_dims_ || param->topk_ > 1 || in_tensors_[0]->shape().size() == out_tensors_[0]->shape().size(); | |||
| InitWeights(); | |||
| SetGlobalLocal(); | |||
| @@ -44,6 +44,7 @@ class ArgMinMaxOpenCLKernel : public OpenCLKernel { | |||
| void *buff_{nullptr}; | |||
| void *ids_{nullptr}; | |||
| GpuTensorInfo im_in_{GpuTensorInfo(nullptr)}; | |||
| GpuTensorInfo im_out_{GpuTensorInfo(nullptr)}; | |||
| cl_int4 src_size_; | |||
| cl_int4 cus_size_; | |||
| cl_int4 strides_; | |||
| @@ -105,6 +105,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| auto allocator = ocl_runtime_->GetAllocator(); | |||
| bool is_fp16 = ocl_runtime_->GetFp16Enable(); | |||
| size_t dtype_size = is_fp16 ? sizeof(int16_t) : sizeof(float); | |||
| auto out_info = GpuTensorInfo(out_tensors_[0]); | |||
| // weight: o, h, w, i; o == group, i == 1 | |||
| void *origin_weight = in_tensors_.at(kWeightIndex)->data_c(); | |||
| @@ -121,7 +122,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||
| img_size = {(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype}; | |||
| } | |||
| pack_weight_size = is_fp16 ? pack_weight_size * sizeof(int16_t) : pack_weight_size * sizeof(float); | |||
| pack_weight_size = pack_weight_size * dtype_size; | |||
| auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out, | |||
| size_t channel) { | |||
| if (dst_type == kNumberTypeFloat16) { | |||
| @@ -173,18 +174,14 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() { | |||
| memcpy(dst, src, size * dtype_size); | |||
| } | |||
| }; | |||
| size_t dtype_size = sizeof(float); | |||
| if (is_fp16 && in_tensors_.at(kBiasIndex)->data_type() == kNumberTypeFloat16) { | |||
| dtype_size = sizeof(int16_t); | |||
| } | |||
| std::vector<char> temp_bias(pack_weight_size, 0); | |||
| size_t bias_size = C4NUM * CO4 * dtype_size; | |||
| std::vector<char> temp_bias(bias_size, 0); | |||
| if (in_tensors_.size() == 3) { | |||
| src_type = in_tensors_.at(kBiasIndex)->data_type(); | |||
| dst_type = is_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32; | |||
| auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum(); | |||
| ConvertBias(in_tensors_.at(kBiasIndex)->data_c(), temp_bias.data(), element_size, dtype_size, src_type, dst_type); | |||
| } | |||
| size_t bias_size = C4NUM * CO4 * dtype_size; | |||
| bias_data_ = allocator->Malloc(bias_size, {}, temp_bias.data()); | |||
| if (bias_data_ == nullptr) { | |||
| return RET_ERROR; | |||
| @@ -538,7 +538,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, s | |||
| } // namespace | |||
| void OpenCLSubGraph::Fusion() { | |||
| int OpenCLSubGraph::FusionPass() { | |||
| MS_LOG(DEBUG) << "start Fusion"; | |||
| std::vector<LiteKernel *> input_nodes; | |||
| @@ -657,6 +657,7 @@ void OpenCLSubGraph::Fusion() { | |||
| std::remove_if(nodes_.begin(), nodes_.end(), [&](LiteKernel *node) { return AIsInB(node, &removed_set); }), | |||
| nodes_.end()); | |||
| MS_LOG(DEBUG) << "number of kernels(after fusion) : " << nodes_.size(); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -16,6 +16,8 @@ | |||
| #include "src/runtime/kernel/opencl/opencl_subgraph.h" | |||
| #include <set> | |||
| #include <map> | |||
| #include <string> | |||
| #include "src/runtime/opencl/opencl_executor.h" | |||
| #include "src/runtime/kernel/opencl/utils.h" | |||
| #include "include/errorcode.h" | |||
| @@ -189,19 +191,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int OpenCLSubGraph::Init() { | |||
| allocator_ = ocl_runtime_->GetAllocator(); | |||
| MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size(); | |||
| for (const auto tensor : in_tensors_) { | |||
| MS_ASSERT(tensor); | |||
| tensor->set_allocator(allocator_); | |||
| } | |||
| for (const auto tensor : out_tensors_) { | |||
| MS_ASSERT(tensor); | |||
| tensor->set_allocator(allocator_); | |||
| } | |||
| int OpenCLSubGraph::InsertOpsPass() { | |||
| GetInOutNodes(); | |||
| std::vector<std::vector<kernel::LiteKernel *>> from_kernels_; | |||
| @@ -222,12 +212,34 @@ int OpenCLSubGraph::Init() { | |||
| } | |||
| nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end()); | |||
| GetInOutNodes(); | |||
| UpdateTensorDataType(); | |||
| Fusion(); | |||
| return RET_OK; | |||
| } | |||
| int OpenCLSubGraph::Init() { | |||
| allocator_ = ocl_runtime_->GetAllocator(); | |||
| MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size(); | |||
| for (const auto tensor : in_tensors_) { | |||
| MS_ASSERT(tensor); | |||
| tensor->set_allocator(allocator_); | |||
| } | |||
| for (const auto tensor : out_tensors_) { | |||
| MS_ASSERT(tensor); | |||
| tensor->set_allocator(allocator_); | |||
| } | |||
| std::map<std::string, std::function<int(void)>> pass_manager{ | |||
| {"InsertOpsPass", std::bind(&OpenCLSubGraph::InsertOpsPass, this)}, | |||
| {"UpdateTensorDataTypePass", std::bind(&OpenCLSubGraph::UpdateTensorDataTypePass, this)}, | |||
| {"FusionPass", std::bind(&OpenCLSubGraph::FusionPass, this)}}; | |||
| for (auto iv : pass_manager) { | |||
| auto ret = iv.second(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run Pass: " << iv.first << " failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void OpenCLSubGraph::UpdateTensorDataType() { | |||
| int OpenCLSubGraph::UpdateTensorDataTypePass() { | |||
| bool is_fp16 = ocl_runtime_->GetFp16Enable(); | |||
| MS_ASSERT(in_tensors_[0]); | |||
| if (is_fp16 && (in_tensors_[0]->data_type() == kNumberTypeFloat32)) { | |||
| @@ -245,6 +257,7 @@ void OpenCLSubGraph::UpdateTensorDataType() { | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void OpenCLSubGraph::GetKernelFromToTensor(const std::vector<lite::Tensor *> &in_tensors, | |||
| @@ -46,10 +46,11 @@ class OpenCLSubGraph : public SubGraphKernel { | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int Run(const KernelCallBack &before, const KernelCallBack &after) override { return this->Run(); }; | |||
| int InsertOpsPass(); | |||
| private: | |||
| void UnInit(); | |||
| void UpdateTensorDataType(); | |||
| int UpdateTensorDataTypePass(); | |||
| void ReplaceOutTensorAndKernelToNull(const std::vector<lite::Tensor *> &in_tensors, | |||
| const std::vector<std::vector<kernel::LiteKernel *>> &in_kernels, | |||
| lite::opencl::MemType mem_type); | |||
| @@ -64,7 +65,10 @@ class OpenCLSubGraph : public SubGraphKernel { | |||
| void GetKernelFromToTensor(const std::vector<lite::Tensor *> &in_tensors, | |||
| const std::vector<kernel::LiteKernel *> &in_kernels, | |||
| std::vector<std::vector<kernel::LiteKernel *>> *out_kernels, bool is_from); | |||
| void Fusion(); | |||
| int FusionPass(); | |||
| public: | |||
| using PassFunc = int (OpenCLSubGraph::*)(void); | |||
| private: | |||
| lite::opencl::OpenCLAllocator *allocator_{nullptr}; | |||
| @@ -330,4 +330,22 @@ std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape | |||
| } | |||
| return {image_x, image_y}; | |||
| } | |||
| int GetBroadcastGpuAxis(int ndim, int ori_axis) { | |||
| if (ori_axis >= ndim) { | |||
| return ndim - 1; | |||
| } | |||
| int axis = 0; | |||
| if (ndim == 1) { | |||
| axis = 3; | |||
| } else if (ndim == 2) { | |||
| axis = ori_axis == 0 ? 0 : 3; | |||
| } else if (ndim == 3) { | |||
| axis = ori_axis == 0 ? 0 : ori_axis == 1 ? 2 : 3; | |||
| } else if (ndim == 4) { | |||
| axis = ori_axis; | |||
| } else if (ndim > 4) { | |||
| MS_LOG(ERROR) << "GPU doesn't support ndim>=" << ndim; | |||
| } | |||
| return axis; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -61,6 +61,8 @@ std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape); | |||
| std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format); | |||
| int GetBroadcastGpuAxis(int ndim, int ori_axis); | |||
| template <class T1, class T2> | |||
| void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel, | |||
| const std::function<T2(T1)> &to_dtype) { | |||
| @@ -185,7 +185,7 @@ TEST_F(TestOpenCL_ArgMinMax, axis3topk2value) { | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) { | |||
| TEST_F(TestOpenCL_ArgMinMax, dim32axis1topk1index) { | |||
| schema::PrimitiveType type = schema::PrimitiveType_ArgMax; | |||
| int axis = 1; | |||
| int topk = 1; | |||
| @@ -200,4 +200,52 @@ TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) { | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_ArgMinMax, dim43axis2topk1index) { | |||
| schema::PrimitiveType type = schema::PrimitiveType_ArgMax; | |||
| int axis = 2; | |||
| int topk = 1; | |||
| bool out_value = false; | |||
| std::vector<int> input_shape = {2, 2, 2, 14}; | |||
| std::vector<int> output_shape = {2, 2, 14}; | |||
| float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, | |||
| 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, | |||
| 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, | |||
| 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, | |||
| 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25}; | |||
| float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, | |||
| 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(type, axis, topk, out_value); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_ArgMinMax, dim21axis2topk1index) { | |||
| schema::PrimitiveType type = schema::PrimitiveType_ArgMax; | |||
| int axis = 0; | |||
| int topk = 1; | |||
| bool out_value = false; | |||
| std::vector<int> input_shape = {2, 14}; | |||
| std::vector<int> output_shape = {14}; | |||
| float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, | |||
| 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25}; | |||
| float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(type, axis, topk, out_value); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_ArgMinMax, dim10axis2topk1index) { | |||
| schema::PrimitiveType type = schema::PrimitiveType_ArgMax; | |||
| int axis = 0; | |||
| int topk = 1; | |||
| bool out_value = false; | |||
| std::vector<int> input_shape = {14}; | |||
| std::vector<int> output_shape = {1}; | |||
| float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50}; | |||
| float output_data[] = {4}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(type, axis, topk, out_value); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite::opencl::test | |||