GitOrigin-RevId: 6abfa06ada
tags/v1.5.0
| @@ -246,4 +246,11 @@ tensor = Tensor | |||||
| class Parameter(Tensor): | class Parameter(Tensor): | ||||
| r""" | r""" | ||||
| A kind of Tensor that is to be considered a module parameter. | A kind of Tensor that is to be considered a module parameter. | ||||
| .. note:: | |||||
| Operations happened on Parameter usually return a Tensor instead of Parameter. | |||||
| For example, with a Parameter ``x``, ``x.reshape/to/sum/...`` will result into a Tensor. | |||||
| Any operations between Parameter and Tensor will have Tensor as outputs. | |||||
| """ | """ | ||||
| @@ -397,6 +397,10 @@ public: | |||||
| return Py_TYPE(op) == &m_type; | return Py_TYPE(op) == &m_type; | ||||
| } | } | ||||
| bool same_pytype(PyTypeObject *pt) { | |||||
| return pt == &m_type; | |||||
| } | |||||
| PyObject* finalize() { | PyObject* finalize() { | ||||
| if (!m_finalized) { | if (!m_finalized) { | ||||
| m_finalized = true; | m_finalized = true; | ||||
| @@ -140,6 +140,12 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| auto* op = args[0]; | auto* op = args[0]; | ||||
| PyTypeObject* pytype = args[1]->ob_type; | PyTypeObject* pytype = args[1]->ob_type; | ||||
| // check if pytype is Parameter(and all other python Tensor's derived class), | |||||
| // if yes, using it's tp_base(python Tensor) | |||||
| if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) { | |||||
| pytype = pytype->tp_base; | |||||
| } | |||||
| ++args; | ++args; | ||||
| --nargs; | --nargs; | ||||
| @@ -13,7 +13,7 @@ import pytest | |||||
| from utils import make_tensor | from utils import make_tensor | ||||
| from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | ||||
| from megengine.tensor import Tensor | |||||
| from megengine.tensor import Parameter, Tensor | |||||
| from megengine.utils.network import Network | from megengine.utils.network import Network | ||||
| @@ -198,3 +198,11 @@ def test_name(): | |||||
| assert x.name == "x" | assert x.name == "x" | ||||
| x = Tensor(0, name="x") | x = Tensor(0, name="x") | ||||
| assert x.name == "x" | assert x.name == "x" | ||||
| def test_tensor_type(): | |||||
| x1 = Parameter(1) | |||||
| x2 = Tensor(2) | |||||
| y1 = x1 + x2 | |||||
| y2 = x2 + x1 | |||||
| assert type(y1) == type(y2) | |||||