|
|
|
@@ -316,7 +316,11 @@ PyObject* TensorWrapper::shape() { |
|
|
|
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { |
|
|
|
return PyTuple_New(0); |
|
|
|
} |
|
|
|
return PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); |
|
|
|
PyObject *shp = PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape"); |
|
|
|
if (shp == Py_None) { |
|
|
|
throw TraceReadError("shape of this tensor is not read in trace"); |
|
|
|
} |
|
|
|
return shp; |
|
|
|
} |
|
|
|
if (m_tensor->m_trace_info.recording && !skip_tracing) { |
|
|
|
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "shape_read", py::cast(true).release().ptr()); |
|
|
|
@@ -367,6 +371,9 @@ PyObject* TensorWrapper::device() { |
|
|
|
PyObject* TensorWrapper::numpy() { |
|
|
|
if (m_tensor->m_trace_info.compiled_info != nullptr) { |
|
|
|
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr); |
|
|
|
if (np_val == Py_None) { |
|
|
|
throw TraceReadError("value of this tensor is not read in trace"); |
|
|
|
} |
|
|
|
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { |
|
|
|
np_val = PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val)); |
|
|
|
} |
|
|
|
@@ -445,9 +452,14 @@ PyObject* TensorWrapper::detach() { |
|
|
|
PyObject* TensorWrapper::_dev_tensor(){ |
|
|
|
if (m_tensor->m_trace_info.compiled_info != nullptr) { |
|
|
|
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr); |
|
|
|
if (dev_tensor == Py_None) { |
|
|
|
throw TraceReadError("raw data of this tensor is not read in trace"); |
|
|
|
} |
|
|
|
auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor); |
|
|
|
auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>()); |
|
|
|
m_tensor->m_handle = std::move(SharedHandle(sh)); |
|
|
|
|
|
|
|
return dev_tensor; |
|
|
|
} |
|
|
|
if (m_tensor->m_trace_info.recording && !skip_tracing) { |
|
|
|
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); |
|
|
|
|