diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/batchnorm.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/batchnorm.cl new file mode 100644 index 0000000000..a68141329b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/batchnorm.cl @@ -0,0 +1,27 @@ +#define FLT4 float4 +#define INT4 int4 +#define INT2 int2 +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__kernel void batch_normalization(__read_only image2d_t input, __read_only image2d_t scale, + __read_only image2d_t offset, __read_only image2d_t mean, + __read_only image2d_t variance, __write_only image2d_t output, const INT4 input_shape, + float epsilon) { + int X = get_global_id(0); // H + int Y = get_global_id(1); // W + int Z = get_global_id(2); // C/4 + if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) { + return; + } + FLT4 result = read_imagef(input, smp_none, (int2)((Y)*input_shape.w + Z, (X))); + + FLT4 result_mean = read_imagef(mean, smp_none, (int2)((Z), (0))); + FLT4 result_var = read_imagef(variance, smp_none, (int2)((Z), (0))); + FLT4 result_scale = read_imagef(scale, smp_none, (int2)((Z), (0))); + FLT4 result_offset = read_imagef(offset, smp_none, (int2)((Z), (0))); + + result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x; + result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y; + result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z; + result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w; + write_imagef(output, (int2)((Y)*input_shape.w + Z, (X)), result); +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl index 0d877b778c..0811d7ce52 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl @@ -63,7 +63,7 @@ __kernel void to_format_NHWC4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only if (X >= size.x || Y >= size.y || Z >= size.z) { return; } - // WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X))); + WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), src_data[(X * size.y + Y) * size.z + Z]); } __kernel void to_format_NC4HW4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size, int4 shape) { @@ -231,5 +231,5 @@ __kernel void to_format_NHWC4_to_NHWC4_BUF(__read_only image2d_t src_data, __glo if (X >= size.x || Y >= size.y || Z >= size.z) { return; } - dst_data[(Y * size.z + Z) * size.x + X] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)); + dst_data[(X * size.y + Y) * size.z + Z] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc new file mode 100644 index 0000000000..93fa46d11f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "src/kernel_registry.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/kernel/batchnorm.h" +#include "src/runtime/kernel/opencl/cl/fp32/batchnorm.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::schema::PrimitiveType_BatchNorm; + +namespace mindspore::kernel { + +int BatchNormOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM); + size_t im_dst_x, im_dst_y; + if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) { + im_dst_x = out_tensors_[0]->Width() * CO4; + im_dst_y = out_tensors_[0]->Height(); + } else { + im_dst_y = out_tensors_[0]->Height() * CO4; + im_dst_x = out_tensors_[0]->Width(); + } +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + img_size->clear(); + std::vector vec{im_dst_x, im_dst_y, img_dtype}; + *img_size = vec; + return RET_OK; +} +int BatchNormOpenCLKernel::Init() { + std::set build_options; + std::string source = batchnorm_source_fp32; + std::string program_name = "batch_normalization"; + std::string kernel_name = "batch_normalization"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + ori_format_ = out_tensors_[0]->GetFormat(); + out_tensors_[0]->SetFormat(schema::Format_NHWC4); + + return RET_OK; +} + +int BatchNormOpenCLKernel::ReSize() { return RET_OK; } + +int BatchnormGetBiggestDividerWithPriority(int number, int max_divider) { + if (number % 8 == 0 && 8 <= max_divider) { + return number / 8; + } + if (number % 4 == 0 && 4 <= max_divider) { + return number / 4; + } + if (number % 2 == 0 && 2 <= max_divider) { + return number / 2; + } + + for (int i = max_divider; i != 0; i--) { + if (number % i == 0) { + return i; + } + } + return RET_OK; +} + +void BatchNormGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { + const int max_divider = 8; + const int max_x = 4, max_y = 8; + int x = std::min(BatchnormGetBiggestDividerWithPriority(global[0], max_divider), max_x); + int yz = max_size / x; + int y = std::min(std::min(BatchnormGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y); + int z = std::min(yz / y, static_cast(UP_DIV(global[2], 2))); + + local->clear(); + local->push_back(x); + local->push_back(y); + local->push_back(z); +} +int BatchNormOpenCLKernel::Run() { + MS_LOG(DEBUG) << this->name() << " Running!"; + auto param = reinterpret_cast(this->op_parameter_); + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + auto input0_shape = in_tensors_[0]->shape(); + auto output_shape = out_tensors_[0]->shape(); + cl_int4 input_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], UP_DIV(input0_shape[3], C4NUM)}; + + uint32_t OH = output_shape[1]; + uint32_t OW = output_shape[2]; + uint32_t OC = UP_DIV(output_shape[3], C4NUM); + + const std::vector &max_global = ocl_runtime->GetWorkItemSize(); + std::vector local = {1, 1, 1}; // init local + std::vector global = {OH, OW, OC}; + BatchNormGetWorkGroup(global, &local, max_global[0]); + int arg_cn = 0; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); // input tensor + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->Data()); // scale + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->Data()); // offest + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->Data()); // mean + ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[4]->Data()); // variance + ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); // out tensor + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->epsilon_); + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + + return RET_OK; +} // namespace mindspore::kernel + +kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { + auto *kernel = new (std::nothrow) BatchNormOpenCLKernel(opParameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new BatchnormOpenCLKernel failed"; + return nullptr; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init kernel failed, name: Convolution"; + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BatchNorm, OpenCLBatchnormKernelCreator); +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h new file mode 100644 index 0000000000..f27556e060 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_BATCHNORM_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_BATCHNORM_H_ + +#include +#include "ir/anf.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/arm/nnacl/fp32/batchnorm.h" + +namespace mindspore::kernel { + +class BatchNormOpenCLKernel : public OpenCLKernel { + public: + explicit BatchNormOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + + ~BatchNormOpenCLKernel() override{}; + + int Init() override; + + int ReSize() override; + + int Run() override; + + int GetImageSize(size_t idx, std::vector *img_size) override; + + private: + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index ae985412d7..b7dc8f1a86 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -83,6 +83,8 @@ int ConcatOpenCLKernel::Init() { ocl_runtime->LoadSource(program_name, source); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); } + ori_format_ = out_tensors_[0]->GetFormat(); + out_tensors_[0]->SetFormat(schema::Format_NHWC4); return RET_OK; } @@ -116,8 +118,8 @@ int ConcatOpenCLKernel::Run_axis0() { return RET_OK; } -int GetBiggestDividerWithPriority(int number, int max_divider) { - if (number % 8 == 0 && 8 <= max_divider) { +int ConcatGetBiggestDividerWithPriority(int number, int max_divider) { + if (number % 8 == 0 && max_divider >= 8) { return number / 8; } if (number % 4 == 0 && 4 <= max_divider) { @@ -138,9 +140,9 @@ int GetBiggestDividerWithPriority(int number, int max_divider) { void ConcatGetWorkGroup(const std::vector &global, std::vector *local, int max_size) { const int max_divider = 8; const int max_x = 4, max_y = 8; - int x = std::min(GetBiggestDividerWithPriority(global[0], max_divider), max_x); + int x = std::min(ConcatGetBiggestDividerWithPriority(global[0], max_divider), max_x); int yz = max_size / x; - int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], max_divider), yz), max_y); + int y = std::min(std::min(ConcatGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y); int z = std::min(yz / y, static_cast(UP_DIV(global[2], 2))); local->clear(); diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc index 44d1cf8dde..d4e9d79a1f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc @@ -103,16 +103,16 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size size_t im_dst_x, im_dst_y; std::vector shapex = out_tensors_[0]->shape(); if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) { + int c = shapex[1]; + int h = shapex[2]; + int w = shapex[3]; + im_dst_y = h * UP_DIV(c, C4NUM); + im_dst_x = w; + } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { int h = shapex[1]; int w = shapex[2]; int c = shapex[3]; - im_dst_y = UP_DIV(h * c, C4NUM); - im_dst_x = w; - } else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) { - int h = shapex[2]; - int w = shapex[3]; - int c = shapex[1]; - im_dst_x = UP_DIV(w * c, C4NUM); + im_dst_x = w * UP_DIV(c, C4NUM); im_dst_y = h; } else { MS_LOG(ERROR) << "Unsupported format. " << out_tensors_[0]->GetFormat(); @@ -127,9 +127,9 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size *img_size = vec; return RET_OK; } - int ToFormatOpenCLKernel::Run() { - MS_LOG(DEBUG) << "ToFormat" << " Running!"; + MS_LOG(DEBUG) << "ToFormat" + << " Running!"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); std::vector local = {}; std::vector global; diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index c121537071..604ad6541a 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -31,7 +31,12 @@ int Scheduler::Schedule(const lite::Model *model, std::vector // 1. op ---> kernel // 2. sub graph // 3. kernels (kernels --> subGraph) - int ret = InitOp2Kernel(model, tensors, kernels); + int ret = InferShape(model, tensors); + if (ret != RET_OK) { + MS_LOG(ERROR) << "op infer shape failed."; + return RET_ERROR; + } + ret = InitOp2Kernel(model, tensors, kernels); if (ret != RET_OK) { MS_LOG(ERROR) << "init op to kernel failed."; return RET_ERROR; @@ -72,15 +77,12 @@ int Scheduler::ReSizeKernels(const std::vector &kernels) { return RET_OK; } -int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector *tensors, - std::vector *kernels) { +int Scheduler::InferShape(const lite::Model *model, std::vector *tensors) { MS_EXCEPTION_IF_NULL(model); MS_EXCEPTION_IF_NULL(tensors); - MS_EXCEPTION_IF_NULL(kernels); auto meta_graph = model->GetMetaGraph(); MS_EXCEPTION_IF_NULL(meta_graph); uint32_t kernelCount = meta_graph->nodes()->size(); - auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); for (uint32_t i = 0; i < kernelCount; i++) { auto cNode = meta_graph->nodes()->GetAs(i); std::vector inputs; @@ -115,7 +117,31 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vectorSetInferFlag(false); } + } + return RET_OK; +} +int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector *tensors, + std::vector *kernels) { + MS_EXCEPTION_IF_NULL(model); + MS_EXCEPTION_IF_NULL(tensors); + auto meta_graph = model->GetMetaGraph(); + MS_EXCEPTION_IF_NULL(meta_graph); + uint32_t kernelCount = meta_graph->nodes()->size(); + auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph); + for (uint32_t i = 0; i < kernelCount; i++) { + auto cNode = meta_graph->nodes()->GetAs(i); + std::vector inputs; + std::vector outputs; + auto inIndexes = cNode->inputIndex(); + for (size_t j = 0; j < inIndexes->size(); j++) { + inputs.emplace_back(tensors->at(size_t(inIndexes->GetAs(j)))); + } + auto outIndexes = cNode->outputIndex(); + for (size_t j = 0; j < outIndexes->size(); j++) { + outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs(j)))); + } + auto *primitive = model->GetOp(cNode->name()->str()); auto *kernel = this->ScheduleNode(inputs, outputs, primitive); if (nullptr == kernel) { MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << cNode->name()->str() diff --git a/mindspore/lite/src/scheduler.h b/mindspore/lite/src/scheduler.h index 407dceafd4..ef1991c42c 100644 --- a/mindspore/lite/src/scheduler.h +++ b/mindspore/lite/src/scheduler.h @@ -38,6 +38,7 @@ class Scheduler { private: int InitOp2Kernel(const lite::Model *model, std::vector *tensors, std::vector *kernels); + int InferShape(const lite::Model *model, std::vector *tensors); // construct SubGraphKernel for each kernel-group in markedKernelGroup void ConstructSubgraphs(std::vector *kernels); diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 3ae5c32cd8..6c29040139 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -144,6 +144,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc @@ -314,6 +315,7 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc new file mode 100644 index 0000000000..47de2d1acd --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h" + +namespace mindspore { +class TestBatchnormOpenCL : public mindspore::CommonTest { + public: + TestBatchnormOpenCL() {} +}; + +template +void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + ASSERT_LE(abs, err_bound); + } +} + +TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) { + MS_LOG(INFO) << "begin test"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "Read tensors from .bin"; + std::vector input_shape = {1, 256, 256, 48}; + std::vector output_shape = {1, 256, 256, 48}; + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + + // get the input from .bin + size_t input_size, output_size; + std::string input_path = "./test_data/in_data.bin"; + std::string mean_path = "./test_data/mean.bin"; + std::string var_path = "./test_data/var.bin"; + std::string output_path = "./test_data/out_data.bin"; + auto input_data = reinterpret_cast(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); + auto correct_data = reinterpret_cast(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); + size_t mean_size, var_size; + auto mean_data = reinterpret_cast(mindspore::lite::ReadFile(mean_path.c_str(), &mean_size)); + auto var_data = reinterpret_cast(mindspore::lite::ReadFile(var_path.c_str(), &var_size)); + + MS_LOG(INFO) << "construct tensors"; + lite::tensor::Tensor *tensor_data = + new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_mean = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_var = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_scale = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + lite::tensor::Tensor *tensor_offset = + new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type); + if (tensor_data == nullptr || tensor_mean == nullptr || tensor_var == nullptr || tensor_scale == nullptr || + tensor_offset == nullptr) { + MS_LOG(INFO) << "init tensor failed"; + return; + } + auto *output_tensor = + new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type); + if (output_tensor == nullptr) { + MS_LOG(INFO) << "init tensor failed"; + return; + } + std::vector inputs = {tensor_data, tensor_scale, tensor_offset, tensor_mean, tensor_var}; + std::vector outputs{output_tensor}; + + MS_LOG(INFO) << "initialize tensors"; + auto param = new (std::nothrow) BatchNormParameter(); + if (param == nullptr) { + MS_LOG(INFO) << "new BatchNormParameter failed"; + return; + } + param->epsilon_ = pow(10, -5); + auto *batchnorm_kernel = + new (std::nothrow) kernel::BatchNormOpenCLKernel(reinterpret_cast(param), inputs, outputs); + if (batchnorm_kernel == nullptr) { + MS_LOG(INFO) << "new kernel::BatchNorm_kernel failed"; + return; + } + batchnorm_kernel->Init(); + + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{batchnorm_kernel}; + auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + if (sub_graph == nullptr) { + MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed"; + return; + } + sub_graph->Init(); + MS_LOG(INFO) << "init tensors"; + std::cout << "init tensors" << std::endl; + memcpy(inputs[0]->Data(), input_data, input_size); + auto &temp = inputs[1]; + auto tensor_temp = reinterpret_cast(temp->Data()); + int UPDIV_tensor_scale = UP_DIV(tensor_scale->ElementsNum(), C4NUM) * 4; + for (int i = 0; i < UPDIV_tensor_scale; ++i) { + tensor_temp[i] = static_cast(1); + } + memcpy(inputs[3]->Data(), mean_data, mean_size); + memcpy(inputs[4]->Data(), var_data, var_size); + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001); + lite::opencl::OpenCLRuntime::DeleteInstance(); +} +} // namespace mindspore