Merge pull request !8122 from chenzupeng/master-litetags/v1.1.0
| @@ -295,7 +295,7 @@ kernel::LiteKernel *OpenCLArithmeticKernelCreator(const std::vector<lite::Tensor | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *kernel = | auto *kernel = | ||||
| new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx); | |||||
| new (std::nothrow) ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; | MS_LOG(ERROR) << "Create OpenCL Arithmetic kernel failed!"; | ||||
| free(opParameter); | free(opParameter); | ||||
| @@ -26,7 +26,7 @@ namespace mindspore::kernel { | |||||
| class ArithmeticOpenCLKernel : public OpenCLKernel { | class ArithmeticOpenCLKernel : public OpenCLKernel { | ||||
| public: | public: | ||||
| ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ArithmeticOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| const std::vector<lite::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | : OpenCLKernel(parameter, inputs, outputs) {} | ||||
| ~ArithmeticOpenCLKernel() override = default; | ~ArithmeticOpenCLKernel() override = default; | ||||
| @@ -251,8 +251,7 @@ kernel::LiteKernel *OpenCLScaleKernelCreator(const std::vector<lite::Tensor *> & | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | ||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | const lite::InnerContext *ctx, const kernel::KernelKey &desc, | ||||
| const mindspore::lite::PrimitiveC *primitive) { | const mindspore::lite::PrimitiveC *primitive) { | ||||
| auto *kernel = | |||||
| new (std::nothrow) ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs, ctx); | |||||
| auto *kernel = new (std::nothrow) ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
| MS_LOG(ERROR) << "Create OpenCL Scale kernel failed!"; | MS_LOG(ERROR) << "Create OpenCL Scale kernel failed!"; | ||||
| free(opParameter); | free(opParameter); | ||||
| @@ -26,7 +26,7 @@ namespace mindspore::kernel { | |||||
| class ScaleOpenCLKernel : public OpenCLKernel { | class ScaleOpenCLKernel : public OpenCLKernel { | ||||
| public: | public: | ||||
| ScaleOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ScaleOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) | |||||
| const std::vector<lite::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | : OpenCLKernel(parameter, inputs, outputs) {} | ||||
| ~ScaleOpenCLKernel() override; | ~ScaleOpenCLKernel() override; | ||||
| @@ -13,233 +13,178 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h" | ||||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestArithmeticOpenCL : public mindspore::CommonTest { | |||||
| public: | |||||
| TestArithmeticOpenCL() {} | |||||
| }; | |||||
| template <class T> | |||||
| static void BoardcaseAdd(const T *a, const T b, T *c, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| c[i] = a[i] + b; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static void ElementAdd(const T *a, const T *b, T *c, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| c[i] = a[i] + b[i]; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static bool DataCompare(const T *a, const T *b, const int size, const float accuracy = 1e-4) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| auto diff = fabs(a[i] - b[i]); | |||||
| if (diff > accuracy) { | |||||
| MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i]; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <class T> | |||||
| static void InitData(void *data, const int size) { | |||||
| T *data_float = reinterpret_cast<T *>(data); | |||||
| static unsigned int seed = 123; | |||||
| for (int i = 0; i < size; i++) { | |||||
| data_float[i] = static_cast<int>(rand_r(&seed)) % 100; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static void LogData(void *data, const int size, const std::string prefix) { | |||||
| std::cout << prefix; | |||||
| T *data_float = reinterpret_cast<T *>(data); | |||||
| for (int i = 0; i < size; i++) { | |||||
| std::cout << data_float[i] << ","; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| } | |||||
| template <class T> | |||||
| static void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) { | |||||
| bool is_log_data = false; | |||||
| void RunTestCaseArithmetic(void *input_data0, const std::vector<int> &input_shape, void *input_data1, | |||||
| const std::vector<int> &weight_shape, void *output_data, const std::vector<int> &out_shape, | |||||
| bool enable_fp16, int op_type, int act_type = schema::ActivationType_NO_ACTIVATION) { | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | ||||
| ocl_runtime->Init(); | |||||
| size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); | |||||
| ocl_runtime->SetFp16Enable(enable_fp16); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | auto allocator = ocl_runtime->GetAllocator(); | ||||
| bool is_bias_add = shape_b.empty(); | |||||
| auto data_type = kNumberTypeFloat32; | |||||
| if (sizeof(T) == 2) { | |||||
| data_type = kNumberTypeFloat16; | |||||
| ocl_runtime->SetFp16Enable(true); | |||||
| } | |||||
| lite::Tensor *tensor_a = new (std::nothrow) lite::Tensor(data_type, shape_a, schema::Format_NHWC4); | |||||
| lite::Tensor *tensor_b = new (std::nothrow) lite::Tensor(data_type, shape_b, schema::Format_NHWC4); | |||||
| lite::Tensor *tensor_c = new (std::nothrow) lite::Tensor(data_type, shape_a, schema::Format_NHWC4); | |||||
| if (tensor_a == nullptr || tensor_b == nullptr || tensor_c == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor failed!"; | |||||
| delete tensor_a; | |||||
| delete tensor_b; | |||||
| delete tensor_c; | |||||
| auto param = static_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "param_ptr create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| int64_t element_num = tensor_a->ElementsC4Num(); | |||||
| int64_t element_num_b = is_bias_add ? 1 : tensor_b->ElementsC4Num(); | |||||
| T *data_a = new (std::nothrow) T[element_num]; | |||||
| T *data_b = new (std::nothrow) T[element_num_b]; | |||||
| T *data_c_cpu = new (std::nothrow) T[element_num]; | |||||
| T *data_c_ocl = new (std::nothrow) T[element_num]; | |||||
| if (data_a == nullptr || data_b == nullptr || data_c_cpu == nullptr || data_c_ocl == nullptr) { | |||||
| MS_LOG(ERROR) << "Create buffer failed!"; | |||||
| delete tensor_a; | |||||
| delete tensor_b; | |||||
| delete tensor_c; | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| int input0_size = std::accumulate(input_shape.begin(), input_shape.end(), 1LL, std::multiplies<int>()); | |||||
| int input1_size = std::accumulate(weight_shape.begin(), weight_shape.end(), 1LL, std::multiplies<int>()); | |||||
| if (input0_size != input1_size) { | |||||
| param->broadcasting_ = true; | |||||
| } | |||||
| param->op_parameter_.type_ = op_type; | |||||
| param->activation_type_ = act_type; | |||||
| auto tensor_x_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); | |||||
| auto tensor_x = tensor_x_ptr.get(); | |||||
| if (tensor_x == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_x create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| InitData<T>(data_a, element_num); | |||||
| InitData<T>(data_b, element_num_b); | |||||
| memset(data_c_ocl, 0, sizeof(T) * element_num); | |||||
| if (is_bias_add) { | |||||
| BoardcaseAdd(data_a, static_cast<T *>(data_b)[0], data_c_cpu, element_num); | |||||
| } else { | |||||
| ElementAdd(data_a, data_b, data_c_cpu, element_num); | |||||
| } | |||||
| std::vector<lite::Tensor *> inputs = {tensor_a}; | |||||
| if (!is_bias_add) { | |||||
| inputs.push_back(tensor_b); | |||||
| } else { | |||||
| tensor_b->MallocData(); | |||||
| memcpy(tensor_b->data_c(), data_b, sizeof(T)); | |||||
| } | |||||
| std::vector<lite::Tensor *> outputs = {tensor_c}; | |||||
| ArithmeticParameter *param = static_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); | |||||
| param->broadcasting_ = is_bias_add; | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "Create parameter failed!"; | |||||
| delete tensor_a; | |||||
| delete tensor_b; | |||||
| delete tensor_c; | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| auto tensor_w_ptr = std::make_unique<lite::Tensor>( | |||||
| TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape, schema::Format_NHWC, | |||||
| input1_size != 1 ? lite::Tensor::Category::CONST_TENSOR : lite::Tensor::Category::CONST_SCALAR); | |||||
| auto tensor_w = tensor_w_ptr.get(); | |||||
| if (tensor_w == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_w create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| param->ndim_ = 4; | |||||
| param->op_parameter_.type_ = PrimitiveType_Add; | |||||
| std::vector<lite::Tensor *> arithmetic_inputs = {tensor_a, tensor_b}; | |||||
| lite::InnerContext ctx; | |||||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||||
| auto *arith_kernel = new (std::nothrow) | |||||
| kernel::ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(param), arithmetic_inputs, outputs, &ctx); | |||||
| if (arith_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create ArithmeticOpenCLKernel failed!"; | |||||
| delete tensor_a; | |||||
| delete tensor_b; | |||||
| delete tensor_c; | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| free(param); | |||||
| tensor_w->set_data(input_data1); | |||||
| auto tensor_out_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); | |||||
| auto tensor_out = tensor_out_ptr.get(); | |||||
| if (tensor_out == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_out create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| arith_kernel->Init(); | |||||
| tensor_a->MallocData(allocator); | |||||
| tensor_b->MallocData(allocator); | |||||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||||
| auto *kernel = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| if (arith_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create SubGraphOpenCLKernel failed!"; | |||||
| delete tensor_a; | |||||
| delete tensor_b; | |||||
| delete tensor_c; | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| delete arith_kernel; | |||||
| std::vector<lite::Tensor *> inputs{tensor_x, tensor_w}; | |||||
| std::vector<lite::Tensor *> outputs{tensor_out}; | |||||
| auto op_kernel_ptr = | |||||
| std::make_unique<kernel::ArithmeticOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| auto op_kernel = op_kernel_ptr.release(); | |||||
| if (op_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "op_kernel create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| kernel->Init(); | |||||
| op_kernel->Init(); | |||||
| inputs[0]->MallocData(allocator); | |||||
| memcpy(inputs[0]->data_c(), data_a, sizeof(T) * element_num); | |||||
| if (!is_bias_add) { | |||||
| memcpy(inputs[1]->data_c(), data_b, sizeof(T) * element_num_b); | |||||
| } | |||||
| kernel->Run(); | |||||
| memcpy(data_c_ocl, outputs[0]->data_c(), sizeof(T) * element_num); | |||||
| std::vector<kernel::LiteKernel *> kernels{op_kernel}; | |||||
| if (is_log_data) { | |||||
| LogData<T>(data_a, 10, "Data A : "); | |||||
| LogData<T>(data_b, tensor_b->shape().empty() ? 1 : 10, "Data B : "); | |||||
| LogData<T>(data_c_cpu, 10, "Expect compute : "); | |||||
| LogData<T>(outputs[0]->data_c(), 10, "OpenCL compute : "); | |||||
| std::vector<lite::Tensor *> inputs_g{tensor_x}; | |||||
| auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels); | |||||
| auto pGraph = pGraph_ptr.get(); | |||||
| if (pGraph == nullptr) { | |||||
| MS_LOG(ERROR) << "pGraph create error."; | |||||
| return; | |||||
| } | |||||
| pGraph->Init(); | |||||
| memcpy(inputs[0]->MutableData(), input_data0, tensor_x->ElementsNum() * dtype_size); | |||||
| pGraph->Run(); | |||||
| if (enable_fp16) { | |||||
| CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float16_t>(1e-3), | |||||
| 2e-2); | |||||
| } else { | |||||
| CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float>(1e-5)); | |||||
| } | } | ||||
| bool cmp = DataCompare(data_c_cpu, data_c_ocl, element_num); | |||||
| MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!"); | |||||
| EXPECT_EQ(true, cmp); | |||||
| // free | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| delete kernel; | |||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| for (auto t : inputs) { | |||||
| t->set_data(nullptr); | |||||
| } | } | ||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| for (auto t : outputs) { | |||||
| t->set_data(nullptr); | |||||
| } | } | ||||
| MS_LOG(INFO) << "TestArithmetic passed"; | |||||
| } | } | ||||
| class TestArithmeticOpenCL : public mindspore::CommonTest { | |||||
| public: | |||||
| TestArithmeticOpenCL() {} | |||||
| }; | |||||
| TEST_F(TestArithmeticOpenCL, ArithmeticElementwiseAddFp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> in_shape1 = {n, h, w, c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> output_data = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f}; | |||||
| RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape, | |||||
| false, schema::PrimitiveType_Add); | |||||
| } | |||||
| TEST_F(TestArithmeticOpenCL, AddElementwiseFP32) { | |||||
| const std::vector<int> &shape_a = {1, 1024, 1024, 4}; | |||||
| const std::vector<int> &shape_b = {1, 1024, 1024, 4}; | |||||
| TestCase<float>(shape_a, shape_b); | |||||
| TEST_F(TestArithmeticOpenCL, ArithmeticScalarMulFp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> in_shape1 = {1}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> weight_data = {2.0f}; | |||||
| std::vector<float> output_data = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f}; | |||||
| RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape, | |||||
| false, schema::PrimitiveType_Mul); | |||||
| } | } | ||||
| TEST_F(TestArithmeticOpenCL, AddBroadcastFP32) { | |||||
| const std::vector<int> &shape_a = {1, 128, 128, 4}; | |||||
| const std::vector<int> &shape_b = {}; | |||||
| TestCase<float>(shape_a, shape_b); | |||||
| TEST_F(TestArithmeticOpenCL, ArithmeticBroadcastSubReLU6Fp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> in_shape1 = {c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> weight_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float> output_data = {0.0f, 0.0f, 0.0f, 3.0f, 3.0f, 3.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}; | |||||
| RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape, | |||||
| false, schema::PrimitiveType_Sub, schema::ActivationType_RELU6); | |||||
| } | } | ||||
| TEST_F(TestArithmeticOpenCL, AddElementwiseFP16) { | |||||
| const std::vector<int> &shape_a = {1, 1024, 1024, 4}; | |||||
| const std::vector<int> &shape_b = {1, 1024, 1024, 4}; | |||||
| TestCase<float16_t>(shape_a, shape_b); | |||||
| TEST_F(TestArithmeticOpenCL, ArithmeticBroadcastSub2Fp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, c}; | |||||
| std::vector<int> in_shape1 = {n, h, w, c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float> weight_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> output_data = {0.0f, 0.0f, 0.0f, -3.0f, -3.0f, -3.0f, -6.0f, -6.0f, -6.0f, -9.0f, -9.0f, -9.0f}; | |||||
| RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape, | |||||
| false, schema::PrimitiveType_Sub); | |||||
| } | } | ||||
| TEST_F(TestArithmeticOpenCL, AddBroadcastFP16) { | |||||
| const std::vector<int> &shape_a = {1, 128, 128, 4}; | |||||
| const std::vector<int> &shape_b = {}; | |||||
| TestCase<float16_t>(shape_a, shape_b); | |||||
| TEST_F(TestArithmeticOpenCL, ArithmeticElementwiseDivFp16) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> in_shape1 = {n, h, w, c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float16_t> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float16_t> weight_data = {1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}; | |||||
| std::vector<float16_t> output_data = {1.0f, 2.0f, 3.0f, 2.0f, 2.5, 3.0f, 7.0f, 8.0f, 9.0f, 5.0f, 5.5, 6.0f}; | |||||
| RunTestCaseArithmetic(input_data.data(), in_shape0, weight_data.data(), in_shape1, output_data.data(), out_shape, | |||||
| true, schema::PrimitiveType_Div); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -13,253 +13,171 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "common/common_test.h" | #include "common/common_test.h" | ||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | ||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h" | #include "mindspore/lite/src/runtime/kernel/opencl/kernel/scale.h" | ||||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class TestScaleOpenCL : public mindspore::CommonTest { | |||||
| public: | |||||
| TestScaleOpenCL() {} | |||||
| }; | |||||
| template <class T> | |||||
| static void BoardcaseScale(const T *in, const T scale, const T offset, T *out, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| out[i] = in[i] * scale + offset; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static void Scale(const T *in, const T *scale, T *offset, T *out, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| out[i] = in[i] * scale[i] + offset[i]; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static bool DataCompare(const T *a, const T *b, const int size, const T accuracy = 1e-4) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| auto diff = fabs(a[i] - b[i]); | |||||
| if (diff > accuracy) { | |||||
| MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i]; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <class T> | |||||
| static void InitData(void *data, const int size) { | |||||
| T *data_float = reinterpret_cast<T *>(data); | |||||
| static unsigned int seed = 123; | |||||
| for (int i = 0; i < size; i++) { | |||||
| data_float[i] = static_cast<int>(rand_r(&seed)) % 100; | |||||
| } | |||||
| } | |||||
| template <class T> | |||||
| static void LogData(void *data, const int size, const std::string prefix) { | |||||
| std::cout << prefix; | |||||
| T *data_float = reinterpret_cast<T *>(data); | |||||
| for (int i = 0; i < size; i++) { | |||||
| std::cout << data_float[i] << ","; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| } | |||||
| template <class T> | |||||
| static void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) { | |||||
| bool is_log_data = false; | |||||
| void RunTestCaseScale(void *input_data0, const std::vector<int> &input_shape, void *scale_data, void *offset_data, | |||||
| const std::vector<int> &weight_shape, void *output_data, const std::vector<int> &out_shape, | |||||
| bool enable_fp16, int axis, int act_type = schema::ActivationType_NO_ACTIVATION) { | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | ||||
| ocl_runtime->Init(); | |||||
| size_t dtype_size = enable_fp16 ? sizeof(float16_t) : sizeof(float); | |||||
| ocl_runtime->SetFp16Enable(enable_fp16); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | auto allocator = ocl_runtime->GetAllocator(); | ||||
| bool is_broadcast = shape_b.empty(); | |||||
| auto format = schema::Format_NHWC4; | |||||
| auto data_type = kNumberTypeFloat32; | |||||
| if (sizeof(T) == 2) { | |||||
| data_type = kNumberTypeFloat16; | |||||
| ocl_runtime->SetFp16Enable(true); | |||||
| } | |||||
| lite::Tensor *tensor_in = new (std::nothrow) lite::Tensor(data_type, shape_a, format); | |||||
| lite::Tensor *tensor_scale = new (std::nothrow) lite::Tensor(data_type, shape_b, format); | |||||
| lite::Tensor *tensor_offset = new (std::nothrow) lite::Tensor(data_type, shape_b, format); | |||||
| lite::Tensor *tensor_out = new (std::nothrow) lite::Tensor(data_type, shape_a, format); | |||||
| if (tensor_in == nullptr || tensor_scale == nullptr || tensor_offset == nullptr) { | |||||
| MS_LOG(ERROR) << "Create tensor failed!"; | |||||
| delete tensor_in; | |||||
| delete tensor_scale; | |||||
| delete tensor_offset; | |||||
| delete tensor_out; | |||||
| auto param = static_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "param_ptr create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| int64_t element_num = tensor_in->ElementsC4Num(); | |||||
| int64_t element_num_b = is_broadcast ? 1 : tensor_scale->ElementsC4Num(); | |||||
| T *data_in = new (std::nothrow) T[element_num]; | |||||
| T *data_scale = new (std::nothrow) T[element_num_b]; | |||||
| T *data_offset = new (std::nothrow) T[element_num_b]; | |||||
| T *data_out_cpu = new (std::nothrow) T[element_num]; | |||||
| T *data_out_ocl = new (std::nothrow) T[element_num]; | |||||
| if (data_in == nullptr || data_scale == nullptr || data_out_cpu == nullptr || data_out_ocl == nullptr) { | |||||
| MS_LOG(ERROR) << "Create buffer failed!"; | |||||
| delete tensor_in; | |||||
| delete tensor_scale; | |||||
| delete tensor_offset; | |||||
| delete tensor_out; | |||||
| delete[] data_in; | |||||
| delete[] data_scale; | |||||
| delete[] data_offset; | |||||
| delete[] data_out_cpu; | |||||
| delete[] data_out_ocl; | |||||
| param->axis_ = axis; | |||||
| param->activation_type_ = act_type; | |||||
| auto tensor_x_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), input_shape); | |||||
| auto tensor_x = tensor_x_ptr.get(); | |||||
| if (tensor_x == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_x create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| InitData<T>(data_in, element_num); | |||||
| InitData<T>(data_scale, element_num_b); | |||||
| InitData<T>(data_offset, element_num_b); | |||||
| memset(data_out_ocl, 0, sizeof(T) * element_num); | |||||
| if (is_broadcast) { | |||||
| BoardcaseScale(data_in, static_cast<T *>(data_scale)[0], static_cast<T *>(data_offset)[0], data_out_cpu, | |||||
| element_num); | |||||
| } else { | |||||
| Scale(data_in, data_scale, data_offset, data_out_cpu, element_num); | |||||
| } | |||||
| std::vector<lite::Tensor *> inputs = {tensor_in}; | |||||
| if (!is_broadcast) { | |||||
| inputs.push_back(tensor_scale); | |||||
| inputs.push_back(tensor_offset); | |||||
| } else { | |||||
| tensor_scale->MallocData(); | |||||
| tensor_offset->MallocData(); | |||||
| memcpy(tensor_scale->data_c(), data_scale, sizeof(T)); | |||||
| memcpy(tensor_offset->data_c(), data_offset, sizeof(T)); | |||||
| } | |||||
| std::vector<lite::Tensor *> outputs = {tensor_out}; | |||||
| ScaleParameter *param = static_cast<ScaleParameter *>(malloc(sizeof(ScaleParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "Create parameter failed!"; | |||||
| delete tensor_in; | |||||
| delete tensor_scale; | |||||
| delete tensor_offset; | |||||
| delete tensor_out; | |||||
| delete[] data_in; | |||||
| delete[] data_scale; | |||||
| delete[] data_offset; | |||||
| delete[] data_out_cpu; | |||||
| delete[] data_out_ocl; | |||||
| auto tensor_scale_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape, | |||||
| schema::Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| auto tensor_scale = tensor_scale_ptr.get(); | |||||
| if (tensor_scale == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_scale create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| param->axis_ = 0; | |||||
| param->op_parameter_.type_ = schema::PrimitiveType_Scale; | |||||
| std::vector<lite::Tensor *> scale_inputs = {tensor_in, tensor_scale, tensor_offset}; | |||||
| lite::InnerContext ctx; | |||||
| ASSERT_EQ(lite::RET_OK, ctx.Init()); | |||||
| auto *scale_kernel = | |||||
| new (std::nothrow) kernel::ScaleOpenCLKernel(reinterpret_cast<OpParameter *>(param), scale_inputs, outputs, &ctx); | |||||
| if (scale_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create ScaleOpenCLKernel failed!"; | |||||
| delete tensor_in; | |||||
| delete tensor_scale; | |||||
| delete tensor_offset; | |||||
| delete tensor_out; | |||||
| delete[] data_in; | |||||
| delete[] data_scale; | |||||
| delete[] data_offset; | |||||
| delete[] data_out_cpu; | |||||
| delete[] data_out_ocl; | |||||
| free(param); | |||||
| tensor_scale->set_data(scale_data); | |||||
| auto tensor_offset_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), weight_shape, | |||||
| schema::Format_NHWC, lite::Tensor::Category::CONST_TENSOR); | |||||
| auto tensor_offset = tensor_offset_ptr.get(); | |||||
| if (tensor_offset == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_offset create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| scale_kernel->Init(); | |||||
| tensor_in->MallocData(allocator); | |||||
| tensor_scale->MallocData(allocator); | |||||
| tensor_offset->MallocData(allocator); | |||||
| std::vector<kernel::LiteKernel *> kernels{scale_kernel}; | |||||
| auto *kernel = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| if (scale_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create SubGraphOpenCLKernel failed!"; | |||||
| delete tensor_in; | |||||
| delete tensor_scale; | |||||
| delete tensor_offset; | |||||
| delete tensor_out; | |||||
| delete[] data_in; | |||||
| delete[] data_scale; | |||||
| delete[] data_offset; | |||||
| delete[] data_out_cpu; | |||||
| delete[] data_out_ocl; | |||||
| delete scale_kernel; | |||||
| tensor_offset->set_data(offset_data); | |||||
| auto tensor_out_ptr = | |||||
| std::make_unique<lite::Tensor>(TypeId(enable_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32), out_shape); | |||||
| auto tensor_out = tensor_out_ptr.get(); | |||||
| if (tensor_out == nullptr) { | |||||
| MS_LOG(ERROR) << "tensor_out create error."; | |||||
| return; | return; | ||||
| } | } | ||||
| kernel->Init(); | |||||
| memcpy(inputs[0]->data_c(), data_in, sizeof(T) * element_num); | |||||
| if (!is_broadcast) { | |||||
| memcpy(inputs[1]->data_c(), data_scale, sizeof(T) * element_num_b); | |||||
| memcpy(inputs[2]->data_c(), data_offset, sizeof(T) * element_num_b); | |||||
| std::vector<lite::Tensor *> inputs{tensor_x, tensor_scale, tensor_offset}; | |||||
| std::vector<lite::Tensor *> outputs{tensor_out}; | |||||
| auto op_kernel_ptr = | |||||
| std::make_unique<kernel::ScaleOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| auto op_kernel = op_kernel_ptr.release(); | |||||
| if (op_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "op_kernel create error."; | |||||
| return; | |||||
| } | } | ||||
| op_kernel->Init(); | |||||
| inputs[0]->MallocData(allocator); | |||||
| kernel->Run(); | |||||
| memcpy(data_out_ocl, outputs[0]->data_c(), sizeof(T) * element_num); | |||||
| std::vector<kernel::LiteKernel *> kernels{op_kernel}; | |||||
| if (is_log_data) { | |||||
| LogData<T>(data_in, 10, "Data input : "); | |||||
| LogData<T>(data_scale, tensor_scale->shape().empty() ? 1 : 10, "Data scale : "); | |||||
| LogData<T>(data_offset, tensor_offset->shape().empty() ? 1 : 10, "Data offset : "); | |||||
| LogData<T>(data_out_cpu, 10, "Expect compute : "); | |||||
| LogData<T>(outputs[0]->data_c(), 10, "OpenCL compute : "); | |||||
| std::vector<lite::Tensor *> inputs_g{tensor_x}; | |||||
| auto pGraph_ptr = std::make_unique<kernel::SubGraphOpenCLKernel>(inputs_g, outputs, kernels, kernels, kernels); | |||||
| auto pGraph = pGraph_ptr.get(); | |||||
| if (pGraph == nullptr) { | |||||
| MS_LOG(ERROR) << "pGraph create error."; | |||||
| return; | |||||
| } | |||||
| pGraph->Init(); | |||||
| memcpy(inputs[0]->MutableData(), input_data0, tensor_x->ElementsNum() * dtype_size); | |||||
| pGraph->Run(); | |||||
| if (enable_fp16) { | |||||
| CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float16_t>(1e-3), | |||||
| 2e-2); | |||||
| } else { | |||||
| CompareOutput(outputs[0]->MutableData(), output_data, tensor_out->ElementsNum(), static_cast<float>(1e-5)); | |||||
| } | } | ||||
| bool cmp = DataCompare(data_out_cpu, data_out_ocl, element_num); | |||||
| MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!"); | |||||
| EXPECT_EQ(true, cmp); | |||||
| // free | |||||
| delete[] data_in; | |||||
| delete[] data_scale; | |||||
| delete[] data_offset; | |||||
| delete[] data_out_cpu; | |||||
| delete[] data_out_ocl; | |||||
| delete kernel; | |||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| for (auto t : inputs) { | |||||
| t->set_data(nullptr); | |||||
| } | } | ||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| for (auto t : outputs) { | |||||
| t->set_data(nullptr); | |||||
| } | } | ||||
| MS_LOG(INFO) << "TestScale passed"; | |||||
| } | } | ||||
| class TestScaleOpenCL : public mindspore::CommonTest { | |||||
| public: | |||||
| TestScaleOpenCL() {} | |||||
| }; | |||||
| TEST_F(TestScaleOpenCL, ElementFP32) { | |||||
| const std::vector<int> &shape_a = {1, 1024, 1024, 4}; | |||||
| const std::vector<int> &shape_b = {1, 1024, 1024, 4}; | |||||
| TestCase<float>(shape_a, shape_b); | |||||
| TEST_F(TestScaleOpenCL, ScaleAxis3Fp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> weight_shape = {c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> scale_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float> offset_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float> output_data = {2.0f, 6.0f, 12.0f, 5.0f, 12.0f, 21.0f, 8.0f, 18.0f, 30.0f, 11.0f, 24.0f, 39.0f}; | |||||
| RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape, | |||||
| output_data.data(), out_shape, false, 3); | |||||
| } | } | ||||
| TEST_F(TestScaleOpenCL, BroadcastFP32) { | |||||
| const std::vector<int> &shape_a = {1, 128, 128, 4}; | |||||
| const std::vector<int> &shape_b = {}; | |||||
| TestCase<float>(shape_a, shape_b); | |||||
| TEST_F(TestScaleOpenCL, ScaleAxis1Fp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> weight_shape = {h}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> scale_data = {1.0f, 2.0f}; | |||||
| std::vector<float> offset_data = {1.0f, 2.0f}; | |||||
| std::vector<float> output_data = {2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f}; | |||||
| RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape, | |||||
| output_data.data(), out_shape, false, 1); | |||||
| } | } | ||||
| TEST_F(TestScaleOpenCL, ElementFP16) { | |||||
| const std::vector<int> &shape_a = {1, 1024, 1024, 4}; | |||||
| const std::vector<int> &shape_b = {1, 1024, 1024, 4}; | |||||
| TestCase<float16_t>(shape_a, shape_b); | |||||
| TEST_F(TestScaleOpenCL, ScaleAxis3ReLU6Fp32) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> weight_shape = {c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float> scale_data = {1.0f, 2.0f, -1.0f}; | |||||
| std::vector<float> offset_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float> output_data = {2.0f, 6.0f, 0.0f, 5.0f, 6.0f, 0.0f, 6.0f, 6.0f, 0.0f, 6.0f, 6.0f, 0.0f}; | |||||
| RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape, | |||||
| output_data.data(), out_shape, false, 3, schema::ActivationType_RELU6); | |||||
| } | } | ||||
| TEST_F(TestScaleOpenCL, BroadcastFP16) { | |||||
| const std::vector<int> &shape_a = {1, 128, 128, 4}; | |||||
| const std::vector<int> &shape_b = {}; | |||||
| TestCase<float16_t>(shape_a, shape_b); | |||||
| TEST_F(TestScaleOpenCL, ScaleAxis3Fp16) { | |||||
| int n = 1; | |||||
| int h = 2; | |||||
| int w = 2; | |||||
| int c = 3; | |||||
| std::vector<int> in_shape0 = {n, h, w, c}; | |||||
| std::vector<int> weight_shape = {c}; | |||||
| std::vector<int> out_shape = {n, h, w, c}; | |||||
| std::vector<float16_t> input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; | |||||
| std::vector<float16_t> scale_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float16_t> offset_data = {1.0f, 2.0f, 3.0f}; | |||||
| std::vector<float16_t> output_data = {2.0f, 6.0f, 12.0f, 5.0f, 12.0f, 21.0f, 8.0f, 18.0f, 30.0f, 11.0f, 24.0f, 39.0f}; | |||||
| RunTestCaseScale(input_data.data(), in_shape0, scale_data.data(), offset_data.data(), weight_shape, | |||||
| output_data.data(), out_shape, true, 3); | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||