run model OpenCL + ExternCOprRunner, for example
graph: part_a(OpenCL) --> part_b(ExternCOprRunner) --> part_c(OpenCL)
GitOrigin-RevId: f754b559a2
tags/v1.11.0
| @@ -381,7 +381,6 @@ void ExternCOprRunner::check_param() { | |||||
| void ExternCOprRunner::scn_do_execute() { | void ExternCOprRunner::scn_do_execute() { | ||||
| SmallVector<MGBTensor> c_inp(input().size()), c_out(output().size()); | SmallVector<MGBTensor> c_inp(input().size()), c_out(output().size()); | ||||
| SmallVector<HostTensorND> cpu_inp, cpu_out; | |||||
| check_param(); | check_param(); | ||||
| bool need_copy = false; | bool need_copy = false; | ||||
| @@ -16,6 +16,9 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| //! store dynamic store param | //! store dynamic store param | ||||
| std::shared_ptr<ExternCOprParam> m_param; | std::shared_ptr<ExternCOprParam> m_param; | ||||
| //! HostTensorND holder for scn_do_execute | |||||
| SmallVector<HostTensorND> cpu_inp, cpu_out; | |||||
| void get_output_var_shape( | void get_output_var_shape( | ||||
| const TensorShapeArray& inp_shape, | const TensorShapeArray& inp_shape, | ||||
| TensorShapeArray& out_shape) const override; | TensorShapeArray& out_shape) const override; | ||||
| @@ -445,6 +445,22 @@ TEST(TestExternCOpr, GPUCompute) { | |||||
| run_compute_test(CompNode::load("gpux"), MGB_DTYPE_FLOAT32); | run_compute_test(CompNode::load("gpux"), MGB_DTYPE_FLOAT32); | ||||
| } | } | ||||
| #if MGB_OPENCL | |||||
| #include "megcore_opencl.h" | |||||
| #define REQUIRE_OPENCL() \ | |||||
| do { \ | |||||
| if (!CompNode::get_device_count(CompNode::DeviceType::OPENCL)) { \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | |||||
| TEST(TestExternCOpr, OPENCLCompute) { | |||||
| REQUIRE_OPENCL(); | |||||
| run_compute_test(CompNode::load("openclx"), MGB_DTYPE_FLOAT32); | |||||
| } | |||||
| #endif | |||||
| TEST(TestExternCOpr, CPUComputeMultiDtype) { | TEST(TestExternCOpr, CPUComputeMultiDtype) { | ||||
| run_compute_test(CompNode::load("cpux"), MGB_DTYPE_INT32); | run_compute_test(CompNode::load("cpux"), MGB_DTYPE_INT32); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||