GitOrigin-RevId: 52b8a29932
tags/v1.8.0
| @@ -420,6 +420,7 @@ def warp_affine( | |||||
| Here all available options for params are listed, | Here all available options for params are listed, | ||||
| however it does not mean that you can use all the combinations. | however it does not mean that you can use all the combinations. | ||||
| On different platforms, different combinations are supported. | On different platforms, different combinations are supported. | ||||
| ``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed. | |||||
| """ | """ | ||||
| conv_format = _config._get_actual_op_param(format, _config.__conv_format) | conv_format = _config._get_actual_op_param(format, _config.__conv_format) | ||||
| @@ -104,6 +104,7 @@ class TensorInfo: | |||||
| "shape", | "shape", | ||||
| "is_const", | "is_const", | ||||
| "bound_data", | "bound_data", | ||||
| "bound_data_numpy", | |||||
| # resources for execution | # resources for execution | ||||
| "varnode", | "varnode", | ||||
| "data_setter", | "data_setter", | ||||
| @@ -119,12 +120,18 @@ class TensorInfo: | |||||
| self.shape_read = None | self.shape_read = None | ||||
| self.value_read = None | self.value_read = None | ||||
| self.bound_data = None | self.bound_data = None | ||||
| self.bound_data_numpy = None | |||||
| self.data_setter = None | self.data_setter = None | ||||
| self.shape_reader = None | self.shape_reader = None | ||||
| self.value_reader = None | self.value_reader = None | ||||
| self.data_reader = None | self.data_reader = None | ||||
| def get_numpy(self): | |||||
| if self.bound_data_numpy is None: | |||||
| self.bound_data_numpy = self.bound_data.numpy() | |||||
| return self.bound_data_numpy | |||||
| _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} | _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} | ||||
| @@ -292,7 +299,7 @@ class trace: | |||||
| # Const op is represented by a str | # Const op is represented by a str | ||||
| assert isinstance(op_, str) and op_ == "Const" | assert isinstance(op_, str) and op_ == "Const" | ||||
| expected = self._tinfo[ohandles[0]].bound_data.numpy() | |||||
| expected = self._tinfo[ohandles[0]].get_numpy() | |||||
| shape = value.shape | shape = value.shape | ||||
| if shape != expected.shape or dtype != expected.dtype: | if shape != expected.shape or dtype != expected.dtype: | ||||
| eq = False | eq = False | ||||
| @@ -369,6 +376,7 @@ class trace: | |||||
| info.dtype = x.dtype | info.dtype = x.dtype | ||||
| info.shape = x.shape | info.shape = x.shape | ||||
| info.bound_data = x | info.bound_data = x | ||||
| info.bound_data_numpy = None | |||||
| info.is_const = True | info.is_const = True | ||||
| x._mixin_handle = h | x._mixin_handle = h | ||||
| x._recording = True | x._recording = True | ||||
| @@ -612,9 +620,7 @@ class trace: | |||||
| assert info.external | assert info.external | ||||
| assert info.bound_data | assert info.bound_data | ||||
| info.varnode = graph.make_const( | info.varnode = graph.make_const( | ||||
| info.bound_data.numpy(), | |||||
| info.bound_data.dtype, | |||||
| info.bound_data.device, | |||||
| info.get_numpy(), info.bound_data.dtype, info.bound_data.device, | |||||
| ) | ) | ||||
| continue | continue | ||||
| @@ -627,7 +633,7 @@ class trace: | |||||
| if info.bound_data: | if info.bound_data: | ||||
| if getattr(info, "is_const", False): | if getattr(info, "is_const", False): | ||||
| info.varnode = graph.make_const( | info.varnode = graph.make_const( | ||||
| info.bound_data.numpy(), | |||||
| info.get_numpy(), | |||||
| info.bound_data.dtype, | info.bound_data.dtype, | ||||
| info.bound_data.device, | info.bound_data.device, | ||||
| ) | ) | ||||
| @@ -1174,7 +1180,7 @@ class trace: | |||||
| assert info.external | assert info.external | ||||
| assert info.bound_data | assert info.bound_data | ||||
| h2v[h] = graph.make_const( | h2v[h] = graph.make_const( | ||||
| info.bound_data.numpy(), | |||||
| info.get_numpy(), | |||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device(info), | device=dumped_device(info), | ||||
| name=info.name, | name=info.name, | ||||
| @@ -1187,7 +1193,7 @@ class trace: | |||||
| assert info.external | assert info.external | ||||
| assert info.bound_data | assert info.bound_data | ||||
| h2v[h] = graph.make_const( | h2v[h] = graph.make_const( | ||||
| info.bound_data.numpy(), | |||||
| info.get_numpy(), | |||||
| dtype=info.dtype, | dtype=info.dtype, | ||||
| device=dumped_device(info), | device=dumped_device(info), | ||||
| name=info.name, | name=info.name, | ||||
| @@ -1074,6 +1074,10 @@ void init_tensor(py::module m) { | |||||
| []() { | []() { | ||||
| interpreter_for_py->sync(); | interpreter_for_py->sync(); | ||||
| CompNode::sync_all(); | CompNode::sync_all(); | ||||
| CompNode::foreach ([](CompNode cn) { | |||||
| auto err = cn.check_async_error(); | |||||
| mgb_assert(!err, "%s", err->what()); | |||||
| }); | |||||
| sync_py_task_q(); | sync_py_task_q(); | ||||
| }, | }, | ||||
| py::call_guard<py::gil_scoped_release>()); | py::call_guard<py::gil_scoped_release>()); | ||||
| @@ -96,6 +96,15 @@ def test_regression_2870(): | |||||
| (x + x).numpy() | (x + x).numpy() | ||||
| @pytest.mark.require_ngpu(1) | |||||
| def test_async_error_check(): | |||||
| src = mge.tensor([[1.0, 2.0]]) | |||||
| index = mge.tensor([3]) | |||||
| val = F.indexing_one_hot(src, index) | |||||
| with pytest.raises(RuntimeError): | |||||
| val.numpy() | |||||
| # NOTE: DO NOT REMOVE THIS TEST | # NOTE: DO NOT REMOVE THIS TEST | ||||
| # This is also a compatibility test for | # This is also a compatibility test for | ||||
| # mge.core.set_option('async_level', 0). | # mge.core.set_option('async_level', 0). | ||||
| @@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { | |||||
| if (m_async_level == 0) { | if (m_async_level == 0) { | ||||
| sync_impl(); | sync_impl(); | ||||
| info->desc.comp_node.sync(); | info->desc.comp_node.sync(); | ||||
| auto err = info->desc.comp_node.check_async_error(); | |||||
| mgb_assert(!err, "%s", err->what()); | |||||
| } | } | ||||
| return info; | return info; | ||||
| } | } | ||||
| @@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel( | |||||
| for (auto&& oup : *outputs) { | for (auto&& oup : *outputs) { | ||||
| auto info = reinterpret_cast<TensorInfo*>(oup); | auto info = reinterpret_cast<TensorInfo*>(oup); | ||||
| info->ptr->comp_node().sync(); | info->ptr->comp_node().sync(); | ||||
| auto err = info->ptr->comp_node().check_async_error(); | |||||
| mgb_assert(!err, "%s", err->what()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||||
| MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); | MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); | ||||
| bool require_host = prop == TensorProp::HostValue; | bool require_host = prop == TensorProp::HostValue; | ||||
| auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; | auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; | ||||
| if (require_host && !host_available()) { | |||||
| bool wait_host = !host_available(); | |||||
| if (require_host && wait_host) { | |||||
| // avoid dead lock | // avoid dead lock | ||||
| lock.unlock(); | lock.unlock(); | ||||
| m_buffer.enqueue(GetValue{info}); | m_buffer.enqueue(GetValue{info}); | ||||
| @@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||||
| }); | }); | ||||
| MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | ||||
| m_waitee = nullptr; | m_waitee = nullptr; | ||||
| if (require_host && wait_host) { | |||||
| auto err = info->ptr->comp_node().check_async_error(); | |||||
| mgb_assert(!err, "%s", err->what()); | |||||
| } | |||||
| return info->ptr; | return info->ptr; | ||||
| } | } | ||||