Merge pull request !7067 from pengyongrong/padtags/v1.1.0
| @@ -516,7 +516,7 @@ __kernel void Concat4input_NC4HW4(__read_only image2d_t input0, __read_only imag | |||||
| DOConcat##Inputnum##Axis##ToFormat; \ | DOConcat##Inputnum##Axis##ToFormat; \ | ||||
| } | } | ||||
| // nc4hw4 ? | |||||
| // nc4hw4 | |||||
| CONCAT6(6input, axis1, _NC4HW4) | CONCAT6(6input, axis1, _NC4HW4) | ||||
| CONCAT6(6input, axis2, _NC4HW4) | CONCAT6(6input, axis2, _NC4HW4) | ||||
| CONCAT6(6input, axis3, _NC4HW4) | CONCAT6(6input, axis3, _NC4HW4) | ||||
| @@ -530,7 +530,7 @@ CONCAT2(2input, axis1, _NC4HW4) | |||||
| CONCAT2(2input, axis2, _NC4HW4) | CONCAT2(2input, axis2, _NC4HW4) | ||||
| CONCAT2(2input, axis3, _NC4HW4) | CONCAT2(2input, axis3, _NC4HW4) | ||||
| // nhwc4? | |||||
| // nhwc4 | |||||
| CONCAT6(6input, axis1, _NHWC4) | CONCAT6(6input, axis1, _NHWC4) | ||||
| CONCAT6(6input, axis2, _NHWC4) | CONCAT6(6input, axis2, _NHWC4) | ||||
| CONCAT6(6input, axis3, _NHWC4) | CONCAT6(6input, axis3, _NHWC4) | ||||
| @@ -0,0 +1,36 @@ | |||||
| #pragma OPENCL EXTENSION cl_khr_fp16 : enable | |||||
| __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; | |||||
| #define Pad(dataformat, in_x, in_y, out_x, out_y) \ | |||||
| __kernel void Pad_##dataformat(__read_only image2d_t input, __write_only image2d_t output, int4 input_shape, \ | |||||
| int4 output_shape, int2 pad, float constant_value) { \ | |||||
| int oh = get_global_id(0); \ | |||||
| int ow = get_global_id(1); \ | |||||
| int co_slice = get_global_id(2); \ | |||||
| int OH = output_shape.y; \ | |||||
| int OW = output_shape.z; \ | |||||
| int CO_SLICES = output_shape.w; \ | |||||
| \ | |||||
| if (oh >= OH || ow >= OW || co_slice >= CO_SLICES) { \ | |||||
| return; \ | |||||
| } \ | |||||
| \ | |||||
| int IH = input_shape.y; \ | |||||
| int IW = input_shape.z; \ | |||||
| int CI_SLICES = input_shape.w; \ | |||||
| \ | |||||
| int pad_top = pad.x; \ | |||||
| int pad_left = pad.y; \ | |||||
| int ih = oh - pad_top; \ | |||||
| int iw = ow - pad_left; \ | |||||
| \ | |||||
| FLT4 result = (FLT4)(constant_value); \ | |||||
| if (ih >= 0 && ih < IH && iw >= 0 && iw < IW) { \ | |||||
| result = READ_IMAGE(input, smp_zero, (int2)(in_x, in_y)); \ | |||||
| } \ | |||||
| WRITE_IMAGE(output, (int2)(out_x, out_y), result); \ | |||||
| } | |||||
| Pad(NHWC4, iw *CI_SLICES + co_slice, ih, ow *CO_SLICES + co_slice, oh); | |||||
| Pad(NC4HW4, iw, co_slice *IH + ih, ow, co_slice *OH + oh); | |||||
| @@ -0,0 +1,157 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <set> | |||||
| #include <algorithm> | |||||
| #include "src/common/utils.h" | |||||
| #include "src/runtime/kernel/opencl/kernel/pad.h" | |||||
| #include "src/runtime/kernel/opencl/utils.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/runtime/kernel/opencl/cl/pad.cl.inc" | |||||
| using mindspore::kernel::KERNEL_ARCH::kGPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PaddingMode_CONSTANT; | |||||
| using mindspore::schema::PrimitiveType_Pad; | |||||
| using mindspore::schema::Format::Format_NC4HW4; | |||||
| using mindspore::schema::Format::Format_NCHW; | |||||
| using mindspore::schema::Format::Format_NHWC; | |||||
| using mindspore::schema::Format::Format_NHWC4; | |||||
| namespace mindspore::kernel { | |||||
| int PadOpenCLKernel::Init() { | |||||
| auto param = reinterpret_cast<PadParameter *>(op_parameter_); | |||||
| std::set<std::string> build_options; | |||||
| if (op_format_ != Format_NHWC4 && op_format_ != Format_NC4HW4) { | |||||
| MS_LOG(ERROR) << "op_format_ " << op_format_ << " not support!"; | |||||
| } | |||||
| if (in_tensors_.empty()) { | |||||
| MS_LOG(ERROR) << "PadOpenCLKernel in_tensors is empty"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (out_tensors_.empty()) { | |||||
| MS_LOG(ERROR) << "PadOpenCLKernel out_tensors is empty"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param->paddings_[0] || param->paddings_[1] || param->paddings_[6] || param->paddings_[7]) { | |||||
| MS_LOG(ERROR) << "PadOpenCLKernel not support pad at Batch/Channel axis"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (param->pad_mode_ != PaddingMode_CONSTANT) { | |||||
| MS_LOG(ERROR) << "PadOpenCLKernel only support CONSTANT MODE"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto input_tensor = in_tensors_[0]; | |||||
| auto output_tensor = out_tensors_[0]; | |||||
| in_ori_format_ = input_tensor->GetFormat(); | |||||
| out_ori_format_ = output_tensor->GetFormat(); | |||||
| input_tensor->SetFormat(op_format_); | |||||
| output_tensor->SetFormat(op_format_); | |||||
| CI_ = input_tensor->Channel(); | |||||
| IH_ = input_tensor->Height(); | |||||
| IW_ = input_tensor->Width(); | |||||
| CO_ = output_tensor->Channel(); | |||||
| OH_ = output_tensor->Height(); | |||||
| OW_ = output_tensor->Width(); | |||||
| CI_SLICES_ = UP_DIV(CI_, C4NUM); | |||||
| CO_SLICES_ = UP_DIV(CO_, C4NUM); | |||||
| const std::string source = pad_source; | |||||
| const std::string kernel_name = op_format_ == Format_NHWC4 ? "Pad_NHWC4" : "Pad_NC4HW4"; | |||||
| const std::string &program_name = kernel_name; | |||||
| ocl_runtime_->LoadSource(program_name, source); | |||||
| ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); | |||||
| MS_LOG(DEBUG) << "Pad Init Done!"; | |||||
| return RET_OK; | |||||
| } | |||||
| int PadOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) { | |||||
| size_t im_dst_x, im_dst_y; | |||||
| if (in_tensors_[0]->GetFormat() == Format_NHWC4) { | |||||
| if (OW_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) { | |||||
| { | |||||
| im_dst_x = OW_ * CO_SLICES_; | |||||
| im_dst_y = OH_; | |||||
| } | |||||
| } else { | |||||
| im_dst_x = OH_ * CO_SLICES_; | |||||
| im_dst_y = OW_; | |||||
| } | |||||
| } else { | |||||
| im_dst_y = OH_ * CO_SLICES_; | |||||
| im_dst_x = OW_; | |||||
| } | |||||
| size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT; | |||||
| img_size->clear(); | |||||
| img_size->push_back(im_dst_x); | |||||
| img_size->push_back(im_dst_y); | |||||
| img_size->push_back(img_dtype); | |||||
| return RET_OK; | |||||
| } | |||||
| int PadOpenCLKernel::Run() { | |||||
| MS_LOG(DEBUG) << this->name() << " Running!"; | |||||
| auto param = reinterpret_cast<PadParameter *>(op_parameter_); | |||||
| cl_int4 input_shape = {1, IH_, IW_, CI_SLICES_}; | |||||
| cl_int4 output_shape = {1, OH_, OW_, CO_SLICES_}; | |||||
| cl_int2 pad_top_left = {param->paddings_[2], param->paddings_[4]}; | |||||
| int arg_cn = 0; | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::IMG); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, input_shape); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, output_shape); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, pad_top_left); | |||||
| ocl_runtime_->SetKernelArg(kernel_, arg_cn++, static_cast<cl_float>(param->constant_value_)); | |||||
| std::vector<size_t> global = {static_cast<size_t>(OH_), static_cast<size_t>(OW_), static_cast<size_t>(CO_SLICES_)}; | |||||
| std::vector<size_t> local = {8, 4, 1}; | |||||
| ocl_runtime_->RunKernel(kernel_, global, local, nullptr); | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *OpenCLPadKernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| auto *kernel = new (std::nothrow) PadOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "Create OpenCL Pad kernel failed!"; | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: Pad"; | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Pad, OpenCLPadKernelCreator) | |||||
| REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_Pad, OpenCLPadKernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_ | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "src/tensor.h" | |||||
| #include "src/runtime/kernel/opencl/opencl_kernel.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "nnacl/pad_parameter.h" | |||||
| namespace mindspore::kernel { | |||||
| class PadOpenCLKernel : public OpenCLKernel { | |||||
| public: | |||||
| explicit PadOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs) | |||||
| : OpenCLKernel(parameter, inputs, outputs) {} | |||||
| ~PadOpenCLKernel() override{}; | |||||
| int Init() override; | |||||
| int Run() override; | |||||
| int GetImageSize(size_t idx, std::vector<size_t> *img_size) override; | |||||
| private: | |||||
| int CI_{}; | |||||
| int IH_{}; | |||||
| int IW_{}; | |||||
| int CO_{}; | |||||
| int OH_{}; | |||||
| int OW_{}; | |||||
| int CI_SLICES_{}; | |||||
| int CO_SLICES_{}; | |||||
| cl::Kernel kernel_; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_PAD_H_ | |||||
| @@ -396,7 +396,7 @@ TEST_F(TestConcatOpenCLfp32, ConcatFp32_3input_dim4_axis1) { | |||||
| TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | ||||
| MS_LOG(INFO) << " begin test "; | MS_LOG(INFO) << " begin test "; | ||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | |||||
| ocl_runtime->SetFp16Enable(true); | ocl_runtime->SetFp16Enable(true); | ||||
| ocl_runtime->Init(); | ocl_runtime->Init(); | ||||
| auto allocator = ocl_runtime->GetAllocator(); | auto allocator = ocl_runtime->GetAllocator(); | ||||
| @@ -523,7 +523,6 @@ TEST_F(TestConcatOpenCLfp16, ConcatFp16_6input_dim4_axis1) { | |||||
| sub_graph->Run(); | sub_graph->Run(); | ||||
| auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData()); | auto *output_data_gpu = reinterpret_cast<float16_t *>(output_tensor->MutableData()); | ||||
| CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); | CompareOutputData1(output_data_gpu, correctOutput, output_tensor->ElementsNum(), 0.000001); | ||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| for (auto tensor : inputs) { | for (auto tensor : inputs) { | ||||
| tensor->SetData(nullptr); | tensor->SetData(nullptr); | ||||
| delete tensor; | delete tensor; | ||||
| @@ -0,0 +1,168 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/pad.h" | |||||
| #include "nnacl/pack.h" | |||||
| using mindspore::kernel::LiteKernel; | |||||
| using mindspore::kernel::PadOpenCLKernel; | |||||
| using mindspore::kernel::SubGraphOpenCLKernel; | |||||
| using mindspore::lite::Tensor; | |||||
| using mindspore::schema::Format; | |||||
| using mindspore::schema::Format_NC4HW4; | |||||
| using mindspore::schema::Format_NHWC; | |||||
| using mindspore::schema::Format_NHWC4; | |||||
| using mindspore::schema::NodeType_ValueNode; | |||||
| using mindspore::schema::PaddingMode; | |||||
| using mindspore::schema::PaddingMode_CONSTANT; | |||||
| using mindspore::schema::PaddingMode_REFLECT; | |||||
| using mindspore::schema::PaddingMode_SYMMETRIC; | |||||
| namespace mindspore { | |||||
| class TestPadOpenCL : public mindspore::CommonTest {}; | |||||
| void TEST_MAIN(PadParameter *param, Format input_format, Format output_format, Format op_format, const TypeId data_type, | |||||
| const std::vector<int> &input_shape, const std::vector<int> &output_shape, const float *input_data, | |||||
| const float *expect_data) { | |||||
| auto ocl_runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); | |||||
| auto ocl_runtime = ocl_runtime_wrapper.GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| ocl_runtime->SetFp16Enable(data_type == kNumberTypeFloat16); | |||||
| auto allocator = ocl_runtime->GetAllocator(); | |||||
| MS_LOG(DEBUG) << "create Tensors"; | |||||
| auto input = Tensor(kNumberTypeFloat32, input_shape, input_format, lite::TensorCategory(NodeType_ValueNode)); | |||||
| auto output = Tensor(kNumberTypeFloat32, output_shape, output_format, lite::TensorCategory(NodeType_ValueNode)); | |||||
| MS_LOG(DEBUG) << "create OpenCL Kernel"; | |||||
| std::vector<lite::Tensor *> inputs{&input}; | |||||
| std::vector<lite::Tensor *> outputs{&output}; | |||||
| auto kernel = std::make_unique<PadOpenCLKernel>(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| if (kernel == nullptr) { | |||||
| return; | |||||
| } | |||||
| kernel->SetFormatType(op_format); | |||||
| kernel->Init(); | |||||
| MS_LOG(DEBUG) << "create SubGraph"; | |||||
| std::vector<kernel::LiteKernel *> kernels{kernel.release()}; | |||||
| auto sub_graph = new (std::nothrow) SubGraphOpenCLKernel({&input}, {&output}, kernels, kernels, kernels); | |||||
| input.MallocData(allocator); | |||||
| sub_graph->Init(); | |||||
| memcpy(input.data_c(), input_data, input.Size()); | |||||
| sub_graph->Run(); | |||||
| if (lite::CompareOutputData(reinterpret_cast<float *>(output.data_c()), output.ElementsNum(), | |||||
| const_cast<float *>(expect_data), output.ElementsNum())) { | |||||
| FAIL(); | |||||
| } else { | |||||
| std::cout << "COMPARE SUCCESS!\n"; | |||||
| } | |||||
| MS_LOG(DEBUG) << "release resources"; | |||||
| input.SetData(nullptr); | |||||
| output.SetData(nullptr); | |||||
| delete sub_graph; | |||||
| } | |||||
| TEST_F(TestPadOpenCL, TestPad3) { | |||||
| auto param = static_cast<PadParameter *>(malloc(sizeof(PadParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "PadParameter create error."; | |||||
| return; | |||||
| } | |||||
| param->pad_mode_ = PaddingMode_CONSTANT; | |||||
| param->constant_value_ = 0.0f; | |||||
| param->padding_length = MAX_PAD_SIZE; | |||||
| int paddings[MAX_PAD_SIZE] = {0, 0, 3, 3, 3, 3, 0, 0}; | |||||
| memcpy(param->paddings_, paddings, sizeof(paddings)); | |||||
| float input_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, | |||||
| 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, | |||||
| 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, | |||||
| 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0}; | |||||
| float expect_data[300] = { | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, | |||||
| 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 36.0, | |||||
| 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | |||||
| TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data, | |||||
| expect_data); | |||||
| TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NC4HW4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3}, | |||||
| input_data, expect_data); | |||||
| TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat16, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data, | |||||
| expect_data); | |||||
| TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NC4HW4, kNumberTypeFloat16, {1, 4, 4, 3}, {1, 10, 10, 3}, | |||||
| input_data, expect_data); | |||||
| } | |||||
| TEST_F(TestPadOpenCL, TestPad4) { | |||||
| auto param = static_cast<PadParameter *>(malloc(sizeof(PadParameter))); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "PadParameter create error."; | |||||
| return; | |||||
| } | |||||
| param->pad_mode_ = PaddingMode_CONSTANT; | |||||
| param->constant_value_ = 1.0f; | |||||
| param->padding_length = MAX_PAD_SIZE; | |||||
| int paddings[MAX_PAD_SIZE] = {0, 0, 3, 3, 3, 3, 0, 0}; | |||||
| memcpy(param->paddings_, paddings, sizeof(paddings)); | |||||
| float input_data[48] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, | |||||
| 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, | |||||
| 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, | |||||
| 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0}; | |||||
| float expect_data[300] = { | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 12.0, 13.0, 14.0, 15.0, | |||||
| 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 36.0, | |||||
| 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, | |||||
| 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; | |||||
| TEST_MAIN(param, Format_NHWC, Format_NHWC, Format_NHWC4, kNumberTypeFloat32, {1, 4, 4, 3}, {1, 10, 10, 3}, input_data, | |||||
| expect_data); | |||||
| } | |||||
| } // namespace mindspore | |||||