Browse Source

!4603 [MS][LITE][Develop]add new ops named batchnorm for opencl(GPU)

Merge pull request !4603 from pengyongrong/batchnorm
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
1c1ee66008
10 changed files with 409 additions and 20 deletions
  1. +27
    -0
      mindspore/lite/src/runtime/kernel/opencl/cl/fp32/batchnorm.cl
  2. +2
    -2
      mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl
  3. +148
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc
  4. +49
    -0
      mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h
  5. +6
    -4
      mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc
  6. +9
    -9
      mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc
  7. +31
    -5
      mindspore/lite/src/scheduler.cc
  8. +1
    -0
      mindspore/lite/src/scheduler.h
  9. +2
    -0
      mindspore/lite/test/CMakeLists.txt
  10. +134
    -0
      mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc

+ 27
- 0
mindspore/lite/src/runtime/kernel/opencl/cl/fp32/batchnorm.cl View File

@@ -0,0 +1,27 @@
#define FLT4 float4
#define INT4 int4
#define INT2 int2
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
__kernel void batch_normalization(__read_only image2d_t input, __read_only image2d_t scale,
__read_only image2d_t offset, __read_only image2d_t mean,
__read_only image2d_t variance, __write_only image2d_t output, const INT4 input_shape,
float epsilon) {
int X = get_global_id(0); // H
int Y = get_global_id(1); // W
int Z = get_global_id(2); // C/4
if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) {
return;
}
FLT4 result = read_imagef(input, smp_none, (int2)((Y)*input_shape.w + Z, (X)));

FLT4 result_mean = read_imagef(mean, smp_none, (int2)((Z), (0)));
FLT4 result_var = read_imagef(variance, smp_none, (int2)((Z), (0)));
FLT4 result_scale = read_imagef(scale, smp_none, (int2)((Z), (0)));
FLT4 result_offset = read_imagef(offset, smp_none, (int2)((Z), (0)));

result.x = result_scale.x * ((result.x - result_mean.x) / sqrt(result_var.x + epsilon)) + result_offset.x;
result.y = result_scale.y * ((result.y - result_mean.y) / sqrt(result_var.y + epsilon)) + result_offset.y;
result.z = result_scale.z * ((result.z - result_mean.z) / sqrt(result_var.z + epsilon)) + result_offset.z;
result.w = result_scale.w * ((result.w - result_mean.w) / sqrt(result_var.w + epsilon)) + result_offset.w;
write_imagef(output, (int2)((Y)*input_shape.w + Z, (X)), result);
}

+ 2
- 2
mindspore/lite/src/runtime/kernel/opencl/cl/fp32/to_format.cl View File

@@ -63,7 +63,7 @@ __kernel void to_format_NHWC4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
// WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X)));
WRITE_IMAGE(dst_data, (int2)(Y * size.z + Z, X), src_data[(X * size.y + Y) * size.z + Z]);
}
__kernel void to_format_NC4HW4_to_NHWC4_IMG(__global FLT4 *src_data, __write_only image2d_t dst_data, int4 size,
int4 shape) {
@@ -231,5 +231,5 @@ __kernel void to_format_NHWC4_to_NHWC4_BUF(__read_only image2d_t src_data, __glo
if (X >= size.x || Y >= size.y || Z >= size.z) {
return;
}
dst_data[(Y * size.z + Z) * size.x + X] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X));
dst_data[(X * size.y + Y) * size.z + Z] = READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + Z, X));
}

+ 148
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.cc View File

@@ -0,0 +1,148 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include "src/kernel_registry.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/kernel/batchnorm.h"
#include "src/runtime/kernel/opencl/cl/fp32/batchnorm.cl.inc"

using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_BatchNorm;

namespace mindspore::kernel {

int BatchNormOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height();
} else {
im_dst_y = out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int BatchNormOpenCLKernel::Init() {
std::set<std::string> build_options;
std::string source = batchnorm_source_fp32;
std::string program_name = "batch_normalization";
std::string kernel_name = "batch_normalization";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);

return RET_OK;
}

int BatchNormOpenCLKernel::ReSize() { return RET_OK; }

int BatchnormGetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && 8 <= max_divider) {
return number / 8;
}
if (number % 4 == 0 && 4 <= max_divider) {
return number / 4;
}
if (number % 2 == 0 && 2 <= max_divider) {
return number / 2;
}

for (int i = max_divider; i != 0; i--) {
if (number % i == 0) {
return i;
}
}
return RET_OK;
}

void BatchNormGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 4, max_y = 8;
int x = std::min(BatchnormGetBiggestDividerWithPriority(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(BatchnormGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));

local->clear();
local->push_back(x);
local->push_back(y);
local->push_back(z);
}
int BatchNormOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running!";
auto param = reinterpret_cast<BatchNormParameter *>(this->op_parameter_);
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
auto input0_shape = in_tensors_[0]->shape();
auto output_shape = out_tensors_[0]->shape();
cl_int4 input_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], UP_DIV(input0_shape[3], C4NUM)};

uint32_t OH = output_shape[1];
uint32_t OW = output_shape[2];
uint32_t OC = UP_DIV(output_shape[3], C4NUM);

const std::vector<size_t> &max_global = ocl_runtime->GetWorkItemSize();
std::vector<size_t> local = {1, 1, 1}; // init local
std::vector<size_t> global = {OH, OW, OC};
BatchNormGetWorkGroup(global, &local, max_global[0]);
int arg_cn = 0;
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->Data()); // input tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[1]->Data()); // scale
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[2]->Data()); // offest
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[3]->Data()); // mean
ocl_runtime->SetKernelArg(kernel_, arg_cn++, in_tensors_[4]->Data()); // variance
ocl_runtime->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->Data()); // out tensor
ocl_runtime->SetKernelArg(kernel_, arg_cn++, input_shape_);
ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->epsilon_);
ocl_runtime->RunKernel(kernel_, global, local, nullptr);

return RET_OK;
} // namespace mindspore::kernel

kernel::LiteKernel *OpenCLBatchnormKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
auto *kernel = new (std::nothrow) BatchNormOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new BatchnormOpenCLKernel failed";
return nullptr;
}
auto ret = kernel->Init();
if (0 != ret) {
MS_LOG(ERROR) << "Init kernel failed, name: Convolution";
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BatchNorm, OpenCLBatchnormKernelCreator);
} // namespace mindspore::kernel

+ 49
- 0
mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_BATCHNORM_H_
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_BATCHNORM_H_

#include <vector>
#include "ir/anf.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/nnacl/fp32/batchnorm.h"

namespace mindspore::kernel {

class BatchNormOpenCLKernel : public OpenCLKernel {
public:
explicit BatchNormOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}

~BatchNormOpenCLKernel() override{};

int Init() override;

int ReSize() override;

int Run() override;

int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;

private:
cl::Kernel kernel_;
};

} // namespace mindspore::kernel
#endif

+ 6
- 4
mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc View File

@@ -83,6 +83,8 @@ int ConcatOpenCLKernel::Init() {
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
}
ori_format_ = out_tensors_[0]->GetFormat();
out_tensors_[0]->SetFormat(schema::Format_NHWC4);

return RET_OK;
}
@@ -116,8 +118,8 @@ int ConcatOpenCLKernel::Run_axis0() {
return RET_OK;
}

int GetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && 8 <= max_divider) {
int ConcatGetBiggestDividerWithPriority(int number, int max_divider) {
if (number % 8 == 0 && max_divider >= 8) {
return number / 8;
}
if (number % 4 == 0 && 4 <= max_divider) {
@@ -138,9 +140,9 @@ int GetBiggestDividerWithPriority(int number, int max_divider) {
void ConcatGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 4, max_y = 8;
int x = std::min(GetBiggestDividerWithPriority(global[0], max_divider), max_x);
int x = std::min(ConcatGetBiggestDividerWithPriority(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(GetBiggestDividerWithPriority(global[1], max_divider), yz), max_y);
int y = std::min(std::min(ConcatGetBiggestDividerWithPriority(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));

local->clear();


+ 9
- 9
mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.cc View File

@@ -103,16 +103,16 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size
size_t im_dst_x, im_dst_y;
std::vector<int> shapex = out_tensors_[0]->shape();
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
int c = shapex[1];
int h = shapex[2];
int w = shapex[3];
im_dst_y = h * UP_DIV(c, C4NUM);
im_dst_x = w;
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
int h = shapex[1];
int w = shapex[2];
int c = shapex[3];
im_dst_y = UP_DIV(h * c, C4NUM);
im_dst_x = w;
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
int h = shapex[2];
int w = shapex[3];
int c = shapex[1];
im_dst_x = UP_DIV(w * c, C4NUM);
im_dst_x = w * UP_DIV(c, C4NUM);
im_dst_y = h;
} else {
MS_LOG(ERROR) << "Unsupported format. " << out_tensors_[0]->GetFormat();
@@ -127,9 +127,9 @@ int ToFormatOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size
*img_size = vec;
return RET_OK;
}

int ToFormatOpenCLKernel::Run() {
MS_LOG(DEBUG) << "ToFormat" << " Running!";
MS_LOG(DEBUG) << "ToFormat"
<< " Running!";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
std::vector<size_t> local = {};
std::vector<size_t> global;


+ 31
- 5
mindspore/lite/src/scheduler.cc View File

@@ -31,7 +31,12 @@ int Scheduler::Schedule(const lite::Model *model, std::vector<tensor::Tensor *>
// 1. op ---> kernel
// 2. sub graph
// 3. kernels (kernels --> subGraph)
int ret = InitOp2Kernel(model, tensors, kernels);
int ret = InferShape(model, tensors);
if (ret != RET_OK) {
MS_LOG(ERROR) << "op infer shape failed.";
return RET_ERROR;
}
ret = InitOp2Kernel(model, tensors, kernels);
if (ret != RET_OK) {
MS_LOG(ERROR) << "init op to kernel failed.";
return RET_ERROR;
@@ -72,15 +77,12 @@ int Scheduler::ReSizeKernels(const std::vector<kernel::LiteKernel *> &kernels) {
return RET_OK;
}

int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *> *tensors) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
MS_EXCEPTION_IF_NULL(kernels);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
uint32_t kernelCount = meta_graph->nodes()->size();
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (uint32_t i = 0; i < kernelCount; i++) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
std::vector<tensor::Tensor *> inputs;
@@ -115,7 +117,31 @@ int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tenso
} else {
primitive->SetInferFlag(false);
}
}
return RET_OK;
}

int Scheduler::InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels) {
MS_EXCEPTION_IF_NULL(model);
MS_EXCEPTION_IF_NULL(tensors);
auto meta_graph = model->GetMetaGraph();
MS_EXCEPTION_IF_NULL(meta_graph);
uint32_t kernelCount = meta_graph->nodes()->size();
auto graph_output_node_indexes = GetGraphOutputNodes(meta_graph);
for (uint32_t i = 0; i < kernelCount; i++) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i);
std::vector<tensor::Tensor *> inputs;
std::vector<tensor::Tensor *> outputs;
auto inIndexes = cNode->inputIndex();
for (size_t j = 0; j < inIndexes->size(); j++) {
inputs.emplace_back(tensors->at(size_t(inIndexes->GetAs<uint32_t>(j))));
}
auto outIndexes = cNode->outputIndex();
for (size_t j = 0; j < outIndexes->size(); j++) {
outputs.emplace_back(tensors->at(size_t(outIndexes->GetAs<uint32_t>(j))));
}
auto *primitive = model->GetOp(cNode->name()->str());
auto *kernel = this->ScheduleNode(inputs, outputs, primitive);
if (nullptr == kernel) {
MS_LOG(ERROR) << "ScheduleNode return nullptr, name: " << cNode->name()->str()


+ 1
- 0
mindspore/lite/src/scheduler.h View File

@@ -38,6 +38,7 @@ class Scheduler {
private:
int InitOp2Kernel(const lite::Model *model, std::vector<tensor::Tensor *> *tensors,
std::vector<kernel::LiteKernel *> *kernels);
int InferShape(const lite::Model *model, std::vector<tensor::Tensor *> *tensors);

// construct SubGraphKernel for each kernel-group in markedKernelGroup
void ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels);


+ 2
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -144,6 +144,7 @@ if (SUPPORT_GPU)
${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/batchnorm.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/activation.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
@@ -314,6 +315,7 @@ if (SUPPORT_GPU)
${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/batchnorm_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc


+ 134
- 0
mindspore/lite/test/ut/src/runtime/kernel/opencl/batchnorm_tests.cc View File

@@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/batchnorm.h"

namespace mindspore {
class TestBatchnormOpenCL : public mindspore::CommonTest {
public:
TestBatchnormOpenCL() {}
};

template <typename T>
void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) {
for (size_t i = 0; i < size; i++) {
T abs = fabs(output_data[i] - correct_data[i]);
ASSERT_LE(abs, err_bound);
}
}

TEST_F(TestBatchnormOpenCL, Batchnorminput_dim4) {
MS_LOG(INFO) << "begin test";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();

MS_LOG(INFO) << "Read tensors from .bin";
std::vector<int> input_shape = {1, 256, 256, 48};
std::vector<int> output_shape = {1, 256, 256, 48};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;

// get the input from .bin
size_t input_size, output_size;
std::string input_path = "./test_data/in_data.bin";
std::string mean_path = "./test_data/mean.bin";
std::string var_path = "./test_data/var.bin";
std::string output_path = "./test_data/out_data.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size));
auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size));
size_t mean_size, var_size;
auto mean_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(mean_path.c_str(), &mean_size));
auto var_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(var_path.c_str(), &var_size));

MS_LOG(INFO) << "construct tensors";
lite::tensor::Tensor *tensor_data =
new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC, tensor_type);
lite::tensor::Tensor *tensor_mean =
new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::tensor::Tensor *tensor_var =
new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::tensor::Tensor *tensor_scale =
new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
lite::tensor::Tensor *tensor_offset =
new (std::nothrow) lite::tensor::Tensor(data_type, {1, 1, 1, input_shape[3]}, schema::Format_NHWC, tensor_type);
if (tensor_data == nullptr || tensor_mean == nullptr || tensor_var == nullptr || tensor_scale == nullptr ||
tensor_offset == nullptr) {
MS_LOG(INFO) << "init tensor failed";
return;
}
auto *output_tensor =
new (std::nothrow) lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensor_type);
if (output_tensor == nullptr) {
MS_LOG(INFO) << "init tensor failed";
return;
}
std::vector<lite::tensor::Tensor *> inputs = {tensor_data, tensor_scale, tensor_offset, tensor_mean, tensor_var};
std::vector<lite::tensor::Tensor *> outputs{output_tensor};

MS_LOG(INFO) << "initialize tensors";
auto param = new (std::nothrow) BatchNormParameter();
if (param == nullptr) {
MS_LOG(INFO) << "new BatchNormParameter failed";
return;
}
param->epsilon_ = pow(10, -5);
auto *batchnorm_kernel =
new (std::nothrow) kernel::BatchNormOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (batchnorm_kernel == nullptr) {
MS_LOG(INFO) << "new kernel::BatchNorm_kernel failed";
return;
}
batchnorm_kernel->Init();

// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}

MS_LOG(INFO) << "initialize sub_graph";
std::vector<kernel::LiteKernel *> kernels{batchnorm_kernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
MS_LOG(INFO) << "new kernel::SubGraphOpenCLKernel failed";
return;
}
sub_graph->Init();
MS_LOG(INFO) << "init tensors";
std::cout << "init tensors" << std::endl;
memcpy(inputs[0]->Data(), input_data, input_size);
auto &temp = inputs[1];
auto tensor_temp = reinterpret_cast<float *>(temp->Data());
int UPDIV_tensor_scale = UP_DIV(tensor_scale->ElementsNum(), C4NUM) * 4;
for (int i = 0; i < UPDIV_tensor_scale; ++i) {
tensor_temp[i] = static_cast<float>(1);
}
memcpy(inputs[3]->Data(), mean_data, mean_size);
memcpy(inputs[4]->Data(), var_data, var_size);
std::cout << "==================output data================" << std::endl;
sub_graph->Run();

auto *output_data_gpu = reinterpret_cast<float *>(output_tensor->Data());
CompareOutputData1(output_data_gpu, correct_data, output_tensor->ElementsNum(), 0.0001);
lite::opencl::OpenCLRuntime::DeleteInstance();
}
} // namespace mindspore

Loading…
Cancel
Save