Browse Source

refactor(mge): trace exception in compiled info

GitOrigin-RevId: 508f5463b9
tags/v1.2.0
Megvii Engine Team 5 years ago
parent
commit
d4ada69d3b
4 changed files with 30 additions and 6 deletions
  1. +7
    -4
      imperative/python/megengine/jit/tracing.py
  2. +13
    -1
      imperative/python/src/tensor.cpp
  3. +10
    -0
      imperative/python/src/trace.h
  4. +0
    -1
      imperative/python/test/unit/test_tracing.py

+ 7
- 4
imperative/python/megengine/jit/tracing.py View File

@@ -414,7 +414,7 @@ class trace:
for x in escaped_tensors:
try:
assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
except TraceMismatchError:
except RuntimeError:
# TraceMismatchError thrown in do_exit
pass
self._graph.wait()
@@ -954,7 +954,8 @@ class CompiledTensorProxy:
elif self.__info.data_read:
self.__shape = self._dev_tensor().shape
else:
raise TraceMismatchError("shape of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
return self.__shape

def numpy(self):
@@ -964,7 +965,8 @@ class CompiledTensorProxy:
elif self.__info.data_read:
self.__value = self._dev_tensor().numpy()
else:
raise TraceMismatchError("value of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
if self._isscalar:
self.__value = self.__value.squeeze()
return self.__value
@@ -972,7 +974,8 @@ class CompiledTensorProxy:
def _dev_tensor(self):
if self.__data is None:
if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace")
# c++ will throw TraceReadError
return None
self.__data = self.__info.data_reader.get_value()
return self.__data



+ 13
- 1
imperative/python/src/tensor.cpp View File

@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() {
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
return PyTuple_New(0);
}
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
if (shp == Py_None) {
throw TraceReadError("shape of this tensor is not read in trace");
}
return shp;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr());
@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
}
@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>());
m_tensor->m_handle = std::move(SharedHandle(sh));

return dev_tensor;
}
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());


+ 10
- 0
imperative/python/src/trace.h View File

@@ -10,9 +10,19 @@
*/

#include "./tensor.h"
#include <stdexcept>

namespace mgb::imperative::python {

class TraceReadError : public std::exception {
public:
explicit TraceReadError(const char * m) : message{m} {}
const char * what() const noexcept override {return message.c_str();}
private:
std::string message = "";
};


apply_result_t apply_trace(ApplyContext& ctx);

} // namespace mgb::imperative::python

+ 0
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -311,7 +311,6 @@ def test_trace_warp_perspective():
f(x, M)


@pytest.mark.skip(reason="skip")
def test_raise_on_trace():
step_count = 0
catch_count = 0


Loading…
Cancel
Save