From: @yeyunpeng2020 Reviewed-by: @ddwsky,@HilbertDavid Signed-off-by: @ddwskypull/15440/MERGE
| @@ -79,4 +79,5 @@ IMG_to_BUF(float16, float32, half, float, read_imageh); | |||
| IMG_to_BUF(float16, float16, half, half, read_imageh); | |||
| IMG_to_BUF(int32, int32, int, int, read_imagei); | |||
| IMG_to_BUF(uint32, uint32, int, int, read_imagei); | |||
| IMG_to_BUF(int32, float32, int, float, read_imagei); | |||
| IMG_to_BUF(int8, int8, char, char, read_imagei); | |||
| @@ -40,7 +40,8 @@ int ArgMinMaxOpenCLKernel::CheckSpecs() { | |||
| } | |||
| if ((in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16) || | |||
| (out_tensors_[0]->data_type() != kNumberTypeFloat32 && out_tensors_[0]->data_type() != kNumberTypeFloat16)) { | |||
| MS_LOG(ERROR) << "Unsupported data type " << in_tensors_[0]->data_type(); | |||
| MS_LOG(ERROR) << "Unsupported input/output data type. input data type is " << in_tensors_[0]->data_type() | |||
| << " output data type is " << out_tensors_[0]->data_type(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_[0]->shape().size() > 4 && in_tensors_[0]->shape().size() == 0) { | |||
| @@ -35,7 +35,7 @@ int BatchNormOpenCLKernel::CheckSpecs() { | |||
| MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (in_tensors_.at(0)->shape().size() == 4) { | |||
| if (in_tensors_.at(0)->shape().size() != 4) { | |||
| MS_LOG(ERROR) << "The dim of in_tensors->shape must be 4 but your dim is : " << in_tensors_.at(0)->shape().size(); | |||
| return RET_ERROR; | |||
| } | |||
| @@ -37,7 +37,7 @@ int FillOpenCLKernel::RunFill() { | |||
| auto param = reinterpret_cast<FillParameter *>(this->op_parameter_); | |||
| default_ = param->num_dims_; | |||
| ImageSize img_size; | |||
| cl_float4 fill_value = {}; | |||
| cl_int4 fill_value = {}; | |||
| fill_value.s[0] = fill_value.s[1] = fill_value.s[2] = fill_value.s[3] = default_; | |||
| auto src_data = out_tensors_[0]->data_c(); | |||
| allocator_->GetImageSize(src_data, &img_size); | |||
| @@ -51,11 +51,11 @@ int FillOpenCLKernel::RunFill() { | |||
| int FillOpenCLKernel::RunShape() { | |||
| auto allocator_ = ocl_runtime_->GetAllocator(); | |||
| auto src_data = out_tensors_[0]->data_c(); | |||
| cl_float4 fill_value = {default_, default_, default_, default_}; | |||
| cl_int4 fill_value = {default_, default_, default_, default_}; | |||
| auto tensor_shape = in_tensors_[0]->shape(); | |||
| void *tensor_shape_data = tensor_shape.data(); | |||
| for (int i = 0; i < tensor_shape.size(); ++i) { | |||
| fill_value.s[i] = reinterpret_cast<float *>(tensor_shape_data)[i]; | |||
| fill_value.s[i] = reinterpret_cast<int *>(tensor_shape_data)[i]; | |||
| } | |||
| auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0}; | |||
| auto region = cl::array<cl::size_type, 3U>{1, 1, 1}; | |||
| @@ -39,7 +39,7 @@ class FillOpenCLKernel : public OpenCLKernel { | |||
| private: | |||
| int RunFill(); | |||
| int RunShape(); | |||
| float default_{0.0f}; | |||
| int default_{0}; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -193,6 +193,17 @@ int GatherOpenCLKernel::InitWeights() { | |||
| return RET_OK; | |||
| } | |||
| int GatherOpenCLKernel::PreProcess() { | |||
| if (!op_parameter_->infer_flag_) { | |||
| auto indices_tensor = in_tensors_[1]; | |||
| if (!indices_tensor->IsConst()) { | |||
| ocl_runtime_->SyncCommandQueue(); | |||
| indices_tensor->MutableData(); | |||
| } | |||
| } | |||
| return OpenCLKernel::PreProcess(); | |||
| } | |||
| int GatherOpenCLKernel::Run() { | |||
| MS_LOG(DEBUG) << this->name() << " Running! "; | |||
| if (intensor1_is_tensor) { | |||
| @@ -32,6 +32,7 @@ class GatherOpenCLKernel : public OpenCLKernel { | |||
| int Run() override; | |||
| int InitWeights() override; | |||
| int Prepare() override; | |||
| int PreProcess() override; | |||
| int CheckSpecs() override; | |||
| void SetConstArgs() override; | |||
| @@ -48,8 +48,6 @@ int StackOpenCLKernel::RunAxis0() { | |||
| return RET_OK; | |||
| } | |||
| int StackOpenCLKernel::ReSize() { return RET_OK; } | |||
| void StackGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) { | |||
| const int max_divider = 8; | |||
| const int max_x = 4, max_y = 8; | |||
| @@ -33,8 +33,6 @@ class StackOpenCLKernel : public OpenCLKernel { | |||
| void SetConstArgs() override; | |||
| void SetGlobalLocal() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| private: | |||
| @@ -0,0 +1,72 @@ | |||
| #!/bin/bash | |||
| basepath=$(pwd) | |||
| echo ${basepath} | |||
| # Example:sh run_benchmark_nets.sh -r /home/temp_test -d "8KE5T19620002408" | |||
| while getopts "r:d:" opt; do | |||
| case ${opt} in | |||
| r) | |||
| release_path=${OPTARG} | |||
| echo "release_path is ${OPTARG}" | |||
| ;; | |||
| d) | |||
| device_id=${OPTARG} | |||
| echo "device_id is ${OPTARG}" | |||
| ;; | |||
| ?) | |||
| echo "unknown para" | |||
| exit 1;; | |||
| esac | |||
| done | |||
| ut_test_path=${basepath}/ut_test | |||
| rm -rf ${ut_test_path} | |||
| mkdir -p ${ut_test_path} | |||
| run_ut_result_file=${basepath}/run_benchmark_result.txt | |||
| echo ' ' > ${run_ut_result_file} | |||
| run_gpu_ut_log_file=${basepath}/run_gpu_ut_log.txt | |||
| echo 'run gpu ut logs: ' > ${run_gpu_ut_log_file} | |||
| ut_gpu_config=${basepath}/ut_gpu.cfg | |||
| function Run_gpu_ut() { | |||
| cd ${release_path} || exit 1 | |||
| cp -a ${release_path}/lite-test ${ut_test_path}/lite-test || exit 1 | |||
| cp -r ${basepath}/ut/src/runtime/kernel/opencl/test_data ${ut_test_path} || exit 1 | |||
| # adb push all needed files to the phone | |||
| adb -s ${device_id} push ${ut_test_path} /data/local/tmp/ > adb_push_log.txt | |||
| # run adb ,run session ,check the result: | |||
| echo 'rm -rf /data/local/tmp/ut_test' > adb_cmd.txt | |||
| echo 'cd /data/local/tmp/ut_test' > adb_cmd.txt | |||
| echo 'cp /data/local/tmp/libc++_shared.so ./' >> adb_cmd.txt | |||
| echo 'cp /data/local/tmp/libgtest.so ./' >> adb_cmd.txt | |||
| echo 'chmod 777 lite-test' >> adb_cmd.txt | |||
| adb -s ${device_id} shell < adb_cmd.txt | |||
| # Run npu converted models: | |||
| while read line; do | |||
| echo 'cd /data/local/tmp/ut_test' > adb_run_cmd.txt | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/ut_test;./lite-test --gtest_filter='${line} >> "${run_gpu_ut_log_file}" | |||
| echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/ut_test;./lite-test --gtest_filter='${line} >> adb_run_cmd.txt | |||
| adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_gpu_ut_log_file}" | |||
| if [ $? = 0 ]; then | |||
| run_result='arm64_gpu_ut: '${line}' pass'; echo ${run_result} >> ${run_ut_result_file} | |||
| else | |||
| run_result='arm64_gpu_ut: '${line}' failed'; echo ${run_result} >> ${run_ut_result_file}; return 1 | |||
| fi | |||
| done < ${ut_gpu_config} | |||
| } | |||
| Run_gpu_ut | |||
| Run_gpu_ut_status=$? | |||
| if [[ $Run_gpu_ut_status == 1 ]]; then | |||
| exit 1 | |||
| fi | |||
| exit 0 | |||
| @@ -31,6 +31,7 @@ OpParameter *CreateParameter(schema::PrimitiveType type, int axis, int topk, boo | |||
| param->axis_type_ = axis_type; | |||
| param->out_value_ = out_value; | |||
| param->keep_dims_ = keep_dims; | |||
| reinterpret_cast<OpParameter *>(param)->infer_flag_ = true; | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| } // namespace | |||
| @@ -157,44 +157,6 @@ TEST_F(TestOpenCL_Arithmetic, FloorMod) { | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_Arithmetic, FloorModFile) { | |||
| std::vector<int> input0_shape = {1, 3, 4, 5}; | |||
| std::vector<int> input1_shape = {1, 3, 4, 5}; | |||
| std::vector<int> output_shape = {1, 3, 4, 5}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/FloodModfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/FloodModfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/FloodModfp32_output.bin"; | |||
| auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| for (auto fp16_enable : {true}) { | |||
| auto *param = CreateParameter(schema::PrimitiveType_FloorMod, input0_shape, input1_shape); | |||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||
| param, fp16_enable, fp16_enable ? 1e-2 : 1e-7); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_Arithmetic, SquaredDifference) { | |||
| std::vector<int> input0_shape = {1, 512, 1, 5}; | |||
| std::vector<int> input1_shape = {1, 1, 1, 5}; | |||
| std::vector<int> output_shape = {1, 512, 1, 5}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/SquaredDifferencefp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/SquaredDifferencefp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/SquaredDifferencefp32_output.bin"; | |||
| auto input0_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input1_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| for (auto fp16_enable : {true}) { | |||
| auto *param = CreateParameter(schema::PrimitiveType_SquaredDifference, input0_shape, input1_shape); | |||
| TestMain({{input0_shape, input0_data, VAR}, {input1_shape, input1_data, CONST_TENSOR}}, {output_shape, output_data}, | |||
| param, fp16_enable, fp16_enable ? 1e-2 : 1e-9); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_Arithmetic, ElementwiseDiv) { | |||
| std::vector<int> input0_shape = {1, 2, 2, 3}; | |||
| std::vector<int> input1_shape = {1, 2, 2, 3}; | |||
| @@ -25,7 +25,7 @@ namespace { | |||
| // PrimitiveType_BatchToSpaceND: src/ops/populate/batch_to_space_populate.cc | |||
| OpParameter *CreateParameter(int block_shape[], int crops[], const std::vector<int> &input_shape, | |||
| std::vector<int> *output_shape) { | |||
| auto *param = test::CreateParameter<BatchToSpaceParameter>(schema::PrimitiveType_BatchToSpaceND); | |||
| auto *param = test::CreateParameter<BatchToSpaceParameter>(schema::PrimitiveType_BatchToSpace); | |||
| memcpy(param->block_shape_, block_shape, sizeof(param->block_shape_)); | |||
| memcpy(param->crops_, crops, sizeof(param->crops_)); | |||
| *output_shape = {input_shape[0] / param->block_shape_[0] / param->block_shape_[1], | |||
| @@ -38,7 +38,7 @@ void TestMain(const std::vector<ArgsTuple> &input_infos, const std::vector<ArgsT | |||
| TestMain(input_infos_new, output_info, op_parameter, fp16_enable, atol, rtol, print_data); | |||
| } | |||
| void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOut> &output_info, | |||
| void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOutWithDType> &output_info, | |||
| OpParameter *op_parameter, bool fp16_enable, float atol, float rtol, bool print_data) { | |||
| auto primitive_type = static_cast<schema::PrimitiveType>(op_parameter->type_); | |||
| #ifdef ENABLE_V0 | |||
| @@ -71,7 +71,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec | |||
| } | |||
| for (auto outout_info : output_info) { | |||
| const std::vector<int> &output_shape = std::get<0>(outout_info); | |||
| out_tensors.emplace_back(std::make_shared<Tensor>(kNumberTypeFloat32, output_shape, Format_NHWC, VAR)); | |||
| out_tensors.emplace_back(std::make_shared<Tensor>(std::get<2>(outout_info), output_shape, Format_NHWC, VAR)); | |||
| } | |||
| // secondly, init weight Tensor's data | |||
| std::vector<Tensor *> kernel_inputs; | |||
| @@ -180,6 +180,16 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec | |||
| } | |||
| delete sub_graph; | |||
| } | |||
| void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOut> &output_info, | |||
| OpParameter *op_parameter, bool fp16_enable, float atol, float rtol, bool print_data) { | |||
| std::vector<ArgsTupleOutWithDType> output_info_new; | |||
| auto transform_fun = [](ArgsTupleOut in) -> ArgsTupleOutWithDType { | |||
| return ArgsTupleOutWithDType(std::get<0>(in), std::get<1>(in), kNumberTypeFloat32); | |||
| }; | |||
| std::transform(output_info.begin(), output_info.end(), std::back_inserter(output_info_new), transform_fun); | |||
| TestMain(input_infos, output_info_new, op_parameter, fp16_enable, atol, rtol, print_data); | |||
| } | |||
| // single-output | |||
| void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std::vector<int>, float *> output_info, | |||
| @@ -32,6 +32,7 @@ | |||
| using Tensor = mindspore::lite::Tensor; | |||
| using ArgsTuple = std::tuple<std::vector<int>, void *, Tensor::Category>; | |||
| using ArgsTupleOut = std::tuple<std::vector<int>, void *>; | |||
| using ArgsTupleOutWithDType = std::tuple<std::vector<int>, void *, mindspore::TypeId>; | |||
| using ArgsTupleWithDtype = std::tuple<std::vector<int>, void *, Tensor::Category, mindspore::TypeId>; | |||
| constexpr Tensor::Category VAR = Tensor::VAR; | |||
| constexpr Tensor::Category CONST_TENSOR = Tensor::Category::CONST_TENSOR; | |||
| @@ -94,6 +95,10 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec | |||
| OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9, | |||
| bool print_output = false); | |||
| void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vector<ArgsTupleOutWithDType> &output_info, | |||
| OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9, | |||
| bool print_output = false); | |||
| void TestMain(const std::vector<ArgsTuple> &input_infos, const std::vector<ArgsTupleOut> &output_info, | |||
| OpParameter *op_parameter, bool fp16_enable = false, float atol = 1e-9, float rtol = 1e-9, | |||
| bool print_output = false); | |||
| @@ -127,6 +127,7 @@ TEST_F(TestOpenCL_Conv2D, test3_batch2) { | |||
| TestMain_Conv2D(attr, input_data, weight_data, bias_data, output_data, ActType_No, true, 1e-6f); | |||
| } | |||
| // Check and optimize | |||
| TEST_F(TestOpenCL_Conv2D, test4) { | |||
| std::vector<std::tuple<std::string, std::string, std::vector<float>, std::vector<float>, std::vector<float>, | |||
| std::vector<float>, ActType>> | |||
| @@ -19,7 +19,7 @@ | |||
| namespace mindspore::lite::opencl::test { | |||
| class TestOpenCL_Conv2dTranspose : public CommonTest {}; | |||
| // Check and optimize | |||
| namespace { | |||
| // PrimitiveType_DeConv2D: src/ops/populate/deconv2d_populate.cc | |||
| OpParameter *CreateParameter(int n, int h, int w, int ci, int co, int kh, int kw, std::vector<int> pad, int oh, int ow, | |||
| @@ -21,6 +21,7 @@ namespace mindspore::lite::opencl::test { | |||
| class TestOpenCL_DepthwiseConv2d : public CommonTest {}; | |||
| namespace { | |||
| // Check and optimize | |||
| // PrimitiveType_DepthwiseConv2D: src/ops/populate/depthwise_conv2d_populate.cc | |||
| OpParameter *CreateParameter(int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_u, int pad_d, int pad_l, | |||
| int pad_r, int dilation_h, int dilation_w, ActType act_type, int input_channel) { | |||
| @@ -41,10 +41,10 @@ TEST_F(TestOpenCL_LayerNorm, test1) { | |||
| std::vector<int> beta_shape = {1, 1, 1, 5}; | |||
| std::vector<int> output_shape = {2, 3, 4, 5}; | |||
| size_t input_size, gamma_size, beta_size, output_size; | |||
| std::string inputPpath = "./test_data/layernormfp32_input.bin"; | |||
| std::string gammaPpath = "./test_data/gammafp32_input.bin"; | |||
| std::string betaPpath = "./test_data/betafp32_input.bin"; | |||
| std::string correctOutputPath = "./test_data/layernormfp32_output.bin"; | |||
| std::string inputPpath = "./test_data/layer_norm/test1/layernormfp32_input.bin"; | |||
| std::string gammaPpath = "./test_data/layer_norm/test1/gammafp32_input.bin"; | |||
| std::string betaPpath = "./test_data/layer_norm/test1/betafp32_input.bin"; | |||
| std::string correctOutputPath = "./test_data/layer_norm/test1/layernormfp32_output.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size)); | |||
| auto gamma_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(gammaPpath.c_str(), &gamma_size)); | |||
| auto beta_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(betaPpath.c_str(), &beta_size)); | |||
| @@ -32,25 +32,6 @@ OpParameter *CreateParameter(bool a_transpose = false, bool b_transpose = true) | |||
| } | |||
| } // namespace | |||
| TEST_F(TestOpenCL_MatMul, 2Dfile) { | |||
| std::vector<int> input_shape = {64, 64}; | |||
| std::vector<int> output_shape = {64, 64}; | |||
| std::vector<int> weight_shape = {64, 64}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/matmulfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/matmulfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/matmulfp32_output.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| for (auto fp16_enable : {false}) { | |||
| auto *param = CreateParameter(false, false); | |||
| TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, CONST_TENSOR}}, {output_shape, output_data}, | |||
| param, fp16_enable, fp16_enable ? 1e-3 : 1e-3); | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_MatMul, 2D) { | |||
| int ci = 5; | |||
| int co = 3; | |||
| @@ -22,18 +22,10 @@ class TestOpenCL_Pad : public CommonTest {}; | |||
| namespace { | |||
| // PrimitiveType_Pad: src/ops/populate/pad_populate.cc | |||
| OpParameter *CreateParameter(const std::vector<int> &paddings, float constant_value) { | |||
| OpParameter *CreateParameter(float constant_value) { | |||
| auto *param = test::CreateParameter<PadParameter>(schema::PrimitiveType_PadFusion); | |||
| param->pad_mode_ = schema::PaddingMode_CONSTANT; | |||
| param->constant_value_ = constant_value; | |||
| param->padding_length = MAX_PAD_SIZE; | |||
| int size = paddings.size(); | |||
| for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) { | |||
| param->paddings_[i] = 0; | |||
| } | |||
| for (size_t i = 0; i < size; i++) { | |||
| param->paddings_[MAX_PAD_SIZE - size + i] = paddings[i]; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| } // namespace | |||
| @@ -42,8 +34,10 @@ TEST_F(TestOpenCL_Pad, 1D) { | |||
| float input_data[] = {1, 1, 1, 1}; | |||
| float output_data[] = {2, 2, 2, 1, 1, 1, 1, 2, 2}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({3, 2}, 2); | |||
| TestMain({{{4}, input_data, VAR}}, {{9}, output_data}, param, fp16_enable); | |||
| auto *param = CreateParameter(2); | |||
| int padding[] = {3, 2}; | |||
| TestMain({{{4}, input_data, VAR, kNumberTypeFloat32}, {{1, 2}, padding, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{9}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -52,8 +46,10 @@ TEST_F(TestOpenCL_Pad, 2D) { | |||
| float output_data[] = {10, 10, 10, 10, 10, 10, 10, 10, 10, 1, 1, 1, 1, 1, 10, 10, | |||
| 10, 2, 2, 2, 2, 2, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({1, 1, 1, 2}, 10); | |||
| TestMain({{{2, 5}, input_data, VAR}}, {{4, 8}, output_data}, param, fp16_enable); | |||
| int padding[] = {1, 1, 1, 2}; | |||
| auto *param = CreateParameter(10); | |||
| TestMain({{{2, 5}, input_data, VAR, kNumberTypeFloat32}, {{2, 2}, padding, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{4, 8}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -73,8 +69,10 @@ TEST_F(TestOpenCL_Pad, 4D) { | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({0, 0, 3, 3, 3, 3, 0, 0}, 0); | |||
| TestMain({{{1, 4, 4, 3}, input_data, VAR}}, {{1, 10, 10, 3}, output_data}, param, fp16_enable); | |||
| auto *param = CreateParameter(0); | |||
| int padding[] = {0, 0, 3, 3, 3, 3, 0, 0}; | |||
| TestMain({{{1, 4, 4, 3}, input_data, VAR, kNumberTypeFloat32}, {{4, 2}, padding, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 10, 10, 3}, output_data}, param, fp16_enable); | |||
| } | |||
| float output_data1[] = { | |||
| @@ -89,8 +87,10 @@ TEST_F(TestOpenCL_Pad, 4D) { | |||
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |||
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({0, 0, 3, 3, 3, 3, 0, 0}, 1); | |||
| TestMain({{{1, 4, 4, 3}, input_data, VAR}}, {{1, 10, 10, 3}, output_data1}, param, fp16_enable); | |||
| auto *param = CreateParameter(1); | |||
| int padding[] = {0, 0, 3, 3, 3, 3, 0, 0}; | |||
| TestMain({{{1, 4, 4, 3}, input_data, VAR, kNumberTypeFloat32}, {{4, 2}, padding, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 10, 10, 3}, output_data1}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -224,8 +224,10 @@ TEST_F(TestOpenCL_Pad, test0) { | |||
| auto constant_value = std::get<6>(case_); | |||
| std::cout << name << std::endl; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(paddings, constant_value); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| auto *param = CreateParameter(constant_value); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(paddings.size() / 2), 2}, paddings.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| } | |||
| @@ -21,6 +21,7 @@ namespace mindspore::lite::opencl::test { | |||
| class TestOpenCL_PRrelu : public CommonTest {}; | |||
| namespace { | |||
| // Check and optimize | |||
| // PrimitiveType_PReLU: src/ops/populate/p_relu_populate.cc | |||
| OpParameter *CreateParameter() { | |||
| auto *param = test::CreateParameter<PReluParameter>(schema::PrimitiveType_PReLUFusion); | |||
| @@ -30,10 +30,11 @@ TEST_F(TestOpenCL_Shape, test0) { | |||
| std::vector<int> input_shape = {2, 4}; | |||
| std::vector<int> output_shape = {2}; | |||
| float input_data[] = {-0.4045, -0.0924, -0.617, -0.10114, -0.9893, 0.3342, 2.445, -2.182}; | |||
| float output_data[] = {2, 4}; | |||
| int output_data[] = {2, 4}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}}, {{output_shape, output_data, kNumberTypeInt32}}, | |||
| param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -22,12 +22,10 @@ class TestOpenCL_Slice : public CommonTest {}; | |||
| namespace { | |||
| // PrimitiveType_Slice: src/ops/populate/slice_populate.cc | |||
| OpParameter *CreateParameter(const std::vector<int> &begin, const std::vector<int> &size) { | |||
| OpParameter *CreateParameter(const std::vector<int> &axis) { | |||
| auto *param = test::CreateParameter<SliceParameter>(schema::PrimitiveType_SliceFusion); | |||
| param->param_length_ = begin.size(); | |||
| for (int i = 0; i < begin.size(); ++i) { | |||
| param->begin_[i] = begin[i]; | |||
| param->size_[i] = size[i]; | |||
| for (int i = 0; i < axis.size(); ++i) { | |||
| param->axis_[i] = axis[i]; | |||
| } | |||
| return reinterpret_cast<OpParameter *>(param); | |||
| } | |||
| @@ -42,10 +40,16 @@ TEST_F(TestOpenCL_Slice, 4D) { | |||
| float output_data[] = {-0.9135602, -1.4002057, 1.1080881, 0.40712625, -0.28128958, -1.2808367, 0.1470597, | |||
| 0.03393711, -0.33282498, -1.0433807, 0.28965706, 0.5343769, 0.75480366, -1.9328151, | |||
| -0.48714373, -0.14000037, -0.080552, 0.95056856, -0.06886655, 0.5316237}; | |||
| auto param = CreateParameter({0, 0, 0, 2}, {1, 2, 2, 5}); | |||
| TestMain({{{1, 2, 2, 8}, input_data, VAR}}, {{1, 2, 2, 5}, output_data}, param, false); | |||
| auto param = CreateParameter({0, 1, 2, 3}); | |||
| std::vector<int> begin = {0, 0, 0, 2}; | |||
| std::vector<int> size = {1, 2, 2, 5}; | |||
| TestMain({{{1, 2, 2, 8}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(size.size())}, size.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 2, 2, 5}, output_data}, param, false); | |||
| } | |||
| // Check and optimize(fp16) | |||
| TEST_F(TestOpenCL_Slice, test0) { | |||
| std::vector<std::tuple<std::string, std::vector<int>, std::vector<int>, std::vector<float>, std::vector<float>, | |||
| std::vector<int>, std::vector<int>>> | |||
| @@ -148,11 +152,18 @@ TEST_F(TestOpenCL_Slice, test0) { | |||
| auto &size = std::get<6>(case_); | |||
| std::cout << name << std::endl; | |||
| auto *param = CreateParameter(begin, size); | |||
| TestMain({{input_shape, input_data.data(), VAR}}, {output_shape, output_data.data()}, param, false); | |||
| param = CreateParameter(begin, size); | |||
| TestMain({{input_shape, input_data.data(), VAR}}, {output_shape, output_data.data()}, param, true); | |||
| std::vector<int> axis(input_shape.size()); | |||
| for (int i = 0; i < input_shape.size(); ++i) { | |||
| axis[i] = i; | |||
| } | |||
| auto *param = CreateParameter(axis); | |||
| for (auto fp16_enable : {false}) { | |||
| TestMain({{input_shape, input_data.data(), VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(size.size())}, size.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data.data()}, param, fp16_enable); | |||
| } | |||
| } | |||
| } // namespace mindspore | |||
| } | |||
| } // namespace mindspore::lite::opencl::test | |||
| @@ -19,7 +19,7 @@ | |||
| namespace mindspore::lite::opencl::test { | |||
| class TestOpenCL_SparseToDense : public CommonTest {}; | |||
| // Check and optimize | |||
| namespace { | |||
| // PrimitiveType_SparseToDense: src/ops/populate/sparse_to_dense_populate.cc | |||
| OpParameter *CreateParameter() { | |||
| @@ -26,6 +26,7 @@ OpParameter *CreateParameter(int split_dim_, int num_split_, std::vector<int> sp | |||
| auto *param = test::CreateParameter<SplitParameter>(schema::PrimitiveType_Split); | |||
| param->split_dim_ = split_dim_; | |||
| param->num_split_ = num_split_; | |||
| param->split_count_ = num_split_; | |||
| param->split_sizes_ = reinterpret_cast<int *>(malloc(param->num_split_ * sizeof(int))); | |||
| for (int i = 0; i < param->num_split_; ++i) { | |||
| param->split_sizes_[i] = split_sizes_[i]; | |||
| @@ -34,6 +35,7 @@ OpParameter *CreateParameter(int split_dim_, int num_split_, std::vector<int> sp | |||
| } | |||
| } // namespace | |||
| // Check and optimize(No data file) | |||
| TEST_F(TestOpenCL_Split, input2_axis3) { | |||
| std::vector<int> input_shape = {2, 2, 2, 12}; | |||
| std::vector<int> output_shape1 = {2, 2, 2, 6}; | |||
| @@ -72,9 +72,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis1) { | |||
| std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 5}, {3, 4, 5}}; | |||
| std::vector<int> output_shape = {3, 2, 4, 5}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stackfp32_output.bin"; | |||
| std::string input1Ppath = "./test_data/stack/input2_ndim3_axis1/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stack/input2_ndim3_axis1/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis1/stackfp32_output.bin"; | |||
| auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| @@ -91,9 +91,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis2) { | |||
| std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 5}, {3, 4, 5}}; | |||
| std::vector<int> output_shape = {3, 4, 2, 5}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stackfp32_output.bin"; | |||
| std::string input1Ppath = "./test_data/stack/input2_ndim3_axis2/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stack/input2_ndim3_axis2/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis2/stackfp32_output.bin"; | |||
| auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| @@ -110,9 +110,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim2_axis2) { | |||
| std::vector<int> input_shapes[INPUT_NUM] = {{1, 96}, {1, 96}}; | |||
| std::vector<int> output_shape = {1, 96, 2}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stackfp32_output.bin"; | |||
| std::string input1Ppath = "./test_data/stack/input2_ndim2_axis2/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stack/input2_ndim2_axis2/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stack/input2_ndim2_axis2/stackfp32_output.bin"; | |||
| auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| @@ -129,9 +129,9 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis3) { | |||
| std::vector<int> input_shapes[INPUT_NUM] = {{3, 4, 6}, {3, 4, 6}}; | |||
| std::vector<int> output_shape = {3, 4, 6, 2}; | |||
| size_t input1_size, input2_size, output_size; | |||
| std::string input1Ppath = "./test_data/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stackfp32_output.bin"; | |||
| std::string input1Ppath = "./test_data/stack/input2_ndim3_axis3/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stack/input2_ndim3_axis3/stackfp32_input2.bin"; | |||
| std::string correctOutputPath = "./test_data/stack/input2_ndim3_axis3/stackfp32_output.bin"; | |||
| auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| @@ -142,44 +142,4 @@ TEST_F(TestOpenCL_Stack, input2_ndim3_axis3) { | |||
| } | |||
| } | |||
| TEST_F(TestOpenCL_Stack, input6_ndim3_axis0) { | |||
| constexpr int INPUT_NUM = 8; | |||
| int axis = 0; | |||
| std::vector<int> input_shapes[INPUT_NUM] = {{1, 17, 18}, {1, 17, 18}, {1, 17, 18}, {1, 17, 18}, | |||
| {1, 17, 18}, {1, 17, 18}, {1, 17, 18}, {1, 17, 18}}; | |||
| std::vector<int> output_shape = {8, 1, 17, 18}; | |||
| size_t input1_size, input2_size, input3_size, input4_size, input5_size, input6_size, input7_size, input8_size, | |||
| output_size; | |||
| std::string input1Ppath = "./test_data/stackfp32_input1.bin"; | |||
| std::string input2Ppath = "./test_data/stackfp32_input2.bin"; | |||
| std::string input3Ppath = "./test_data/stackfp32_input3.bin"; | |||
| std::string input4Ppath = "./test_data/stackfp32_input4.bin"; | |||
| std::string input5Ppath = "./test_data/stackfp32_input5.bin"; | |||
| std::string input6Ppath = "./test_data/stackfp32_input6.bin"; | |||
| std::string input7Ppath = "./test_data/stackfp32_input7.bin"; | |||
| std::string input8Ppath = "./test_data/stackfp32_input8.bin"; | |||
| std::string correctOutputPath = "./test_data/stackfp32_output.bin"; | |||
| auto input_data1 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input1Ppath.c_str(), &input1_size)); | |||
| auto input_data2 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input2Ppath.c_str(), &input2_size)); | |||
| auto input_data3 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input3Ppath.c_str(), &input3_size)); | |||
| auto input_data4 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input4Ppath.c_str(), &input4_size)); | |||
| auto input_data5 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input5Ppath.c_str(), &input5_size)); | |||
| auto input_data6 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input6Ppath.c_str(), &input6_size)); | |||
| auto input_data7 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input7Ppath.c_str(), &input7_size)); | |||
| auto input_data8 = reinterpret_cast<float *>(mindspore::lite::ReadFile(input8Ppath.c_str(), &input8_size)); | |||
| auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size)); | |||
| for (auto fp16_enable : {true}) { | |||
| auto *param = CreateParameter(axis); | |||
| TestMain({{input_shapes[0], input_data1, VAR}, | |||
| {input_shapes[1], input_data2, VAR}, | |||
| {input_shapes[2], input_data3, VAR}, | |||
| {input_shapes[3], input_data4, VAR}, | |||
| {input_shapes[4], input_data5, VAR}, | |||
| {input_shapes[5], input_data6, VAR}, | |||
| {input_shapes[6], input_data7, VAR}, | |||
| {input_shapes[7], input_data8, VAR}}, | |||
| {output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-9); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite::opencl::test | |||
| @@ -41,7 +41,14 @@ TEST_F(TestOpenCL_StridedSlice, 1D) { | |||
| float output_data[] = {3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({3}, {36}, {3}); | |||
| TestMain({{{36}, input_data, VAR}}, {{11}, output_data}, param, fp16_enable); | |||
| std::vector<int> begin = {3}; | |||
| std::vector<int> end = {36}; | |||
| std::vector<int> stride = {3}; | |||
| TestMain({{{36}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{11}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -50,8 +57,15 @@ TEST_F(TestOpenCL_StridedSlice, 2D) { | |||
| 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; | |||
| float output_data[] = {11, 14}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {1, 2}; | |||
| std::vector<int> end = {3, 8}; | |||
| std::vector<int> stride = {2, 3}; | |||
| auto *param = CreateParameter({1, 2}, {3, 8}, {2, 3}); | |||
| TestMain({{{4, 9}, input_data, VAR}}, {{1, 2}, output_data}, param, fp16_enable); | |||
| TestMain({{{4, 9}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 2}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -61,7 +75,14 @@ TEST_F(TestOpenCL_StridedSlice, 3D) { | |||
| float output_data[] = {11, 14}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({0, 1, 2}, {1, 3, 8}, {1, 2, 3}); | |||
| TestMain({{{1, 4, 9}, input_data, VAR}}, {{1, 1, 2}, output_data}, param, fp16_enable); | |||
| std::vector<int> begin = {0, 1, 2}; | |||
| std::vector<int> end = {1, 3, 8}; | |||
| std::vector<int> stride = {1, 2, 3}; | |||
| TestMain({{{1, 4, 9}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 1, 2}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -72,37 +93,79 @@ TEST_F(TestOpenCL_StridedSlice, 4D) { | |||
| float output_data0[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, | |||
| 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {0, 0, 0, 0}; | |||
| std::vector<int> end = {2, 2, 3, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({0, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 3, 3}, output_data0}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{2, 2, 3, 3}, output_data0}, param, fp16_enable); | |||
| } | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {0, 0, 0, 0}; | |||
| std::vector<int> end = {2, 2, 3, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({0, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 3, 3}, output_data0}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{2, 2, 3, 3}, output_data0}, param, fp16_enable); | |||
| } | |||
| float output_data1[] = {18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {1, 0, 0, 0}; | |||
| std::vector<int> end = {2, 2, 3, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({1, 0, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 2, 3, 3}, output_data1}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 2, 3, 3}, output_data1}, param, fp16_enable); | |||
| } | |||
| float output_data2[] = {27, 28, 29, 30, 31, 32, 33, 34, 35}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {1, 1, 0, 0}; | |||
| std::vector<int> end = {2, 2, 3, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({1, 1, 0, 0}, {2, 2, 3, 3}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 3, 3}, output_data2}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 1, 3, 3}, output_data2}, param, fp16_enable); | |||
| } | |||
| float output_data3[] = {33, 34, 35}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {1, 1, 2, 0}; | |||
| std::vector<int> end = {2, 2, 3, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({1, 1, 2, 0}, {2, 2, 3, 3}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 1, 3}, output_data3}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 1, 1, 3}, output_data3}, param, fp16_enable); | |||
| } | |||
| float output_data4[] = {34}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {1, 1, 2, 1}; | |||
| std::vector<int> end = {2, 2, 3, 2}; | |||
| std::vector<int> stride = {1, 1, 1, 1}; | |||
| auto *param = CreateParameter({1, 1, 2, 1}, {2, 2, 3, 2}, {1, 1, 1, 1}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{1, 1, 1, 1}, output_data4}, param, fp16_enable); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 1, 1, 1}, output_data4}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -111,8 +174,15 @@ TEST_F(TestOpenCL_StridedSlice, 4D_stride2) { | |||
| 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35}; | |||
| float output_data[] = {13, 14, 31, 32}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| std::vector<int> begin = {0, 1, 1, 1}; | |||
| std::vector<int> end = {1, 4, 3, 3}; | |||
| std::vector<int> stride = {2, 2, 2, 1}; | |||
| auto *param = CreateParameter({0, 1, 1, 1}, {1, 4, 3, 3}, {2, 2, 2, 1}); | |||
| TestMain({{{1, 4, 3, 3}, input_data, VAR}}, {{1, 2, 1, 2}, output_data}, param, fp16_enable); | |||
| TestMain({{{1, 4, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 2, 1, 2}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -122,19 +192,35 @@ TEST_F(TestOpenCL_StridedSlice, 4D_to_3D) { | |||
| float output_data[] = {18, 20, 21, 23, 27, 29, 30, 32}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({1, 0, 0, 0}, {2, 2, 2, 3}, {1, 1, 1, 2}); | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR}}, {{2, 2, 2}, output_data}, param, fp16_enable); | |||
| std::vector<int> begin = {1, 0, 0, 0}; | |||
| std::vector<int> end = {2, 2, 2, 3}; | |||
| std::vector<int> stride = {1, 1, 1, 2}; | |||
| TestMain({{{2, 2, 3, 3}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{2, 2, 2}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| // Check and optimize | |||
| TEST_F(TestOpenCL_StridedSlice, In1D_OutOfRangeBeginNegativeStride) { | |||
| float input_data[] = {1, 2, 3, 4}; | |||
| float output_data[] = {4, 3, 2}; | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({5}, {0}, {-1}); | |||
| TestMain({{{4}, input_data, VAR}}, {{3}, output_data}, param, fp16_enable); | |||
| std::vector<int> begin = {5}; | |||
| std::vector<int> end = {0}; | |||
| std::vector<int> stride = {-1}; | |||
| TestMain({{{4}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{3}, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| // Check and optimize | |||
| TEST_F(TestOpenCL_StridedSlice, test0) { | |||
| std::vector<float> values(32768); | |||
| for (int i = 0; i < values.size(); ++i) { | |||
| @@ -320,7 +406,12 @@ TEST_F(TestOpenCL_StridedSlice, test0) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(begin, end, stride); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| param->infer_flag_ = true; | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| } | |||
| @@ -332,8 +423,14 @@ TEST_F(TestOpenCL_StridedSlice, test1) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter({0, 1, 0, 1}, {1, 3, 2, 4}, {1, 1, 2, 2}); | |||
| TestMain({{{1, 3, 2, 4}, input_data, VAR}}, {{1, 2, 1, 2}, output_data}, param, fp16_enable, | |||
| fp16_enable ? 1e-2 : 1e-9); | |||
| std::vector<int> begin = {0, 1, 0, 1}; | |||
| std::vector<int> end = {1, 3, 2, 4}; | |||
| std::vector<int> stride = {1, 1, 2, 2}; | |||
| TestMain({{{1, 3, 2, 4}, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(begin.size())}, begin.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(end.size())}, end.data(), CONST_TENSOR, kNumberTypeInt32}, | |||
| {{static_cast<int>(stride.size())}, stride.data(), CONST_TENSOR, kNumberTypeInt32}}, | |||
| {{1, 2, 1, 2}, output_data}, param, fp16_enable, fp16_enable ? 1e-2 : 1e-9); | |||
| } | |||
| } | |||
| @@ -1,105 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "src/common/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "mindspore/lite/src/common/file_utils.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/opencl_subgraph.h" | |||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/to_format.h" | |||
| namespace mindspore::lite::opencl::test { | |||
| class TestToFormatOpenCL : public CommonTest { | |||
| public: | |||
| TestToFormatOpenCL() {} | |||
| }; | |||
| TEST_F(TestToFormatOpenCL, ToFormatNHWC2NCHW) { | |||
| auto ocl_runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); | |||
| ocl_runtime->Init(); | |||
| auto allocator = ocl_runtime->GetAllocator(); | |||
| int h = 64; | |||
| int w = 1; | |||
| int c = 7360; | |||
| size_t input_size; | |||
| std::string input_path = "./test_data/transpose/transpose_fp32_input.bin"; | |||
| auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(input_path.c_str(), &input_size)); | |||
| if (input_data == nullptr) { | |||
| MS_LOG(ERROR) << "input_data load error."; | |||
| return; | |||
| } | |||
| std::vector<int> input_shape = {1, h, w, c}; | |||
| auto tensor_x_ptr = std::make_unique<lite::Tensor>(TypeId(kNumberTypeFloat32), input_shape, schema::Format_NHWC4); | |||
| auto tensor_x = tensor_x_ptr.get(); | |||
| if (tensor_x == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_x create error."; | |||
| return; | |||
| } | |||
| std::vector<int> out_shape = {1, c, h, w}; | |||
| auto tensor_out_ptr = std::make_unique<lite::Tensor>(TypeId(kNumberTypeFloat32), out_shape); | |||
| auto tensor_out = tensor_out_ptr.get(); | |||
| if (tensor_out == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_out create error."; | |||
| return; | |||
| } | |||
| std::vector<lite::Tensor *> inputs{tensor_x}; | |||
| std::vector<lite::Tensor *> outputs{tensor_out}; | |||
| auto arith_kernel_ptr = std::make_unique<kernel::ToFormatOpenCLKernel>(nullptr, inputs, outputs, nullptr); | |||
| auto arith_kernel = arith_kernel_ptr.get(); | |||
| if (arith_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "arith_kernel create error."; | |||
| return; | |||
| } | |||
| arith_kernel->Init(); | |||
| inputs[0]->MallocData(allocator); | |||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||
| auto pGraph_ptr = std::make_unique<kernel::OpenCLSubGraph>(inputs, outputs, kernels, kernels, kernels); | |||
| auto pGraph = pGraph_ptr.get(); | |||
| if (pGraph == nullptr) { | |||
| MS_LOG(ERROR) << "pGraph create error."; | |||
| return; | |||
| } | |||
| pGraph->Init(); | |||
| memcpy(inputs[0]->data_c(), input_data, input_size); | |||
| pGraph->Run(); | |||
| size_t output_size; | |||
| std::string output_path = "./test_data/transpose/transpose_fp32_output.bin"; | |||
| auto correct_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(output_path.c_str(), &output_size)); | |||
| if (correct_data == nullptr) { | |||
| MS_LOG(ERROR) << "correct_data create error."; | |||
| return; | |||
| } | |||
| printf("==================output data=================\n"); | |||
| float *output_data = reinterpret_cast<float *>(tensor_out->data_c()); | |||
| std::cout << std::endl; | |||
| int size_n = h * w * c; | |||
| size_n = size_n > 100 ? 100 : size_n; | |||
| for (int i = 0; i < size_n; i++) { | |||
| std::cout << output_data[i] << " "; | |||
| if ((i + 1) % c == 0) { | |||
| std::cout << std::endl; | |||
| } | |||
| } | |||
| std::cout << std::endl; | |||
| // compare | |||
| ASSERT_EQ(0, CompareOutputData(output_data, correct_data, h * w * c, 0.00001)); | |||
| MS_LOG(INFO) << "Test TransposeFp32 passed"; | |||
| } | |||
| } // namespace mindspore::lite::opencl::test | |||
| @@ -46,7 +46,9 @@ TEST_F(TestOpenCL_Transpose, NHWC2NCHW) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(perm); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -62,7 +64,9 @@ TEST_F(TestOpenCL_Transpose, NCHW2NHWC) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(perm); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -78,7 +82,9 @@ TEST_F(TestOpenCL_Transpose, NHWC2NWHC) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(perm); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -94,7 +100,9 @@ TEST_F(TestOpenCL_Transpose, NWC2CWN) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(perm); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| @@ -112,7 +120,9 @@ TEST_F(TestOpenCL_Transpose, NWC2WNC) { | |||
| for (auto fp16_enable : {false, true}) { | |||
| auto *param = CreateParameter(perm); | |||
| TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable); | |||
| TestMain({{input_shape, input_data, VAR, kNumberTypeFloat32}, | |||
| {{static_cast<int>(perm.size())}, {perm.data()}, CONST_TENSOR, kNumberTypeInt32}}, | |||
| {output_shape, output_data}, param, fp16_enable); | |||
| } | |||
| } | |||
| } // namespace mindspore::lite::opencl::test | |||
| @@ -0,0 +1,50 @@ | |||
| TestOpenCL_Transpose.* | |||
| TestOpenCL_StridedSlice.1D | |||
| TestOpenCL_StridedSlice.2D | |||
| TestOpenCL_StridedSlice.3D | |||
| TestOpenCL_StridedSlice.4D | |||
| TestOpenCL_StridedSlice.4D_stride2 | |||
| TestOpenCL_StridedSlice.4D_to_3D | |||
| TestOpenCL_StridedSlice.test1 | |||
| TestOpenCL_Stack.* | |||
| TestOpenCL_Split.input3_axis0 | |||
| TestOpenCL_DepthToSpace.* | |||
| TestOpenCL_SpaceToDepth.* | |||
| TestOpenCL_SpaceToBatch.* | |||
| TestOpenCL_SoftMax.* | |||
| TestOpenCL_Slice.4D | |||
| TestOpenCL_Shape.* | |||
| TestOpenCL_Scale.* | |||
| TestOpenCL_Resize.* | |||
| TestOpenCL_Reshape.* | |||
| TestOpenCL_Reduce.* | |||
| TestOpenCL_Pooling.* | |||
| TestOpenCL_Pad.1D | |||
| TestOpenCL_Pad.2D | |||
| TestOpenCL_Pad.3D | |||
| TestOpenCL_Pad.4D | |||
| TestOpenCL_OneHot.* | |||
| TestOpenCL_MatMul.* | |||
| TestOpenCL_LayerNorm.* | |||
| TestOpenCL_Gather.* | |||
| TestOpenCL_FullConnection.* | |||
| TestOpenCL_Conv2D.test0 | |||
| TestOpenCL_Conv2D.test0_no_bias | |||
| TestOpenCL_Conv2D.test1 | |||
| TestOpenCL_Conv2D.test2 | |||
| TestOpenCL_Conv2D.test3 | |||
| TestOpenCL_Conv2D.test3_batch2 | |||
| TestOpenCL_Concat.* | |||
| TestOpenCL_BatchNorm.* | |||
| TestOpenCL_BatchToSpaceND.* | |||
| TestOpenCL_Arithmetic.ElementwiseAdd | |||
| TestOpenCL_Arithmetic.ScalarMul | |||
| TestOpenCL_Arithmetic.BroadcastSubReLU6 | |||
| TestOpenCL_Arithmetic.BroadcastSub2 | |||
| TestOpenCL_Arithmetic.BroadcastSub3 | |||
| TestOpenCL_Arithmetic.BroadcastFloorMod | |||
| TestOpenCL_Arithmetic.FloorMod | |||
| TestOpenCL_Arithmetic.ElementwiseDiv | |||
| TestOpenCL_ArithmeticSelf.* | |||
| TestOpenCL_ArgMinMax.* | |||
| TestOpenCL_Activation.* | |||