| @@ -242,7 +242,13 @@ int LITE_destroy_network(LiteNetwork network) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(network, "The network pass to LITE api is null"); | |||
| LITE_LOCK_GUARD(mtx_network); | |||
| get_gloabl_network_holder().erase(network); | |||
| auto& global_holder = get_gloabl_network_holder(); | |||
| if (global_holder.find(network) != global_holder.end()) { | |||
| global_holder.erase(network); | |||
| } else { | |||
| //! means the network has been destoryed | |||
| return -1; | |||
| } | |||
| LITE_CAPI_END(); | |||
| } | |||
| @@ -60,8 +60,10 @@ int LITE_make_tensor(const LiteTensorDesc tensor_describe, LiteTensor* tensor) { | |||
| auto lite_tensor = std::make_shared<lite::Tensor>( | |||
| tensor_describe.device_id, tensor_describe.device_type, layout, | |||
| tensor_describe.is_pinned_host); | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||
| { | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder()[lite_tensor.get()] = lite_tensor; | |||
| } | |||
| *tensor = lite_tensor.get(); | |||
| LITE_CAPI_END(); | |||
| } | |||
| @@ -70,7 +72,13 @@ int LITE_destroy_tensor(LiteTensor tensor) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(tensor, "The tensor pass to LITE c_api is null"); | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder().erase(tensor); | |||
| auto& global_holder = get_global_tensor_holder(); | |||
| if (global_holder.find(tensor) != global_holder.end()) { | |||
| global_holder.erase(tensor); | |||
| } else { | |||
| //! return -1, means the tensor has been destroyed. | |||
| return -1; | |||
| } | |||
| LITE_CAPI_END(); | |||
| } | |||
| @@ -126,8 +134,10 @@ int LITE_tensor_slice( | |||
| } | |||
| } | |||
| auto ret_tensor = static_cast<lite::Tensor*>(tensor)->slice(starts, ends, steps); | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||
| { | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder()[ret_tensor.get()] = ret_tensor; | |||
| } | |||
| *slice_tensor = ret_tensor.get(); | |||
| LITE_CAPI_END(); | |||
| } | |||
| @@ -226,12 +236,16 @@ int LITE_tensor_concat( | |||
| LiteTensor* tensors, int nr_tensor, int dim, LiteDeviceType dst_device, | |||
| int device_id, LiteTensor* result_tensor) { | |||
| LITE_CAPI_BEGIN(); | |||
| LITE_ASSERT(result_tensor, "The tensor pass to LITE c_api is null"); | |||
| std::vector<lite::Tensor> v_tensors; | |||
| for (int i = 0; i < nr_tensor; i++) { | |||
| v_tensors.push_back(*static_cast<lite::Tensor*>(tensors[i])); | |||
| } | |||
| auto tensor = lite::TensorUtils::concat(v_tensors, dim, dst_device, device_id); | |||
| get_global_tensor_holder()[tensor.get()] = tensor; | |||
| { | |||
| LITE_LOCK_GUARD(mtx_tensor); | |||
| get_global_tensor_holder()[tensor.get()] = tensor; | |||
| } | |||
| *result_tensor = tensor.get(); | |||
| LITE_CAPI_END() | |||
| } | |||
| @@ -476,7 +476,7 @@ def start_finish_callback(func): | |||
| def wrapper(c_ios, c_tensors, size): | |||
| ios = {} | |||
| for i in range(size): | |||
| tensor = LiteTensor() | |||
| tensor = LiteTensor(physic_construct=False) | |||
| tensor._tensor = c_void_p(c_tensors[i]) | |||
| tensor.update() | |||
| io = c_ios[i] | |||
| @@ -729,7 +729,7 @@ class LiteNetwork(object): | |||
| c_name = c_char_p(name.encode("utf-8")) | |||
| else: | |||
| c_name = c_char_p(name) | |||
| tensor = LiteTensor() | |||
| tensor = LiteTensor(physic_construct=False) | |||
| self._api.LITE_get_io_tensor( | |||
| self._network, c_name, phase, byref(tensor._tensor) | |||
| ) | |||
| @@ -233,6 +233,7 @@ class LiteTensor(object): | |||
| is_pinned_host=False, | |||
| shapes=None, | |||
| dtype=None, | |||
| physic_construct=True, | |||
| ): | |||
| self._tensor = _Ctensor() | |||
| self._layout = LiteLayout() | |||
| @@ -250,8 +251,10 @@ class LiteTensor(object): | |||
| tensor_desc.device_type = device_type | |||
| tensor_desc.device_id = device_id | |||
| tensor_desc.is_pinned_host = is_pinned_host | |||
| self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||
| self.update() | |||
| if physic_construct: | |||
| self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) | |||
| self.update() | |||
| def __del__(self): | |||
| self._api.LITE_destroy_tensor(self._tensor) | |||
| @@ -399,7 +402,7 @@ class LiteTensor(object): | |||
| c_start = (c_size_t * length)(*start) | |||
| c_end = (c_size_t * length)(*end) | |||
| c_step = (c_size_t * length)(*step) | |||
| slice_tensor = LiteTensor() | |||
| slice_tensor = LiteTensor(physic_construct=False) | |||
| self._api.LITE_tensor_slice( | |||
| self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor), | |||
| ) | |||
| @@ -560,7 +563,7 @@ def LiteTensorConcat( | |||
| length = len(tensors) | |||
| c_tensors = [t._tensor for t in tensors] | |||
| c_tensors = (_Ctensor * length)(*c_tensors) | |||
| result_tensor = LiteTensor() | |||
| result_tensor = LiteTensor(physic_construct=False) | |||
| api.LITE_tensor_concat( | |||
| cast(byref(c_tensors), POINTER(c_void_p)), | |||
| length, | |||
| @@ -1022,6 +1022,20 @@ TEST(TestCapiNetWork, TestShareWeights) { | |||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network2)); | |||
| } | |||
| TEST(TestCapiNetWork, GlobalHolder) { | |||
| std::string model_path = "./shufflenet.mge"; | |||
| LiteNetwork c_network; | |||
| LITE_CAPI_CHECK( | |||
| LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||
| auto destroy_network = c_network; | |||
| LITE_CAPI_CHECK( | |||
| LITE_make_network(&c_network, *default_config(), *default_network_io())); | |||
| //! make sure destroy_network is destroyed by LITE_make_network | |||
| LITE_destroy_network(destroy_network); | |||
| ASSERT_EQ(LITE_destroy_network(destroy_network), -1); | |||
| LITE_CAPI_CHECK(LITE_destroy_network(c_network)); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -251,6 +251,7 @@ TEST(TestCapiTensor, Slice) { | |||
| } | |||
| } | |||
| LITE_destroy_tensor(tensor); | |||
| LITE_destroy_tensor(slice_tensor); | |||
| }; | |||
| check(1, 8, 1, true); | |||
| check(1, 8, 1, false); | |||
| @@ -316,6 +317,21 @@ TEST(TestCapiTensor, ThreadLocalError) { | |||
| thread2.join(); | |||
| } | |||
| TEST(TestCapiTensor, GlobalHolder) { | |||
| LiteTensor c_tensor0; | |||
| LiteTensorDesc description = default_desc; | |||
| description.layout = LiteLayout{{20, 20}, 2, LiteDataType::LITE_FLOAT}; | |||
| LITE_make_tensor(description, &c_tensor0); | |||
| auto destroy_tensor = c_tensor0; | |||
| LITE_make_tensor(description, &c_tensor0); | |||
| //! make sure destroy_tensor is destroyed by LITE_make_tensor | |||
| LITE_destroy_tensor(destroy_tensor); | |||
| ASSERT_EQ(LITE_destroy_tensor(destroy_tensor), -1); | |||
| LITE_destroy_tensor(c_tensor0); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||