Merge pull request !4637 from liuzhongkai/prelutags/v0.7.0-beta
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/kernel/opencl/kernel/prelu.h" | |||||
| #include "src/runtime/opencl/opencl_runtime.h" | |||||
| #include "src/runtime/kernel/opencl/cl/fp32/activation.cl.inc" | |||||
| #include "src/runtime/kernel/arm/nnacl/prelu_parameter.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Prelu; | |||||
| namespace mindspore::kernel { | |||||
| int PReluOpenCLKernel::Init() { | |||||
| if (in_tensors_[0]->shape().size() != 4) { | |||||
| MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::set<std::string> build_options; | |||||
| std::string source = activation_source_fp32; | |||||
| std::string program_name = "PRelu"; | |||||
| std::string kernel_name = "ReluScalar"; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| ocl_runtime->LoadSource(program_name, source); | |||||
| ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| ori_format_ = out_tensors_[0]->GetFormat(); | |||||
| out_tensors_[0]->SetFormat(schema::Format_NHWC4); | |||||
| MS_LOG(DEBUG) << program_name << " init Done!"; | |||||
| return RET_OK; | |||||
| } | |||||
| int PReluOpenCLKernel::Run() { | |||||
| MS_LOG(DEBUG) << op_parameter_->name_ << " Running!"; | |||||
| int N = in_tensors_[0]->shape()[0]; | |||||
| int H = in_tensors_[0]->shape()[1]; | |||||
| int W = in_tensors_[0]->shape()[2]; | |||||
| int C = in_tensors_[0]->shape()[3]; | |||||
| cl_int4 input_shape = {N, H, W, C}; | |||||
| if (in_tensors_[1]->ElementsNum() < 1) { | |||||
| MS_LOG(ERROR) << "PRelu weight size must be greater than 1! But your weight size is " | |||||
| << in_tensors_[1]->ElementsNum(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| int arg_idx = 0; | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data()); | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data()); | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); | |||||
| ocl_runtime->SetKernelArg(kernel_, arg_idx++, reinterpret_cast<float *>(in_tensors_[1]->Data())[0]); | |||||
| std::vector<size_t> local = {1, 1}; | |||||
| std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)}; | |||||
| auto ret = ocl_runtime->RunKernel(kernel_, global, local, nullptr); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int PReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||||
| int H = in_tensors_[0]->shape()[1]; | |||||
| int W = in_tensors_[0]->shape()[2]; | |||||
| int C = in_tensors_[0]->shape()[3]; | |||||
| #ifdef ENABLE_FP16 | |||||
| size_t img_dtype = CL_HALF_FLOAT; | |||||
| #else | |||||
| size_t img_dtype = CL_FLOAT; | |||||
| #endif | |||||
| img_size->clear(); | |||||
| img_size->push_back(W * UP_DIV(C, C4NUM)); | |||||
| img_size->push_back(H); | |||||
| img_size->push_back(img_dtype); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs, | |||||
| OpParameter *opParameter, const lite::Context *ctx, | |||||
| const kernel::KernelKey &desc, const lite::PrimitiveC *primitive) { | |||||
| if (inputs.size() == 0) { | |||||
| MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); | |||||
| return nullptr; | |||||
| } | |||||
| if (inputs[0]->shape()[0] > 1) { | |||||
| MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch."; | |||||
| return nullptr; | |||||
| } | |||||
| auto *kernel = new (std::nothrow) PReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init PRelu kernel failed!"; | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Prelu, OpenCLPReluKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PRELU_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PRELU_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "src/ir/tensor.h" | |||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/runtime/opencl/opencl_runtime.h" | |||||
| namespace mindspore::kernel { | |||||
| class PReluOpenCLKernel : public OpenCLKernel { | |||||
| public: | |||||
| explicit PReluOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | |||||
| const std::vector<lite::tensor::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||||
| ~PReluOpenCLKernel() override{}; | |||||
| int Init() override; | |||||
| int Run() override; | |||||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||||
| private: | |||||
| cl::Kernel kernel_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PRELU_H_ | |||||
| @@ -151,6 +151,7 @@ if (SUPPORT_GPU) | |||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/reshape.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/reshape.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/to_format.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc | ${LITE_DIR}/src/runtime/kernel/opencl/kernel/caffe_prelu.cc | ||||
| ${LITE_DIR}/src/runtime/kernel/opencl/kernel/prelu.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| ### minddata lite | ### minddata lite | ||||
| @@ -327,6 +328,7 @@ if (SUPPORT_GPU) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/activation_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/to_format_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/to_format_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/caffe_prelu_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/caffe_prelu_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/prelu_tests.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -0,0 +1,185 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/arm/nnacl/prelu_parameter.h" | |||||
| using mindspore::kernel::LiteKernel; | |||||
| using mindspore::kernel::PReluOpenCLKernel; | |||||
| using mindspore::kernel::SubGraphOpenCLKernel; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| namespace mindspore { | |||||
| class TestPReluOpenCL : public mindspore::CommonTest {}; | |||||
| void LoadDataPRelu(void *dst, size_t dst_size, const std::string &file_path) { | |||||
| if (file_path.empty()) { | |||||
| memset(dst, 0x00, dst_size); | |||||
| } else { | |||||
| auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); | |||||
| memcpy(dst, src_data, dst_size); | |||||
| } | |||||
| } | |||||
| void CompareOutPRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { | |||||
| auto *output_data = reinterpret_cast<float *>(output_tensor->Data()); | |||||
| size_t output_size = output_tensor->Size(); | |||||
| auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); | |||||
| constexpr float atol = 0.0002; | |||||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); | |||||
| return; | |||||
| } | |||||
| } | |||||
| printf("compare success!\n"); | |||||
| printf("compare success!\n"); | |||||
| printf("compare success!\n\n\n"); | |||||
| } | |||||
| TEST_F(TestPReluOpenCL, PReluFp32_dim4) { | |||||
| std::string in_file = "/data/local/tmp/in_data.bin"; | |||||
| std::string standard_answer_file = "/data/local/tmp/leaky_relu.bin"; | |||||
| MS_LOG(INFO) << "-------------------->> Begin test PRelu!"; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| MS_LOG(INFO) << "Init tensors."; | |||||
| std::vector<int> input_shape = {1, 4, 3, 8}; | |||||
| auto data_type = kNumberTypeFloat32; | |||||
| auto tensor_type = schema::NodeType_ValueNode; | |||||
| auto input_tensor = | |||||
| new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||||
| if (input_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new input_tensor error!"; | |||||
| return; | |||||
| } | |||||
| auto output_tensor = | |||||
| new (std::nothrow) lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); | |||||
| if (output_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new output_tensor error"; | |||||
| delete input_tensor; | |||||
| return; | |||||
| } | |||||
| auto weight_tensor = | |||||
| new (std::nothrow) lite::tensor::Tensor(data_type, std::vector<int>{1}, schema::Format_NHWC, tensor_type); | |||||
| if (weight_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "new weight_tensor error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| return; | |||||
| } | |||||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor, weight_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||||
| // freamework to do!!! allocate memory by hand | |||||
| inputs[0]->MallocData(allocator); | |||||
| inputs[1]->MallocData(allocator); | |||||
| MS_LOG(INFO) << "initialize input data"; | |||||
| LoadDataPRelu(input_tensor->Data(), input_tensor->Size(), in_file); | |||||
| auto weight_data = reinterpret_cast<float *>(weight_tensor->Data()); | |||||
| weight_data[0] = 0.3; | |||||
| auto *input_data = reinterpret_cast<float *>(inputs[0]->Data()); | |||||
| PrintData("PRelu input data", input_data, inputs[0]->ElementsC4Num()); | |||||
| auto param = new (std::nothrow) PreluParameter(); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "new PreluParameter error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| return; | |||||
| } | |||||
| auto prelu_kernel = | |||||
| new (std::nothrow) kernel::PReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| if (prelu_kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new PReluOpenCLKernel error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| delete param; | |||||
| return; | |||||
| } | |||||
| auto ret = prelu_kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init prelu kernel error"; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "initialize sub_graph"; | |||||
| std::vector<kernel::LiteKernel *> kernels{prelu_kernel}; | |||||
| auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel({input_tensor}, outputs, kernels, kernels, kernels); | |||||
| if (sub_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "Create kernel sub_graph error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| delete param; | |||||
| delete prelu_kernel; | |||||
| return; | |||||
| } | |||||
| ret = sub_graph->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init sub graph error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| delete param; | |||||
| delete prelu_kernel; | |||||
| delete sub_graph; | |||||
| return; | |||||
| } | |||||
| ret = sub_graph->Run(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Run sub graph error"; | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| delete param; | |||||
| delete prelu_kernel; | |||||
| delete sub_graph; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "PRelu==================output data================"; | |||||
| auto *output_data = reinterpret_cast<float *>(outputs[0]->Data()); | |||||
| PrintData("output_data", output_data, outputs[0]->ElementsC4Num()); | |||||
| CompareOutPRelu(output_tensor, standard_answer_file); | |||||
| delete input_tensor; | |||||
| delete output_tensor; | |||||
| delete weight_tensor; | |||||
| delete param; | |||||
| delete prelu_kernel; | |||||
| delete sub_graph; | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | |||||
| } // namespace mindspore | |||||