GitOrigin-RevId: a54237488f
tags/v1.7.0
| @@ -67,7 +67,8 @@ bool config_user_allocator(const Args& args); | |||||
| bool register_cryption_method(const Args& args); | bool register_cryption_method(const Args& args); | ||||
| bool update_cryption_key(const Args& args); | bool update_cryption_key(const Args& args); | ||||
| bool async_forward(const Args& args); | bool async_forward(const Args& args); | ||||
| bool set_input_callback(const Args& arg); | |||||
| bool set_output_callback(const Args& arg); | |||||
| #if LITE_WITH_CUDA | #if LITE_WITH_CUDA | ||||
| bool device_input(const Args& args); | bool device_input(const Args& args); | ||||
| bool device_input_output(const Args& args); | bool device_input_output(const Args& args); | ||||
| @@ -160,6 +160,8 @@ REGIST_EXAMPLE("reset_input", reset_input); | |||||
| REGIST_EXAMPLE("reset_input_output", reset_input_output); | REGIST_EXAMPLE("reset_input_output", reset_input_output); | ||||
| REGIST_EXAMPLE("config_user_allocator", config_user_allocator); | REGIST_EXAMPLE("config_user_allocator", config_user_allocator); | ||||
| REGIST_EXAMPLE("async_forward", async_forward); | REGIST_EXAMPLE("async_forward", async_forward); | ||||
| REGIST_EXAMPLE("set_input_callback", set_input_callback); | |||||
| REGIST_EXAMPLE("set_output_callback", set_output_callback); | |||||
| REGIST_EXAMPLE("basic_c_interface", basic_c_interface); | REGIST_EXAMPLE("basic_c_interface", basic_c_interface); | ||||
| REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface); | REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface); | ||||
| @@ -365,6 +365,142 @@ bool lite::example::async_forward(const Args& args) { | |||||
| printf("max=%e, sum=%e\n", max, sum); | printf("max=%e, sum=%e\n", max, sum); | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool lite::example::set_input_callback(const Args& args) { | |||||
| std::string network_path = args.model_path; | |||||
| std::string input_path = args.input_path; | |||||
| Config config; | |||||
| config.options.var_sanity_check_first_run = false; | |||||
| //! create and load the network | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
| network->load_model(network_path); | |||||
| //! set input data to input tensor | |||||
| std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0); | |||||
| //! copy or forward data to network | |||||
| size_t length = input_tensor->get_tensor_total_size_in_byte(); | |||||
| void* dst_ptr = input_tensor->get_memory_ptr(); | |||||
| auto src_tensor = parse_npy(input_path); | |||||
| void* src = src_tensor->get_memory_ptr(); | |||||
| memcpy(dst_ptr, src, length); | |||||
| //! set input callback | |||||
| volatile bool finished = false; | |||||
| network->set_start_callback( | |||||
| [&finished](const std::unordered_map< | |||||
| std::string, std::pair<IO, std::shared_ptr<Tensor>>>& inputs) { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| std::cout << "worker thread_id:" << std::this_thread::get_id() | |||||
| << std::endl; | |||||
| #endif | |||||
| for (auto&& item : inputs) { | |||||
| std::cout << "input name: " << item.first | |||||
| << "input dim: " << item.second.second->get_layout().ndim | |||||
| << std::endl; | |||||
| } | |||||
| finished = true; | |||||
| }); | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| std::cout << "out thread_id:" << std::this_thread::get_id() << std::endl; | |||||
| #endif | |||||
| //! forward | |||||
| network->forward(); | |||||
| size_t count = 0; | |||||
| while (finished == false) { | |||||
| count++; | |||||
| } | |||||
| printf("Forward finish, count is %zu\n", count); | |||||
| //! get the output data or read tensor set in network_in | |||||
| std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||||
| void* out_data = output_tensor->get_memory_ptr(); | |||||
| size_t out_length = output_tensor->get_tensor_total_size_in_byte() / | |||||
| output_tensor->get_layout().get_elem_size(); | |||||
| printf("length=%zu\n", length); | |||||
| float max = -1.0f; | |||||
| float sum = 0.0f; | |||||
| for (size_t i = 0; i < out_length; i++) { | |||||
| float data = static_cast<float*>(out_data)[i]; | |||||
| sum += data; | |||||
| if (max < data) | |||||
| max = data; | |||||
| } | |||||
| printf("max=%e, sum=%e\n", max, sum); | |||||
| return true; | |||||
| } | |||||
| bool lite::example::set_output_callback(const Args& args) { | |||||
| std::string network_path = args.model_path; | |||||
| std::string input_path = args.input_path; | |||||
| Config config; | |||||
| config.options.var_sanity_check_first_run = false; | |||||
| //! create and load the network | |||||
| std::shared_ptr<Network> network = std::make_shared<Network>(config); | |||||
| network->load_model(network_path); | |||||
| //! set input data to input tensor | |||||
| std::shared_ptr<Tensor> input_tensor = network->get_output_tensor(0); | |||||
| //! copy or forward data to network | |||||
| size_t length = input_tensor->get_tensor_total_size_in_byte(); | |||||
| void* dst_ptr = input_tensor->get_memory_ptr(); | |||||
| auto src_tensor = parse_npy(input_path); | |||||
| void* src = src_tensor->get_memory_ptr(); | |||||
| memcpy(dst_ptr, src, length); | |||||
| //! set output callback | |||||
| volatile bool finished = false; | |||||
| network->set_finish_callback( | |||||
| [&finished](const std::unordered_map< | |||||
| std::string, std::pair<IO, std::shared_ptr<Tensor>>>& outputs) { | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| std::cout << "worker thread_id:" << std::this_thread::get_id() | |||||
| << std::endl; | |||||
| #endif | |||||
| for (auto&& item : outputs) { | |||||
| std::cout << "output name: " << item.first | |||||
| << "output dim: " << item.second.second->get_layout().ndim | |||||
| << std::endl; | |||||
| } | |||||
| finished = true; | |||||
| }); | |||||
| #if !__DEPLOY_ON_XP_SP2__ | |||||
| std::cout << "out thread_id:" << std::this_thread::get_id() << std::endl; | |||||
| #endif | |||||
| //! forward | |||||
| network->forward(); | |||||
| network->wait(); | |||||
| size_t count = 0; | |||||
| while (finished == false) { | |||||
| count++; | |||||
| } | |||||
| printf("Forward finish, count is %zu\n", count); | |||||
| //! get the output data or read tensor set in network_in | |||||
| std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0); | |||||
| void* out_data = output_tensor->get_memory_ptr(); | |||||
| size_t out_length = output_tensor->get_tensor_total_size_in_byte() / | |||||
| output_tensor->get_layout().get_elem_size(); | |||||
| printf("length=%zu\n", length); | |||||
| float max = -1.0f; | |||||
| float sum = 0.0f; | |||||
| for (size_t i = 0; i < out_length; i++) { | |||||
| float data = static_cast<float*>(out_data)[i]; | |||||
| sum += data; | |||||
| if (max < data) | |||||
| max = data; | |||||
| } | |||||
| printf("max=%e, sum=%e\n", max, sum); | |||||
| return true; | |||||
| } | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -184,6 +184,8 @@ typedef int (*LiteThreadAffinityCallback)(int thread_id); | |||||
| typedef int (*LiteAsyncCallback)(); | typedef int (*LiteAsyncCallback)(); | ||||
| typedef int (*LiteAsyncCallbackWithData)(void* user_data); | |||||
| /*! | /*! | ||||
| * \brief the start/finish callback function | * \brief the start/finish callback function | ||||
| * \param unordered_map map from the io tensor name to the pair of which is the | * \param unordered_map map from the io tensor name to the pair of which is the | ||||
| @@ -193,9 +195,17 @@ typedef int (*LiteAsyncCallback)(); | |||||
| typedef int (*LiteStartCallback)( | typedef int (*LiteStartCallback)( | ||||
| const LiteIO* inputs, const LiteTensor* input_tensors, size_t size); | const LiteIO* inputs, const LiteTensor* input_tensors, size_t size); | ||||
| typedef int (*LiteStartCallbackWithData)( | |||||
| const LiteIO* inputs, const LiteTensor* input_tensors, size_t size, | |||||
| void* user_data); | |||||
| typedef int (*LiteFinishCallback)( | typedef int (*LiteFinishCallback)( | ||||
| const LiteIO* outputs, const LiteTensor* output_tensors, size_t size); | const LiteIO* outputs, const LiteTensor* output_tensors, size_t size); | ||||
| typedef int (*LiteFinishCallbackWithData)( | |||||
| const LiteIO* outputs, const LiteTensor* output_tensors, size_t size, | |||||
| void* user_data); | |||||
| /*! | /*! | ||||
| * \brief The network is construct form a model, implement model load, init, | * \brief The network is construct form a model, implement model load, init, | ||||
| * forward, and display some model information | * forward, and display some model information | ||||
| @@ -442,6 +452,19 @@ LITE_API int LITE_set_network_algo_workspace_limit( | |||||
| LITE_API int LITE_set_async_callback( | LITE_API int LITE_set_async_callback( | ||||
| LiteNetwork network, const LiteAsyncCallback async_callback); | LiteNetwork network, const LiteAsyncCallback async_callback); | ||||
| /** | |||||
| * \brief set the network forward in async mode and set the async callback | |||||
| * function | |||||
| * \param[in] network The loaded model | |||||
| * \param[in] async_callback when network finish forwarding, the callback | |||||
| * will be called | |||||
| * \param[in] user_data user defined data for something user want to deploy | |||||
| * at forward finish stage | |||||
| */ | |||||
| LITE_API int LITE_set_async_callback_with_userdata( | |||||
| LiteNetwork network, const LiteAsyncCallbackWithData async_callback, | |||||
| void* user_data); | |||||
| /** | /** | ||||
| * \brief set the start forward callback function, which will be execute beform | * \brief set the start forward callback function, which will be execute beform | ||||
| * forward, this can be used to check network input or dump model inputs | * forward, this can be used to check network input or dump model inputs | ||||
| @@ -453,6 +476,20 @@ LITE_API int LITE_set_async_callback( | |||||
| LITE_API int LITE_set_start_callback( | LITE_API int LITE_set_start_callback( | ||||
| LiteNetwork network, const LiteStartCallback start_callback); | LiteNetwork network, const LiteStartCallback start_callback); | ||||
| /** | |||||
| * \brief set the start forward callback function, which will be execute beform | |||||
| * forward, this can be used to check network input or dump model inputs | |||||
| * for debug | |||||
| * \param[in] network The loaded model | |||||
| * \param[in] start_callback when network start forwarding, the callbak | |||||
| * will be called | |||||
| * \param[in] user_data user defined data for something user want to deploy | |||||
| * at forward start stage | |||||
| */ | |||||
| LITE_API int LITE_set_start_callback_with_userdata( | |||||
| LiteNetwork network, const LiteStartCallbackWithData start_callback, | |||||
| void* user_data); | |||||
| /** | /** | ||||
| * \brief set the finish forward callback function, which will be execute after | * \brief set the finish forward callback function, which will be execute after | ||||
| * forward, this can be used to dump model outputs for debug | * forward, this can be used to dump model outputs for debug | ||||
| @@ -463,6 +500,19 @@ LITE_API int LITE_set_start_callback( | |||||
| LITE_API int LITE_set_finish_callback( | LITE_API int LITE_set_finish_callback( | ||||
| LiteNetwork network, const LiteFinishCallback finish_callback); | LiteNetwork network, const LiteFinishCallback finish_callback); | ||||
| /** | |||||
| * \brief set the finish forward callback function, which will be execute after | |||||
| * forward, this can be used to dump model outputs for debug | |||||
| * \param[in] network The loaded model | |||||
| * \param[in] finish_callback when network finish forwarding, the callbak | |||||
| * will be called | |||||
| * \param[in] user_data user defined data for something user want to deploy | |||||
| * at finish stage | |||||
| */ | |||||
| LITE_API int LITE_set_finish_callback_with_userdata( | |||||
| LiteNetwork network, const LiteFinishCallbackWithData finish_callback, | |||||
| void* user_data); | |||||
| /** | /** | ||||
| * \brief set threads affinity callback | * \brief set threads affinity callback | ||||
| * \param[in] network The loaded model | * \param[in] network The loaded model | ||||
| @@ -355,6 +355,22 @@ int LITE_set_async_callback( | |||||
| LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
| } | } | ||||
| int LITE_set_async_callback_with_userdata( | |||||
| LiteNetwork network, LiteAsyncCallbackWithData async_callback, | |||||
| void* user_data) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
| LITE_ASSERT(async_callback, "The ptr pass to LITE api is null"); | |||||
| auto lite_async_callback = [async_callback, user_data]() -> void { | |||||
| async_callback(user_data); | |||||
| }; | |||||
| static_cast<lite::Network*>(network)->set_async_callback( | |||||
| std::move(lite_async_callback)); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| int LITE_set_start_callback( | int LITE_set_start_callback( | ||||
| LiteNetwork network, const LiteStartCallback start_callback) { | LiteNetwork network, const LiteStartCallback start_callback) { | ||||
| LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
| @@ -381,6 +397,34 @@ int LITE_set_start_callback( | |||||
| LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
| } | } | ||||
| int LITE_set_start_callback_with_userdata( | |||||
| LiteNetwork network, const LiteStartCallbackWithData start_callback, | |||||
| void* user_data) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
| auto lite_start_callback = | |||||
| [start_callback, | |||||
| user_data](const std::unordered_map< | |||||
| std::string, | |||||
| std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& inputs_map) | |||||
| -> void { | |||||
| std::vector<LiteIO> ios; | |||||
| std::vector<LiteTensor> io_tensors; | |||||
| size_t nr_io = 0; | |||||
| for (const auto& io : inputs_map) { | |||||
| nr_io++; | |||||
| auto&& lite_io = io.second.first; | |||||
| ios.push_back( | |||||
| {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type, | |||||
| convert_to_clayout(lite_io.config_layout)}); | |||||
| io_tensors.push_back(io.second.second.get()); | |||||
| } | |||||
| start_callback(ios.data(), io_tensors.data(), nr_io, user_data); | |||||
| }; | |||||
| static_cast<lite::Network*>(network)->set_start_callback(lite_start_callback); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| int LITE_set_finish_callback( | int LITE_set_finish_callback( | ||||
| LiteNetwork network, const LiteFinishCallback finish_callback) { | LiteNetwork network, const LiteFinishCallback finish_callback) { | ||||
| LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
| @@ -407,6 +451,34 @@ int LITE_set_finish_callback( | |||||
| LITE_CAPI_END(); | LITE_CAPI_END(); | ||||
| } | } | ||||
| int LITE_set_finish_callback_with_userdata( | |||||
| LiteNetwork network, const LiteFinishCallbackWithData finish_callback, | |||||
| void* user_data) { | |||||
| LITE_CAPI_BEGIN(); | |||||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||||
| auto lite_finish_callback = | |||||
| [finish_callback, | |||||
| user_data](const std::unordered_map< | |||||
| std::string, | |||||
| std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>& | |||||
| outputs_map) -> void { | |||||
| std::vector<LiteIO> ios; | |||||
| std::vector<LiteTensor> io_tensors; | |||||
| size_t nr_io = 0; | |||||
| for (const auto& io : outputs_map) { | |||||
| nr_io++; | |||||
| auto&& lite_io = io.second.first; | |||||
| ios.push_back( | |||||
| {lite_io.name.c_str(), lite_io.is_host, lite_io.io_type, | |||||
| convert_to_clayout(lite_io.config_layout)}); | |||||
| io_tensors.push_back(io.second.second.get()); | |||||
| } | |||||
| finish_callback(ios.data(), io_tensors.data(), nr_io, user_data); | |||||
| }; | |||||
| static_cast<lite::Network*>(network)->set_finish_callback(lite_finish_callback); | |||||
| LITE_CAPI_END(); | |||||
| } | |||||
| int LITE_enable_profile_performance( | int LITE_enable_profile_performance( | ||||
| LiteNetwork network, const char* profile_json_file_path) { | LiteNetwork network, const char* profile_json_file_path) { | ||||
| LITE_CAPI_BEGIN(); | LITE_CAPI_BEGIN(); | ||||
| @@ -74,11 +74,21 @@ int multi_thread_affinity(int id) { | |||||
| }; | }; | ||||
| volatile bool finished = false; | volatile bool finished = false; | ||||
| int finish_callback() { | |||||
| int async_callback() { | |||||
| finished = true; | finished = true; | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| volatile bool finished_with_data = false; | |||||
| int async_callback_with_data(void* user_data) { | |||||
| if (user_data != NULL) { | |||||
| std::cout << "async_callback user_data addr=" << std::hex << user_data | |||||
| << std::endl; | |||||
| } | |||||
| finished_with_data = true; | |||||
| return 0; | |||||
| } | |||||
| volatile bool start_checked = false; | volatile bool start_checked = false; | ||||
| int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t size) { | int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t size) { | ||||
| start_checked = true; | start_checked = true; | ||||
| @@ -96,6 +106,29 @@ int start_callback(const LiteIO* inputs, const LiteTensor* input_tensors, size_t | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| volatile bool start_checked_with_data = false; | |||||
| int start_callback_with_data( | |||||
| const LiteIO* inputs, const LiteTensor* input_tensors, size_t size, | |||||
| void* user_data) { | |||||
| start_checked_with_data = true; | |||||
| auto check_func = [&]() { | |||||
| if (user_data != NULL) { | |||||
| std::cout << "start_callback user_data addr=" << std::hex << user_data | |||||
| << std::endl; | |||||
| } | |||||
| ASSERT_EQ(size, 1); | |||||
| ASSERT_EQ(std::string(inputs->name), "data"); | |||||
| LiteLayout layout; | |||||
| LITE_get_tensor_layout(*input_tensors, &layout); | |||||
| ASSERT_EQ(layout.ndim, 4); | |||||
| ASSERT_EQ(layout.shapes[1], 3); | |||||
| ASSERT_EQ(layout.shapes[2], 224); | |||||
| ASSERT_EQ(layout.shapes[3], 224); | |||||
| }; | |||||
| check_func(); | |||||
| return 0; | |||||
| } | |||||
| volatile bool finish_checked = false; | volatile bool finish_checked = false; | ||||
| int finish_callback( | int finish_callback( | ||||
| const LiteIO* outputs, const LiteTensor* output_tensors, size_t size) { | const LiteIO* outputs, const LiteTensor* output_tensors, size_t size) { | ||||
| @@ -113,6 +146,28 @@ int finish_callback( | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| volatile bool finish_checked_with_data = false; | |||||
| int finish_callback_with_data( | |||||
| const LiteIO* outputs, const LiteTensor* output_tensors, size_t size, | |||||
| void* user_data) { | |||||
| finish_checked_with_data = true; | |||||
| auto check_func = [&]() { | |||||
| if (user_data != NULL) { | |||||
| std::cout << "finish_callback user_data addr=" << std::hex << user_data | |||||
| << std::endl; | |||||
| } | |||||
| ASSERT_EQ(size, 1); | |||||
| ASSERT_EQ( | |||||
| std::string(outputs->name), | |||||
| "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); | |||||
| LiteLayout layout; | |||||
| LITE_get_tensor_layout(*output_tensors, &layout); | |||||
| ASSERT_EQ(layout.shapes[1], 1000); | |||||
| }; | |||||
| check_func(); | |||||
| return 0; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| #define LITE_CAPI_CHECK(_expr) \ | #define LITE_CAPI_CHECK(_expr) \ | ||||
| @@ -671,6 +726,21 @@ TEST(TestCapiNetWork, StartCallBack) { | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | ||||
| } | } | ||||
| TEST(TestCapiNetWork, StartCallBackWithData) { | |||||
| ForwardMgb; | |||||
| MakeNetwork; | |||||
| LoadNetwork; | |||||
| size_t user_data = 1; | |||||
| LITE_CAPI_CHECK(LITE_set_start_callback_with_userdata( | |||||
| c_network, start_callback_with_data, &user_data)); | |||||
| SetInput; | |||||
| ForwardNetwork; | |||||
| GetOutput; | |||||
| CompareResult; | |||||
| ASSERT_TRUE(start_checked_with_data); | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||||
| } | |||||
| TEST(TestCapiNetWork, FinishCallBack) { | TEST(TestCapiNetWork, FinishCallBack) { | ||||
| ForwardMgb; | ForwardMgb; | ||||
| MakeNetwork; | MakeNetwork; | ||||
| @@ -684,6 +754,21 @@ TEST(TestCapiNetWork, FinishCallBack) { | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | ||||
| } | } | ||||
| TEST(TestCapiNetWork, FinishCallBackWtihData) { | |||||
| ForwardMgb; | |||||
| MakeNetwork; | |||||
| LoadNetwork; | |||||
| size_t user_data = 1; | |||||
| LITE_CAPI_CHECK(LITE_set_finish_callback_with_userdata( | |||||
| c_network, finish_callback_with_data, &user_data)); | |||||
| SetInput; | |||||
| ForwardNetwork; | |||||
| GetOutput; | |||||
| CompareResult; | |||||
| ASSERT_TRUE(finish_checked_with_data); | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||||
| } | |||||
| TEST(TestCapiNetWork, BasicCryptAes) { | TEST(TestCapiNetWork, BasicCryptAes) { | ||||
| ForwardMgb; | ForwardMgb; | ||||
| @@ -723,7 +808,7 @@ TEST(TestCapiNetWork, AsyncExec) { | |||||
| LiteConfig c_config = *default_config(); | LiteConfig c_config = *default_config(); | ||||
| c_config.options.var_sanity_check_first_run = false; | c_config.options.var_sanity_check_first_run = false; | ||||
| LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io())); | LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io())); | ||||
| LITE_CAPI_CHECK(LITE_set_async_callback(c_network, finish_callback)); | |||||
| LITE_CAPI_CHECK(LITE_set_async_callback(c_network, async_callback)); | |||||
| LoadNetwork; | LoadNetwork; | ||||
| SetInput; | SetInput; | ||||
| @@ -740,6 +825,32 @@ TEST(TestCapiNetWork, AsyncExec) { | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | ||||
| } | } | ||||
| TEST(TestCapiNetWork, AsyncExecWithData) { | |||||
| finished = false; | |||||
| ForwardMgb; | |||||
| LiteNetwork c_network; | |||||
| LiteConfig c_config = *default_config(); | |||||
| c_config.options.var_sanity_check_first_run = false; | |||||
| LITE_CAPI_CHECK(LITE_make_network(&c_network, c_config, *default_network_io())); | |||||
| size_t user_data = 1; | |||||
| LITE_CAPI_CHECK(LITE_set_async_callback_with_userdata( | |||||
| c_network, async_callback_with_data, &user_data)); | |||||
| LoadNetwork; | |||||
| SetInput; | |||||
| LITE_forward(c_network); | |||||
| size_t count = 0; | |||||
| while (finished_with_data == false) { | |||||
| count++; | |||||
| } | |||||
| ASSERT_GT(count, 0); | |||||
| finished_with_data = false; | |||||
| GetOutput; | |||||
| CompareResult; | |||||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||||
| } | |||||
| TEST(TestCapiNetWork, OutputShapeOnly) { | TEST(TestCapiNetWork, OutputShapeOnly) { | ||||
| ForwardMgb; | ForwardMgb; | ||||
| LiteNetwork c_network; | LiteNetwork c_network; | ||||