Merge pull request !4001 from chenzhongming/mastertags/v0.7.0-beta
| @@ -289,7 +289,11 @@ if (SUPPORT_GPU) | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/matmul_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/depthwise_conv2d_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc | ${TEST_DIR}/ut/src/runtime/kernel/opencl/concat_tests.cc | ||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_cl_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/softmax_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/arithmetic_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/max_pooling_tests.cc | |||||
| ${TEST_DIR}/ut/src/runtime/kernel/opencl/utils_tests.cc | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| @@ -0,0 +1,176 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/arithmetic.h" | |||||
| namespace mindspore { | |||||
| void BoardcaseAdd(const float *a, const float b, float *c, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| c[i] = a[i] + b; | |||||
| } | |||||
| } | |||||
| void ElementAdd(const float *a, const float *b, float *c, const int size) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| c[i] = a[i] + b[i]; | |||||
| } | |||||
| } | |||||
| bool DataCompare(const float *a, const float *b, const int size, const float accuracy = 1e-4) { | |||||
| for (int i = 0; i < size; i++) { | |||||
| auto diff = fabs(a[i] - b[i]); | |||||
| if (diff > accuracy) { | |||||
| MS_LOG(ERROR) << "compare failed at " << i << " exp " << a[i] << " bug got " << b[i]; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void InitData(void *data, const int size) { | |||||
| float *data_float = reinterpret_cast<float *>(data); | |||||
| static unsigned int seed = 123; | |||||
| for (int i = 0; i < size; i++) { | |||||
| data_float[i] = static_cast<int>(rand_r(&seed)) % 100; | |||||
| } | |||||
| } | |||||
| void LogData(void *data, const int size, const std::string prefix) { | |||||
| std::cout << prefix; | |||||
| float *data_float = reinterpret_cast<float *>(data); | |||||
| for (int i = 0; i < size; i++) { | |||||
| std::cout << data_float[i] << ","; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| } | |||||
| void TestCase(const std::vector<int> &shape_a, const std::vector<int> &shape_b) { | |||||
| std::cout << "TestCase" << std::endl; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| bool is_bias_add = shape_b.empty(); | |||||
| auto tensorType = schema::NodeType_ValueNode; | |||||
| std::cout << "TestCase tensor" << std::endl; | |||||
| lite::tensor::Tensor *tensor_a = | |||||
| new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType); | |||||
| lite::tensor::Tensor *tensor_b = | |||||
| new lite::tensor::Tensor(kNumberTypeFloat32, shape_b, schema::Format_NHWC4, tensorType); | |||||
| lite::tensor::Tensor *tensor_c = | |||||
| new lite::tensor::Tensor(kNumberTypeFloat32, shape_a, schema::Format_NHWC4, tensorType); | |||||
| int64_t element_num = tensor_a->ElementsC4Num(); | |||||
| int64_t element_num_b = is_bias_add ? 1 : tensor_b->ElementsC4Num(); | |||||
| std::cout << "TestCase new data" << std::endl; | |||||
| float *data_a = new float[element_num]; | |||||
| float *data_b = new float[element_num_b]; | |||||
| float *data_c_cpu = new float[element_num]; | |||||
| float *data_c_ocl = new float[element_num]; | |||||
| InitData(data_a, element_num); | |||||
| InitData(data_b, element_num_b); | |||||
| memset(data_c_ocl, 0, sizeof(float) * element_num); | |||||
| std::cout << "TestCase run cpu" << std::endl; | |||||
| if (is_bias_add) { | |||||
| BoardcaseAdd(data_a, static_cast<float *>(data_b)[0], data_c_cpu, element_num); | |||||
| } else { | |||||
| ElementAdd(data_a, data_b, data_c_cpu, element_num); | |||||
| } | |||||
| std::cout << "TestCase set data" << std::endl; | |||||
| std::vector<lite::tensor::Tensor *> inputs = {tensor_a}; | |||||
| if (!is_bias_add) { | |||||
| inputs.push_back(tensor_b); | |||||
| } else { | |||||
| tensor_b->MallocData(); | |||||
| memcpy(tensor_b->Data(), data_b, sizeof(float)); | |||||
| } | |||||
| std::vector<lite::tensor::Tensor *> outputs = {tensor_c}; | |||||
| ArithmeticParameter *param = new ArithmeticParameter(); | |||||
| param->ndim_ = 4; | |||||
| param->op_parameter_.type_ = PrimitiveType_Add; | |||||
| std::vector<lite::tensor::Tensor *> arithmetic_inputs = {tensor_a, tensor_b}; | |||||
| lite::Context ctx; | |||||
| auto *arith_kernel = | |||||
| new kernel::ArithmeticOpenCLKernel(reinterpret_cast<OpParameter *>(param), arithmetic_inputs, outputs, &ctx); | |||||
| arith_kernel->Init(); | |||||
| std::vector<kernel::LiteKernel *> kernels{arith_kernel}; | |||||
| auto *kernel = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| std::cout << "TestCase Init" << std::endl; | |||||
| kernel->Init(); | |||||
| memcpy(inputs[0]->Data(), data_a, sizeof(float) * element_num); | |||||
| if (!is_bias_add) { | |||||
| memcpy(inputs[1]->Data(), data_b, sizeof(float) * element_num_b); | |||||
| } | |||||
| std::cout << "TestCase Run" << std::endl; | |||||
| kernel->Run(); | |||||
| memcpy(data_c_ocl, outputs[0]->Data(), sizeof(float) * element_num); | |||||
| // ocl_runtime->SyncCommandQueue(); | |||||
| LogData(data_a, 10, "Data A : "); | |||||
| LogData(data_b, tensor_b->shape().empty() ? 1 : 10, "Data B : "); | |||||
| LogData(data_c_cpu, 10, "Expect compute : "); | |||||
| LogData(outputs[0]->Data(), 10, "OpenCL compute : "); | |||||
| bool cmp = DataCompare(data_c_cpu, data_c_ocl, element_num); | |||||
| MS_LOG(INFO) << "Compare " << (cmp ? "success!" : "failed!"); | |||||
| std::cout << "TestCase End" << std::endl; | |||||
| // free | |||||
| delete[] data_a; | |||||
| delete[] data_b; | |||||
| delete[] data_c_cpu; | |||||
| delete[] data_c_ocl; | |||||
| delete kernel; | |||||
| delete arith_kernel; | |||||
| for (auto tensor : inputs) { | |||||
| delete tensor; | |||||
| } | |||||
| for (auto tensor : outputs) { | |||||
| delete tensor; | |||||
| } | |||||
| lite::opencl::OpenCLRuntime::DeleteInstance(); | |||||
| } | |||||
| class TestArithmeticOpenCL : public mindspore::Common { | |||||
| public: | |||||
| TestArithmeticOpenCL() {} | |||||
| }; | |||||
| TEST_F(TestArithmeticOpenCL, AddElementwiseTest) { | |||||
| const std::vector<int> &shape_a = {1, 32, 32, 4}; | |||||
| const std::vector<int> &shape_b = {1, 32, 32, 4}; | |||||
| TestCase(shape_a, shape_b); | |||||
| } | |||||
| // TEST_F(TestOpenCLKernel, AddBoardcaseTest) { | |||||
| // const std::vector<int> &shape_a = {1, 4, 128, 128}; | |||||
| // const std::vector<int> &shape_b = {}; | |||||
| // TestCase(shape_a, shape_b); | |||||
| //} | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,124 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h" | |||||
| namespace mindspore { | |||||
| class TestAvgPoolingOpenCL : public mindspore::Common {}; | |||||
| void InitAvgPoolingParam(PoolingParameter *param) { | |||||
| param->input_batch_ = 1; | |||||
| param->input_h_ = 2; | |||||
| param->input_w_ = 2; | |||||
| param->input_channel_ = 4; | |||||
| param->output_batch_ = 1; | |||||
| param->output_h_ = 1; | |||||
| param->output_w_ = 1; | |||||
| param->output_channel_ = 4; | |||||
| param->window_h_ = 2; | |||||
| param->window_w_ = 2; | |||||
| param->stride_h_ = 2; | |||||
| param->stride_w_ = 2; | |||||
| param->pad_u_ = 0; | |||||
| param->pad_d_ = 0; | |||||
| param->pad_l_ = 0; | |||||
| param->pad_r_ = 0; | |||||
| param->max_pooling_ = false; | |||||
| param->avg_pooling_ = true; | |||||
| } | |||||
| TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) { | |||||
| MS_LOG(INFO) << "start TEST_F TestPoolingOpenCL"; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| MS_LOG(INFO) << "create PoolingParameter"; | |||||
| auto param = new PoolingParameter(); | |||||
| InitAvgPoolingParam(param); | |||||
| MS_LOG(INFO) << "create Tensors"; | |||||
| std::vector<int> shape_in = { | |||||
| param->input_batch_, | |||||
| param->input_h_, | |||||
| param->input_w_, | |||||
| param->input_channel_, | |||||
| }; | |||||
| std::vector<int> shape_out = { | |||||
| param->output_batch_, | |||||
| param->output_h_, | |||||
| param->output_w_, | |||||
| param->output_channel_, | |||||
| }; | |||||
| auto data_type = kNumberTypeFloat32; | |||||
| auto tensorType = schema::NodeType_ValueNode; | |||||
| lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NHWC, tensorType); | |||||
| lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NHWC, tensorType); | |||||
| std::vector<lite::tensor::Tensor *> inputs{tensor_in}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{tensor_out}; | |||||
| MS_LOG(INFO) << "create OpenCL Kernel"; | |||||
| auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| pooling_kernel->Init(); | |||||
| std::vector<kernel::LiteKernel *> kernels{pooling_kernel}; | |||||
| MS_LOG(INFO) << "create SubGraphOpenCLKernel"; | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | |||||
| MS_LOG(INFO) << "initialize data"; | |||||
| std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in}; | |||||
| for (auto &tensor_file : tensor_map) { | |||||
| auto tensor = tensor_file; | |||||
| size_t size = tensor->Size(); | |||||
| const float data[16] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, | |||||
| 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; | |||||
| memcpy(tensor->Data(), data, size); | |||||
| } | |||||
| MS_LOG(INFO) << "pGraph->Run()"; | |||||
| pGraph->Run(); | |||||
| MS_LOG(INFO) << "==================output data================="; | |||||
| float *output_data = reinterpret_cast<float *>(tensor_out->Data()); | |||||
| printf("output:"); | |||||
| for (int i = 0; i < 4; i++) { | |||||
| printf("%.3f ", output_data[i]); | |||||
| } | |||||
| printf("\n"); | |||||
| size_t output_size = tensor_out->Size(); | |||||
| float expect[4] = {2.0f, 3.0f, 4.0f, 5.0f}; | |||||
| for (int i = 0; i < tensor_out->ElementsNum(); ++i) | |||||
| if (std::fabs(output_data[i] - expect[i]) > 1e-5) { | |||||
| printf("idx[%d] except=%.3f output=%.3f, ", i, expect[i], output_data[i]); | |||||
| } | |||||
| printf("test all close OK!\n"); | |||||
| lite::CompareOutputData(output_data, expect, 4); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <memory> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/common_test.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.h" | |||||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||||
| namespace mindspore { | |||||
| class TestMaxPoolingOpenCL : public mindspore::Common {}; | |||||
| void InitParameter(PoolingParameter *param) { | |||||
| param->window_h_ = 2; | |||||
| param->window_w_ = 2; | |||||
| param->stride_h_ = 2; | |||||
| param->stride_w_ = 2; | |||||
| param->pad_u_ = 0; | |||||
| param->pad_d_ = 0; | |||||
| param->pad_l_ = 0; | |||||
| param->pad_r_ = 0; | |||||
| param->avg_pooling_ = false; | |||||
| param->max_pooling_ = true; | |||||
| } | |||||
| TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { | |||||
| MS_LOG(INFO) << "ocl runtime"; | |||||
| auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); | |||||
| ocl_runtime->Init(); | |||||
| MS_LOG(INFO) << "PoolingParameter"; | |||||
| auto param = new PoolingParameter; | |||||
| InitParameter(param); | |||||
| // define tensor | |||||
| MS_LOG(INFO) << "define tensor"; | |||||
| std::vector<int> input_shape = {1, 16, 256, 192}; | |||||
| std::vector<int> output_shape = {1, 8, 128, 192}; | |||||
| auto data_type = kNumberTypeFloat32; | |||||
| auto tensorType = schema::NodeType_ValueNode; | |||||
| auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType); | |||||
| auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType); | |||||
| std::vector<lite::tensor::Tensor *> inputs{input_tensor}; | |||||
| std::vector<lite::tensor::Tensor *> outputs{output_tensor}; | |||||
| // run | |||||
| auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); | |||||
| pooling_kernel->Init(); | |||||
| std::vector<kernel::LiteKernel *> kernels{pooling_kernel}; | |||||
| auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); | |||||
| pGraph->Init(); | |||||
| // load data | |||||
| MS_LOG(INFO) << "load data"; | |||||
| std::string input_file = "maxpool_in.bin"; | |||||
| std::string expect_file = "maxpool_out.bin"; | |||||
| LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); | |||||
| auto *input_data = reinterpret_cast<float *>(input_tensor->Data()); | |||||
| printf("input[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("[%d]:%.3f ", i, input_data[i]); | |||||
| } | |||||
| printf("\n"); | |||||
| pGraph->Run(); | |||||
| MS_LOG(INFO) << "compare result"; | |||||
| CompareOutput(output_tensor, expect_file); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h" | |||||
| namespace mindspore { | |||||
| void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) { | |||||
| if (file_path.empty()) { | |||||
| memset(dst, dst_size, dst_size); | |||||
| } else { | |||||
| memcpy(dst, reinterpret_cast<const void *>(dst_size), dst_size); | |||||
| } | |||||
| } | |||||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path) { | |||||
| float *output_data = reinterpret_cast<float *>(output_tensor->Data()); | |||||
| size_t output_size = output_tensor->Size(); | |||||
| float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||||
| printf("output[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("[%d]:%.3f ", i, output_data[i]); | |||||
| } | |||||
| printf("\n"); | |||||
| printf("expect[0:10]:"); | |||||
| for (int i = 0; i < 10; i++) { | |||||
| printf("[%d]:%.3f ", i, expect_data[i]); | |||||
| } | |||||
| printf("\n"); | |||||
| constexpr float atol = 1e-5; | |||||
| for (int i = 0; i < output_tensor->ElementsNum(); ++i) { | |||||
| if (std::fabs(output_data[i] - expect_data[i]) > atol) { | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||||
| printf("error at idx[%d] expect=%.3f output=%.3f \n", i, expect_data[i], output_data[i]); | |||||
| return; | |||||
| } | |||||
| } | |||||
| printf("compare success!\n"); | |||||
| printf("compare success!\n"); | |||||
| printf("compare success!\n\n\n"); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <iostream> | |||||
| #include "tests/ut/cpp/common/common_test.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "mindspore/lite/src/common/file_utils.h" | |||||
| #include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" | |||||
| #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" | |||||
| #ifndef TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ | |||||
| #define TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ | |||||
| namespace mindspore { | |||||
| void LoadTestData(void *dst, size_t dst_size, const std::string &file_path); | |||||
| void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_path); | |||||
| } // namespace mindspore | |||||
| #endif // TESTS_UT_OPENCL_KERNEL_TESTS_UTILS_H_ | |||||
| @@ -446,7 +446,7 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' | |||||
| Ascend model. | Ascend model. | ||||
| - BINARY: Binary format for model. An intermidiate representation format for models. | - BINARY: Binary format for model. An intermidiate representation format for models. | ||||
| """ | """ | ||||
| supported_device = ["Ascend"] | |||||
| supported_device = ["Ascend", "GPU"] | |||||
| supported_formats = ['GEIR', 'BINARY'] | supported_formats = ['GEIR', 'BINARY'] | ||||
| mean = validator.check_type("mean", mean, (int, float)) | mean = validator.check_type("mean", mean, (int, float)) | ||||