Browse Source

refactor(mge): trace exception in compiled info

GitOrigin-RevId: 508f5463b9
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
d4ada69d3b
4 changed files with 30 additions and 6 deletions
  1. +7
    -4
      imperative/python/megengine/jit/tracing.py
  2. +13
    -1
      imperative/python/src/tensor.cpp
  3. +10
    -0
      imperative/python/src/trace.h
  4. +0
    -1
      imperative/python/test/unit/test_tracing.py

+ 7
- 4
imperative/python/megengine/jit/tracing.py View File

@@ -414,7 +414,7 @@ class trace:
for x in escaped_tensors: for x in escaped_tensors:
try: try:
assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
except TraceMismatchError:
except RuntimeError:
# TraceMismatchError thrown in do_exit # TraceMismatchError thrown in do_exit
pass pass
self._graph.wait() self._graph.wait()
@@ -954,7 +954,8 @@ class CompiledTensorProxy:
elif self.__info.data_read: elif self.__info.data_read:
self.__shape = self._dev_tensor().shape self.__shape = self._dev_tensor().shape
else: else:
raise TraceMismatchError("shape of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
return self.__shape return self.__shape


def numpy(self): def numpy(self):
@@ -964,7 +965,8 @@ class CompiledTensorProxy:
elif self.__info.data_read: elif self.__info.data_read:
self.__value = self._dev_tensor().numpy() self.__value = self._dev_tensor().numpy()
else: else:
raise TraceMismatchError("value of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
if self._isscalar: if self._isscalar:
self.__value = self.__value.squeeze() self.__value = self.__value.squeeze()
return self.__value return self.__value
@@ -972,7 +974,8 @@ class CompiledTensorProxy:
def _dev_tensor(self): def _dev_tensor(self):
if self.__data is None: if self.__data is None:
if not self.__info.data_read: if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
self.__data = self.__info.data_reader.get_value() self.__data = self.__info.data_reader.get_value()
return self.__data return self.__data




+ 13
- 1
imperative/python/src/tensor.cpp View File

@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0); return PyTuple_New(0);
} }
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
if (shp == Py_None) {
throw TraceReadError("shape of this tensor is not read in trace");
}
return shp;
} }
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val)); np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
} }
@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){ PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) { if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh)); m_tensor->m_handle = std::move(SharedHandle(sh));

return dev_tensor;
} }
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());


+ 10
- 0
imperative/python/src/trace.h View File

@@ -10,9 +10,19 @@
*/ */


#include "./tensor.h" #include "./tensor.h"
#include <stdexcept>


namespace mgb::imperative::python { namespace mgb::imperative::python {


class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char * m) : message{m} {}
const char * what() const noexcept override {return message.c_str();}
private:
std::string message = "";
};


apply_result_t apply_trace(ApplyContext& ctx); apply_result_t apply_trace(ApplyContext& ctx);


} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 0
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -311,7 +311,6 @@ def test_trace_warp_perspective():
f(x, M) f(x, M)




@pytest.mark.skip(reason="skip")
def test_raise_on_trace(): def test_raise_on_trace():
step_count = 0 step_count = 0
catch_count = 0 catch_count = 0


Loading…
Cancel
Save