Merge pull request !4337 from chenzhongming/litetags/v0.7.0-beta
| @@ -32,6 +32,10 @@ enum Format : int { | |||||
| CKHW, | CKHW, | ||||
| KHWC, | KHWC, | ||||
| CHWK, | CHWK, | ||||
| HW, | |||||
| HW4, | |||||
| NC, | |||||
| NC4, | |||||
| NC4HW4 = 100, | NC4HW4 = 100, | ||||
| NUM_OF_FORMAT | NUM_OF_FORMAT | ||||
| } | } | ||||
| @@ -104,7 +104,7 @@ int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format d | |||||
| allocator->Free(src_data); | allocator->Free(src_data); | ||||
| return RET_OK; | return RET_OK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in float32"; | << schema::EnumNameFormat(dst_format) << " in float32"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -116,7 +116,7 @@ int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format | |||||
| MS_ASSERT(4 == tensor->shape().size()); | MS_ASSERT(4 == tensor->shape().size()); | ||||
| // auto src_format = tensor->GetFormat(); | // auto src_format = tensor->GetFormat(); | ||||
| // todo | // todo | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in uint8"; | << schema::EnumNameFormat(dst_format) << " in uint8"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -104,8 +104,8 @@ bool Tensor::operator==(const Value &other) const { | |||||
| } | } | ||||
| int32_t Tensor::Batch() const { | int32_t Tensor::Batch() const { | ||||
| if (this->shape_.size() != 4) { | |||||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| switch (this->format_) { | switch (this->format_) { | ||||
| @@ -115,6 +115,8 @@ int32_t Tensor::Batch() const { | |||||
| case schema::Format_NC4HW4: | case schema::Format_NC4HW4: | ||||
| case schema::Format_KCHW: | case schema::Format_KCHW: | ||||
| case schema::Format_KHWC: | case schema::Format_KHWC: | ||||
| case schema::Format_NC: | |||||
| case schema::Format_NC4: | |||||
| return this->shape_[0]; | return this->shape_[0]; | ||||
| case schema::Format_HWCK: | case schema::Format_HWCK: | ||||
| case schema::Format_CHWK: | case schema::Format_CHWK: | ||||
| @@ -124,19 +126,21 @@ int32_t Tensor::Batch() const { | |||||
| case schema::Format_CKHW: | case schema::Format_CKHW: | ||||
| return this->shape_[1]; | return this->shape_[1]; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); | |||||
| MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| } | } | ||||
| int32_t Tensor::Channel() const { | int32_t Tensor::Channel() const { | ||||
| if (this->shape_.size() != 4) { | |||||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| switch (this->format_) { | switch (this->format_) { | ||||
| case schema::Format_NCHW: | case schema::Format_NCHW: | ||||
| case schema::Format_KCHW: | case schema::Format_KCHW: | ||||
| case schema::Format_NC: | |||||
| case schema::Format_NC4: | |||||
| return this->shape_[1]; | return this->shape_[1]; | ||||
| case schema::Format_HWCK: | case schema::Format_HWCK: | ||||
| return this->shape_[2]; | return this->shape_[2]; | ||||
| @@ -155,8 +159,8 @@ int32_t Tensor::Channel() const { | |||||
| } | } | ||||
| int32_t Tensor::Height() const { | int32_t Tensor::Height() const { | ||||
| if (this->shape_.size() != 4) { | |||||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| switch (this->format_) { | switch (this->format_) { | ||||
| @@ -172,16 +176,18 @@ int32_t Tensor::Height() const { | |||||
| return this->shape_[1]; | return this->shape_[1]; | ||||
| case schema::Format_HWCK: | case schema::Format_HWCK: | ||||
| case schema::Format_HWKC: | case schema::Format_HWKC: | ||||
| case schema::Format_HW: | |||||
| case schema::Format_HW4: | |||||
| return this->shape_[0]; | return this->shape_[0]; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_); | |||||
| MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| } | } | ||||
| int32_t Tensor::Width() const { | int32_t Tensor::Width() const { | ||||
| if (this->shape_.size() != 4) { | |||||
| MS_LOG(ERROR) << "tensor should have 4 dim"; | |||||
| if (this->shape_.size() != 4 && this->shape_.size() != 2) { | |||||
| MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| switch (this->format_) { | switch (this->format_) { | ||||
| @@ -197,12 +203,24 @@ int32_t Tensor::Width() const { | |||||
| return this->shape_[2]; | return this->shape_[2]; | ||||
| case schema::Format_HWCK: | case schema::Format_HWCK: | ||||
| case schema::Format_HWKC: | case schema::Format_HWKC: | ||||
| case schema::Format_HW: | |||||
| case schema::Format_HW4: | |||||
| return this->shape_[1]; | return this->shape_[1]; | ||||
| default: | default: | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| } | } | ||||
| int32_t Tensor::ElementsC4Num() const { | |||||
| int32_t result = 0; | |||||
| if (this->shape_.size() == 4) { | |||||
| result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); | |||||
| } else if (this->shape_.size() == 2) { | |||||
| result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4); | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::string Tensor::ToString() const { | std::string Tensor::ToString() const { | ||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << "Format: " << schema::EnumNameFormat(this->format_); | oss << "Format: " << schema::EnumNameFormat(this->format_); | ||||
| @@ -235,7 +253,7 @@ std::string Tensor::ToString() const { | |||||
| } | } | ||||
| } break; | } break; | ||||
| default: | default: | ||||
| oss << "Unsupport data type to print"; | |||||
| oss << "Unsupported data type to print"; | |||||
| break; | break; | ||||
| } | } | ||||
| return oss.str(); | return oss.str(); | ||||
| @@ -66,7 +66,7 @@ class Tensor : public mindspore::tensor::MetaTensor { | |||||
| int32_t Width() const; | int32_t Width() const; | ||||
| int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); } | |||||
| int32_t ElementsC4Num() const; | |||||
| int DataSize() const { return this->ElementsNum(); } | int DataSize() const { return this->ElementsNum(); } | ||||
| @@ -37,7 +37,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:: | |||||
| return RET_INPUT_TENSOR_ERROR; | return RET_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) { | if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) { | ||||
| MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); | |||||
| MS_LOG(ERROR) << "Unsupported input data type " << input->data_type(); | |||||
| return RET_INPUT_TENSOR_ERROR; | return RET_INPUT_TENSOR_ERROR; | ||||
| } | } | ||||
| if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { | if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) { | ||||
| @@ -74,7 +74,7 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset, | Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset, | ||||
| reinterpret_cast<int32_t *>(output_data) + offset, data_num); | reinterpret_cast<int32_t *>(output_data) + offset, data_num); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type; | |||||
| MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -88,7 +88,7 @@ int CastCPUKernel::DoCast(int thread_id) { | |||||
| reinterpret_cast<float *>(output_data) + offset, data_num); | reinterpret_cast<float *>(output_data) + offset, data_num); | ||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; | |||||
| MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,21 +1,15 @@ | |||||
| #define SLICES 4 | |||||
| int DivideRoundUp(int n, int div) { | |||||
| int q = n / div; | |||||
| return n % div == 0 ? q : q + 1; | |||||
| } | |||||
| __kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) { | |||||
| int X = get_global_id(0); // width | |||||
| int Y = get_global_id(1); // height | |||||
| int H = input_shape.y; | |||||
| int W = input_shape.z; | |||||
| int C = input_shape.w; | |||||
| __kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| int H = input_shape.x; | |||||
| int W = input_shape.y; | |||||
| int C = input_shape.z; | |||||
| int S = input_shape.w; | |||||
| if (X >= W || Y >= H) return; | if (X >= W || Y >= H) return; | ||||
| float sum = 0.0f; | float sum = 0.0f; | ||||
| for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { | |||||
| for (int d = 0; d < S; ++d) { | |||||
| float4 t = input[(Y * W + X * H) * C + d]; | float4 t = input[(Y * W + X * H) * C + d]; | ||||
| sum += exp(t.x); | sum += exp(t.x); | ||||
| if (d * 4 + 1 < C) sum += exp(t.y); | if (d * 4 + 1 < C) sum += exp(t.y); | ||||
| @@ -23,10 +17,34 @@ __kernel void SoftMax(__global float4 *input, __global float4 *output, const int | |||||
| if (d * 4 + 3 < C) sum += exp(t.w); | if (d * 4 + 3 < C) sum += exp(t.w); | ||||
| } | } | ||||
| for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) { | |||||
| for (int d = 0; d < S; ++d) { | |||||
| float4 t = input[(Y * W + X * H) * C + d]; | float4 t = input[(Y * W + X * H) * C + d]; | ||||
| t = exp(t) / sum; | t = exp(t) / sum; | ||||
| float4 result = convert_float4(t); | float4 result = convert_float4(t); | ||||
| output[(Y * W + X * H) * C + d] = result; | output[(Y * W + X * H) * C + d] = result; | ||||
| } | } | ||||
| } | } | ||||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||||
| __kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) { | |||||
| int X = get_global_id(0); | |||||
| int Y = get_global_id(1); | |||||
| if (X >= input_shape.x || Y >= input_shape.y) return; | |||||
| float sum = 0.0f; | |||||
| for (int d = 0; d < input_shape.w; ++d) { | |||||
| float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); | |||||
| sum += exp(t.x); | |||||
| if (d * 4 + 1 < input_shape.z) sum += exp(t.y); | |||||
| if (d * 4 + 2 < input_shape.z) sum += exp(t.z); | |||||
| if (d * 4 + 3 < input_shape.z) sum += exp(t.w); | |||||
| } | |||||
| for (int d = 0; d < input_shape.w; ++d) { | |||||
| float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X)); | |||||
| t = exp(t) / sum; | |||||
| float4 result = convert_float4(t); | |||||
| write_imagef(output, (int2)(Y * input_shape.w + d, X), result); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,50 @@ | |||||
| __constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; | |||||
| // what is mask and args.slices_x32 | |||||
| __kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask, | |||||
| const int slices, const int slices_x32) { | |||||
| int tid = get_local_id(0); | |||||
| int slices_count = 0; | |||||
| int offset = 0; | |||||
| float sum = 0.0f; | |||||
| do { | |||||
| int z = offset + tid; | |||||
| if (z < slices) { | |||||
| float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f); | |||||
| float4 src = read_imagef(input, smp_none, (int2)(0, 0)); | |||||
| sum += dot(mask_temp, exp(src)); | |||||
| offset += 32; | |||||
| } | |||||
| slices_count++; | |||||
| } while (slices_count < slices_x32); | |||||
| __local float4 tmp[8]; | |||||
| __local float *tmpx1 = (__local float *)tmp; | |||||
| tmpx1[tid] = sum; | |||||
| barrier(CLK_LOCAL_MEM_FENCE); | |||||
| if (tid == 0) { | |||||
| sum = dot((float4)(1.0f), tmp[0]); | |||||
| sum += dot((float4)(1.0f), tmp[1]); | |||||
| sum += dot((float4)(1.0f), tmp[2]); | |||||
| sum += dot((float4)(1.0f), tmp[3]); | |||||
| sum += dot((float4)(1.0f), tmp[4]); | |||||
| sum += dot((float4)(1.0f), tmp[5]); | |||||
| sum += dot((float4)(1.0f), tmp[6]); | |||||
| sum += dot((float4)(1.0f), tmp[7]); | |||||
| tmpx1[0] = 1.0f / sum; | |||||
| } | |||||
| barrier(CLK_LOCAL_MEM_FENCE); | |||||
| sum = tmpx1[0]; | |||||
| offset = 0; | |||||
| slices_count = 0; | |||||
| do { | |||||
| int z = offset + tid; | |||||
| if (z < slices) { | |||||
| float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum); | |||||
| write_imagef(output, (int2)(0, 0), res); | |||||
| offset += 32; | |||||
| } | |||||
| slices_count++; | |||||
| } while (slices_count < slices_x32); | |||||
| } | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -17,69 +17,143 @@ | |||||
| #include "src/runtime/kernel/opencl/kernel/softmax.h" | #include "src/runtime/kernel/opencl/kernel/softmax.h" | ||||
| #include <string> | #include <string> | ||||
| #include <set> | #include <set> | ||||
| #include "include/errorcode.h" | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| #include "src/runtime/kernel/opencl/utils.h" | |||||
| #ifndef PROGRAM_WITH_IL | #ifndef PROGRAM_WITH_IL | ||||
| #include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" | #include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc" | ||||
| #include "src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl.inc" | |||||
| #endif | #endif | ||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | using mindspore::kernel::KERNEL_ARCH::kGPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::schema::PrimitiveType_SoftMax; | using mindspore::schema::PrimitiveType_SoftMax; | ||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| namespace mindspore::kernel { | |||||
| std::vector<float> SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) { | |||||
| std::vector<float> mask{4, 0.0f}; | |||||
| const int reminder = channels % 4 == 0 ? 4 : channels % 4; | |||||
| for (int i = 0; i < reminder; ++i) { | |||||
| mask[i] = 1.0f; | |||||
| } | |||||
| return mask; | |||||
| } | |||||
| int SoftmaxOpenCLKernel::InitGlobalSize() { | |||||
| const size_t global_x = out_tensors_[0]->Height(); | |||||
| const size_t global_y = out_tensors_[0]->Width(); | |||||
| const size_t global_z = 1; | |||||
| global_size_ = {global_x, global_y, global_z}; | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int SoftmaxOpenCLKernel::SetWorkGroupSize() { | |||||
| // set work group size | |||||
| InitGlobalSize(); | |||||
| int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); | |||||
| local_size_ = GetCommonLocalSize(global_size_, max_work_group_size); | |||||
| global_size_ = GetCommonGlobalSize(local_size_, global_size_); | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() { | |||||
| local_size_ = {32, 1, 1}; | |||||
| global_size_ = {32, 1, 1}; | |||||
| return lite::RET_OK; | |||||
| } | |||||
| int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||||
| size_t im_dst_x, im_dst_y; | |||||
| if (onexone_flag_) { | |||||
| im_dst_x = UP_DIV(in_tensors_[0]->shape()[1], C4NUM); | |||||
| im_dst_y = 1; | |||||
| } else { | |||||
| size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||||
| im_dst_x = out_tensors_[0]->Width() * CO4; | |||||
| im_dst_y = out_tensors_[0]->Height(); | |||||
| } | |||||
| #ifdef ENABLE_FP16 | |||||
| size_t img_dtype = CL_HALF_FLOAT; | |||||
| #else | |||||
| size_t img_dtype = CL_FLOAT; | |||||
| #endif | |||||
| img_size->clear(); | |||||
| std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype}; | |||||
| *img_size = vec; | |||||
| return RET_OK; | |||||
| } | |||||
| int SoftmaxOpenCLKernel::Init() { | int SoftmaxOpenCLKernel::Init() { | ||||
| std::string kernel_name = "SoftMax"; | std::string kernel_name = "SoftMax"; | ||||
| if (parameter_->axis_ != -1 && parameter_->axis_ != 3) { | |||||
| MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; | |||||
| return -1; | |||||
| } | |||||
| std::string program_name = "SoftMax"; | |||||
| std::string source = softmax_source_fp32; | |||||
| runtime_ = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) { | |||||
| // support 4d tensor | |||||
| onexone_flag_ = false; | |||||
| } else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) { | |||||
| // support 2d tensor | |||||
| kernel_name += "1x1"; | |||||
| program_name += "1x1"; | |||||
| source = softmax1x1_source_fp32; | |||||
| onexone_flag_ = true; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_; | |||||
| } | |||||
| #ifdef PROGRAM_WITH_IL | #ifdef PROGRAM_WITH_IL | ||||
| ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); | |||||
| runtime_->CreateKernelFromIL(kernel_(), kernel_name); | |||||
| #else | #else | ||||
| if (mem_type_ == MEM_TYPE::BUF) { | |||||
| kernel_name += "_BUF"; | |||||
| program_name += "_BUF"; | |||||
| } else { | |||||
| kernel_name += "_IMG"; | |||||
| program_name += "_IMG"; | |||||
| } | |||||
| std::set<std::string> build_options; | std::set<std::string> build_options; | ||||
| std::string source = softmax_source_fp32; | |||||
| std::string program_name = "SoftMax"; | |||||
| ocl_runtime->LoadSource(program_name, source); | |||||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| #endif | |||||
| runtime_->LoadSource(program_name, source); | |||||
| out_tensors_[0]->SetFormat(schema::Format_NHWC4); | out_tensors_[0]->SetFormat(schema::Format_NHWC4); | ||||
| runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| #endif | |||||
| MS_LOG(DEBUG) << kernel_name << " Init Done!"; | MS_LOG(DEBUG) << kernel_name << " Init Done!"; | ||||
| return 0; | |||||
| return lite::RET_OK; | |||||
| } | } | ||||
| int SoftmaxOpenCLKernel::InitBuffer() { return 0; } | |||||
| int SoftmaxOpenCLKernel::ReSize() { return 0; } | |||||
| int SoftmaxOpenCLKernel::Run() { | int SoftmaxOpenCLKernel::Run() { | ||||
| MS_LOG(DEBUG) << this->name() << " Running!"; | MS_LOG(DEBUG) << this->name() << " Running!"; | ||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| // global and local workers | |||||
| const uint32_t grid_x = in_tensors_[0]->shape()[2]; // W | |||||
| const uint32_t grid_y = in_tensors_[0]->shape()[1]; // H | |||||
| const uint32_t grid_z = 1; | |||||
| std::vector<size_t> global = {grid_x, grid_y, grid_z}; | |||||
| std::vector<size_t> local = {1, 1, 1}; | |||||
| // input and output | |||||
| cl::Buffer *input = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(in_tensors_[0]->Data())); | |||||
| cl::Buffer *output = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(out_tensors_[0]->Data())); | |||||
| cl_int4 input_size = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2], | |||||
| in_tensors_[0]->shape()[3]}; | |||||
| std::cout << "run" << std::endl; | |||||
| // attribute | |||||
| int arg_idx = 0; | int arg_idx = 0; | ||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input); | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output); | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size); | |||||
| if (onexone_flag_) { | |||||
| int channel_size = in_tensors_[0]->shape()[1]; | |||||
| int slices = UP_DIV(channel_size, C4NUM); | |||||
| cl_int slices_x32 = UP_DIV(slices, 32); | |||||
| auto mask_ = GetMaskForLastChannel(channel_size); | |||||
| cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]}; | |||||
| // run opengl kernel | |||||
| ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, mask); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, slices); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx, slices_x32); | |||||
| SetWorkGroupSize1x1(); | |||||
| } else { | |||||
| int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM); | |||||
| cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices}; | |||||
| return 0; | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||||
| runtime_->SetKernelArg(kernel_, arg_idx, input_shape); | |||||
| SetWorkGroupSize(); | |||||
| } | |||||
| // run opengl kernel | |||||
| runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr); | |||||
| return lite::RET_OK; | |||||
| } | } | ||||
| kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| @@ -104,5 +178,4 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T | |||||
| } | } | ||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator) | REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator) | ||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| } // namespace mindspore::kernel | |||||
| @@ -23,29 +23,37 @@ | |||||
| #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | #include "src/runtime/kernel/arm/nnacl/fp32/softmax.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class SoftmaxOpenCLKernel : public LiteKernel { | |||||
| namespace mindspore::kernel { | |||||
| class SoftmaxOpenCLKernel : public OpenCLKernel { | |||||
| public: | public: | ||||
| explicit SoftmaxOpenCLKernel(OpParameter *parameter, | |||||
| const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| explicit SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | const std::vector<lite::tensor::Tensor *> &outputs) | ||||
| : LiteKernel(parameter, inputs, outputs, nullptr, nullptr) { | |||||
| : OpenCLKernel(parameter, inputs, outputs) { | |||||
| parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter); | parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter); | ||||
| } | } | ||||
| ~SoftmaxOpenCLKernel() override{}; | |||||
| ~SoftmaxOpenCLKernel() override{}; | |||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | |||||
| int Run() override; | int Run() override; | ||||
| int InitBuffer(); | |||||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||||
| int InitGlobalSize(); | |||||
| int SetWorkGroupSize1x1(); | |||||
| int SetWorkGroupSize(); | |||||
| std::vector<float> GetMaskForLastChannel(int channels); | |||||
| private: | private: | ||||
| SoftmaxParameter *parameter_; | |||||
| cl::Kernel kernel_; | cl::Kernel kernel_; | ||||
| SoftmaxParameter *parameter_; | |||||
| lite::opencl::OpenCLRuntime *runtime_; | |||||
| enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG}; | |||||
| bool onexone_flag_{false}; | |||||
| std::vector<size_t> local_size_; | |||||
| std::vector<size_t> global_size_; | |||||
| }; | }; | ||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_ | |||||
| @@ -175,4 +175,3 @@ std::string CLErrorCode(cl_int error_code) { | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -85,4 +85,3 @@ std::string CLErrorCode(cl_int error_code); | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ | #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_ | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "src/runtime/opencl/opencl_runtime.h" | #include "src/runtime/opencl/opencl_runtime.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/kernel/opencl/utils.h" | |||||
| namespace mindspore::lite::opencl { | namespace mindspore::lite::opencl { | ||||
| @@ -128,7 +129,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size) | |||||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, | cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, | ||||
| img_size[0], img_size[1], 0, nullptr, &ret); | img_size[0], img_size[1], 0, nullptr, &ret); | ||||
| if (ret != CL_SUCCESS) { | if (ret != CL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed!" << kernel::CLErrorCode(ret); | |||||
| UnLock(); | UnLock(); | ||||
| delete buffer; | delete buffer; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -187,7 +188,7 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v | |||||
| cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, | cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, | ||||
| image_format, img_size[0], img_size[1], 0, data, &ret); | image_format, img_size[0], img_size[1], 0, data, &ret); | ||||
| if (ret != CL_SUCCESS) { | if (ret != CL_SUCCESS) { | ||||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; | |||||
| MS_LOG(ERROR) << "Create OpenCL Image2D failed - " << kernel::CLErrorCode(ret); | |||||
| UnLock(); | UnLock(); | ||||
| delete buffer; | delete buffer; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -52,6 +52,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso | |||||
| std::vector<size_t> img_size; | std::vector<size_t> img_size; | ||||
| op_kernel->GetImageSize(i, &img_size); | op_kernel->GetImageSize(i, &img_size); | ||||
| auto data_ptr = op_allocator->Malloc(output->Size(), img_size); | auto data_ptr = op_allocator->Malloc(output->Size(), img_size); | ||||
| output->SetData(data_ptr); | output->SetData(data_ptr); | ||||
| } else { | } else { | ||||
| output->MallocData(allocator); | output->MallocData(allocator); | ||||
| @@ -109,7 +110,7 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format | |||||
| case kNumberTypeFloat32: | case kNumberTypeFloat32: | ||||
| return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); | return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir); | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format); | << schema::EnumNameFormat(dst_format); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -160,7 +161,7 @@ int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema | |||||
| // TODO(wandongdong): add support !! | // TODO(wandongdong): add support !! | ||||
| return RET_OK; | return RET_OK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in float32"; | << schema::EnumNameFormat(dst_format) << " in float32"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -194,7 +195,7 @@ int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema: | |||||
| allocator_->Free(src_data); | allocator_->Free(src_data); | ||||
| return RET_OK; | return RET_OK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in float32"; | << schema::EnumNameFormat(dst_format) << " in float32"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -216,7 +217,7 @@ int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schem | |||||
| allocator_->Free(src_data); | allocator_->Free(src_data); | ||||
| return RET_OK; | return RET_OK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in float32"; | << schema::EnumNameFormat(dst_format) << " in float32"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -228,7 +229,7 @@ int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::F | |||||
| MS_ASSERT(4 == tensor->shape().size()); | MS_ASSERT(4 == tensor->shape().size()); | ||||
| // auto src_format = tensor->GetFormat(); | // auto src_format = tensor->GetFormat(); | ||||
| // todo | // todo | ||||
| MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to " | |||||
| << schema::EnumNameFormat(dst_format) << " in uint8"; | << schema::EnumNameFormat(dst_format) << " in uint8"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -17,76 +17,90 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "mindspore/core/utils/log_adapter.h" | #include "mindspore/core/utils/log_adapter.h" | ||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h" | ||||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestSoftmaxOpenCL : public mindspore::CommonTest {}; | class TestSoftmaxOpenCL : public mindspore::CommonTest {}; | ||||
| void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; } | |||||
| TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) { | |||||
| std::cout << "======" << std::endl; | |||||
| MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL"; | |||||
| void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, std::string input_file, | |||||
| std::string expect_file, SoftmaxParameter *param, schema::Format format) { | |||||
| std::cout << "runtime" << std::endl; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | ||||
| ocl_runtime->Init(); | ocl_runtime->Init(); | ||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| MS_LOG(INFO) << "create SoftmaxParameter"; | |||||
| auto param = new SoftmaxParameter(); | |||||
| InitSoftaxParam(param); | |||||
| // define tensor | |||||
| MS_LOG(INFO) << "defineTensor"; | |||||
| std::cout << "defineTensor" << std::endl; | |||||
| MS_LOG(INFO) << "create Tensors"; | |||||
| std::vector<int> shape_in = {1, 2, 2, 1}; | |||||
| std::vector<int> shape_out = {1, 2, 2, 1}; | |||||
| auto data_type = kNumberTypeFloat32; | auto data_type = kNumberTypeFloat32; | ||||
| auto tensorType = schema::NodeType_ValueNode; | auto tensorType = schema::NodeType_ValueNode; | ||||
| lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType); | |||||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType); | |||||
| std::vector<lite::tensor::Tensor *> inputs{tensor_in}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||||
| auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, format, tensorType); | |||||
| auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, format, tensorType); | |||||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||||
| MS_LOG(INFO) << "create OpenCL Kernel"; | |||||
| auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| Softmax_kernel->Init(); | |||||
| std::vector<kernel::LiteKernel *> kernels{Softmax_kernel}; | |||||
| // run | |||||
| MS_LOG(INFO) << "NewOpenCLKernel"; | |||||
| std::cout << "NewOpenCLKernel" << std::endl; | |||||
| auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| MS_LOG(INFO) << "KernelInit"; | |||||
| std::cout << "KernelInit" << std::endl; | |||||
| kernel->Init(); | |||||
| MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | |||||
| std::cout << "LiteKernel" << std::endl; | |||||
| std::vector<kernel::LiteKernel *> kernels{kernel}; | |||||
| inputs[0]->MallocData(allocator); | |||||
| std::cout << "SubGraphOpenCLKernel" << std::endl; | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | ||||
| MS_LOG(INFO) << "pGraphinit"; | |||||
| pGraph->Init(); | pGraph->Init(); | ||||
| MS_LOG(INFO) << "initialize data"; | |||||
| std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in}; | |||||
| for (auto &tensor_file : tensor_map) { | |||||
| auto tensor = tensor_file; | |||||
| size_t size = tensor->Size(); | |||||
| const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)}; | |||||
| memcpy(tensor->Data(), data, size); | |||||
| // load data | |||||
| MS_LOG(INFO) << "load data1"; | |||||
| LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); | |||||
| auto *input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| printf("\ninput[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("[%d]:%.3f ", i, input_data[i]); | |||||
| } | } | ||||
| printf("\n\n"); | |||||
| MS_LOG(INFO) << "pGraph->Run()"; | |||||
| MS_LOG(INFO) << "Run"; | |||||
| pGraph->Run(); | pGraph->Run(); | ||||
| MS_LOG(INFO) << "==================output data================="; | |||||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||||
| size_t output_size = tensor_out->Size(); | |||||
| MS_LOG(INFO) << "compare result"; | |||||
| std::cout << "compare result" << std::endl; | |||||
| CompareOutput(output_tensor, expect_file); | |||||
| } | |||||
| printf("output:"); | |||||
| for (int i = 0; i < 4; i++) { | |||||
| printf("%.3f ", output_data[i]); | |||||
| } | |||||
| printf("\n"); | |||||
| float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f}; | |||||
| TEST_F(TestSoftmaxOpenCL, Softmax_1) { | |||||
| std::vector<int> input_shape = {1, 2, 2, 8}; | |||||
| std::vector<int> output_shape = {1, 2, 2, 8}; | |||||
| std::string input_file = "softmax_in.bin"; | |||||
| std::string expect_file = "softmax_out.bin"; | |||||
| auto param = new SoftmaxParameter; | |||||
| param->axis_ = 3; | |||||
| schema::Format format = schema::Format_NHWC4; | |||||
| for (int i = 0; i < tensor_out->ElementsNum(); ++i) { | |||||
| if (std::fabs(output_data[i] - expect[i]) > 1e-5) { | |||||
| printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]); | |||||
| } | |||||
| } | |||||
| printf("\nTest all close OK for %zu!\n", output_size); | |||||
| lite::CompareOutputData(output_data, expect, 4); | |||||
| RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); | |||||
| } | } | ||||
| // TEST_F(TestSoftmaxOpenCL, Softmax_1x1) { | |||||
| // std::vector<int> input_shape = {1, 100}; | |||||
| // std::vector<int> output_shape = {1, 100}; | |||||
| // std::string input_file = "softmax1x1_in.bin"; | |||||
| // std::string expect_file = "softmax1x1_out.bin"; | |||||
| // auto param = new SoftmaxParameter; | |||||
| // param->axis_ = 1; | |||||
| // schema::Format format = schema::Format_NHWC4; | |||||
| // | |||||
| // RunTestCase(input_shape, output_shape, input_file, expect_file, param, format); | |||||
| //} | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -40,13 +40,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_ | |||||
| size_t output_size = output_tensor->Size(); | size_t output_size = output_tensor->Size(); | ||||
| float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | ||||
| printf("output[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("output[0:12]:"); | |||||
| for (int i = 0; i < 12; i++) { | |||||
| printf("[%d]:%.3f ", i, output_data[i]); | printf("[%d]:%.3f ", i, output_data[i]); | ||||
| } | } | ||||
| printf("\n"); | printf("\n"); | ||||
| printf("expect[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("expect[0:12]:"); | |||||
| for (int i = 0; i < 12; i++) { | |||||
| printf("[%d]:%.3f ", i, expect_data[i]); | printf("[%d]:%.3f ", i, expect_data[i]); | ||||
| } | } | ||||
| printf("\n"); | printf("\n"); | ||||
| @@ -157,7 +157,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { | |||||
| } else if (opType == schema::PrimitiveType_DeConv2D) { | } else if (opType == schema::PrimitiveType_DeConv2D) { | ||||
| weightTensor->format = schema::Format_CHWK; | weightTensor->format = schema::Format_CHWK; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "unsupport format"; | |||||
| MS_LOG(ERROR) << "Unsupported format"; | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| } break; | } break; | ||||
| @@ -184,7 +184,7 @@ size_t GetDataTypeSize(const TypeId &data_type) { | |||||
| return sizeof(int64_t); | return sizeof(int64_t); | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << data_type; | MS_LOG(ERROR) << data_type; | ||||
| MS_LOG(ERROR) << "unsupport datatype"; | |||||
| MS_LOG(ERROR) << "Unsupported datatype"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||