GitOrigin-RevId: 70ddc06eee
tags/v1.5.0
| @@ -86,16 +86,22 @@ def _broadcast(inp, shape): | |||||
| def _reshape(x, shape): | def _reshape(x, shape): | ||||
| shape_tuple = _make_shape_tuple(shape) | |||||
| unspec_axis = None | unspec_axis = None | ||||
| # XXX: assume unspec_axis is not changed in trace | |||||
| for i, s in enumerate(shape_tuple): | |||||
| if s < 0: | |||||
| if s != -1: | |||||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
| if unspec_axis is not None: | |||||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||||
| unspec_axis = i | |||||
| try: | |||||
| shape_tuple = _make_shape_tuple(shape) | |||||
| except ValueError: | |||||
| pass | |||||
| else: | |||||
| # XXX: assume unspec_axis is not changed in trace | |||||
| for i, s in enumerate(shape_tuple): | |||||
| if s < 0: | |||||
| if s != -1: | |||||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||||
| if unspec_axis is not None: | |||||
| raise ValueError( | |||||
| "multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||||
| ) | |||||
| unspec_axis = i | |||||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | ||||
| if unspec_axis is None: | if unspec_axis is None: | ||||
| op = builtin.Reshape() | op = builtin.Reshape() | ||||
| @@ -18,9 +18,9 @@ from .utils import astensor1d, isscalar, make_shape_tuple | |||||
| def remove_ellipsis(tensor, tuple_val): | def remove_ellipsis(tensor, tuple_val): | ||||
| ndim_sum = tensor.ndim | |||||
| cur_sum = 0 | cur_sum = 0 | ||||
| pos = -1 | pos = -1 | ||||
| has_unkown_ndim_bool_index = False | |||||
| for i_idx, i in enumerate(tuple_val): | for i_idx, i in enumerate(tuple_val): | ||||
| if i is Ellipsis: | if i is Ellipsis: | ||||
| for j in tuple_val[:i_idx:-1]: | for j in tuple_val[:i_idx:-1]: | ||||
| @@ -28,10 +28,28 @@ def remove_ellipsis(tensor, tuple_val): | |||||
| raise IndexError("only one ellipsis is allowed") | raise IndexError("only one ellipsis is allowed") | ||||
| pos = i_idx | pos = i_idx | ||||
| else: | else: | ||||
| cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||||
| try: | |||||
| cur_sum += ( | |||||
| i.ndim | |||||
| if hasattr(i, "dtype") | |||||
| and i.dtype == np.bool_ | |||||
| and hasattr(i, "ndim") | |||||
| else 1 | |||||
| ) | |||||
| except ValueError: | |||||
| has_unkown_ndim_bool_index = True | |||||
| if pos == -1: | if pos == -1: | ||||
| return tuple_val | return tuple_val | ||||
| else: | else: | ||||
| if has_unkown_ndim_bool_index: | |||||
| raise IndexError( | |||||
| "Does not support bool index with unknown shape when using Ellipsis" | |||||
| ) | |||||
| try: | |||||
| ndim_sum = tensor.ndim | |||||
| except ValueError: | |||||
| raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.") | |||||
| return ( | return ( | ||||
| tuple_val[:pos] | tuple_val[:pos] | ||||
| + (slice(None, None, None),) * (ndim_sum - cur_sum) | + (slice(None, None, None),) * (ndim_sum - cur_sum) | ||||
| @@ -41,7 +59,11 @@ def remove_ellipsis(tensor, tuple_val): | |||||
| # XXX: assume same results during trace | # XXX: assume same results during trace | ||||
| def check_bool_index(tensor, tuple_val): | def check_bool_index(tensor, tuple_val): | ||||
| cur_shape = make_shape_tuple(tensor.shape) | |||||
| try: | |||||
| cur_shape = make_shape_tuple(tensor.shape) | |||||
| except ValueError: | |||||
| return tensor, tuple_val | |||||
| new_tuple_val = [] | new_tuple_val = [] | ||||
| offset = 0 | offset = 0 | ||||
| tdim = 0 | tdim = 0 | ||||
| @@ -92,20 +114,31 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| ndim_indexed_scalar = 0 | ndim_indexed_scalar = 0 | ||||
| for i in tuple_val: | for i in tuple_val: | ||||
| if not i is Ellipsis: | if not i is Ellipsis: | ||||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||||
| ndim_indexed += ( | |||||
| i.ndim | |||||
| if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||||
| else 1 | |||||
| ) | |||||
| if isscalar(i): | if isscalar(i): | ||||
| ndim_indexed_scalar += 1 | ndim_indexed_scalar += 1 | ||||
| if ndim_indexed > inp.ndim: | |||||
| raise IndexError( | |||||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||||
| inp.ndim, ndim_indexed | |||||
| ret_scalar = False | |||||
| try: | |||||
| ret_scalar = ndim_indexed_scalar == inp.ndim | |||||
| except ValueError: | |||||
| # inp.ndim is unknown | |||||
| pass | |||||
| else: | |||||
| if ndim_indexed > inp.ndim: | |||||
| raise IndexError( | |||||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||||
| inp.ndim, len(tuple_val) | |||||
| ) | |||||
| ) | ) | ||||
| ) | |||||
| tuple_val = remove_ellipsis(inp, tuple_val) | tuple_val = remove_ellipsis(inp, tuple_val) | ||||
| use_subtensor = True | use_subtensor = True | ||||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||||
| if inp.shape is not None: | |||||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||||
| new_axes = [] | new_axes = [] | ||||
| tensors = [] | tensors = [] | ||||
| @@ -186,7 +219,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
| items.append(item) | items.append(item) | ||||
| if new_axes: | if new_axes: | ||||
| raise IndexError("newaxis is not allowed here") | raise IndexError("newaxis is not allowed here") | ||||
| return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||||
| return inp, tensors, items, use_subtensor, ret_scalar | |||||
| def try_condtake(tensor, index): | def try_condtake(tensor, index): | ||||
| @@ -249,16 +282,21 @@ def setitem(tensor, index, value): | |||||
| op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
| (tmp_result,) = apply(op, tensor, *tensors) | (tmp_result,) = apply(op, tensor, *tensors) | ||||
| for i in range(min(len(value.shape), len(tmp_result.shape))): | |||||
| if (value.shape[-i - 1] != 1) & ( | |||||
| value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||||
| ): | |||||
| raise ValueError( | |||||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
| value.shape, tmp_result.shape | |||||
| try: | |||||
| value_shape = value._tuple_shape | |||||
| tmp_result_shape = tmp_result._tuple_shape | |||||
| except ValueError: | |||||
| pass | |||||
| else: | |||||
| for i in range(min(len(value_shape), len(tmp_result_shape))): | |||||
| if (value_shape[-i - 1] != 1) & ( | |||||
| value_shape[-i - 1] != tmp_result_shape[-i - 1] | |||||
| ): | |||||
| raise ValueError( | |||||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||||
| value_shape, tmp_result_shape | |||||
| ) | |||||
| ) | ) | ||||
| ) | |||||
| value = value._broadcast(tmp_result.shape) | value = value._broadcast(tmp_result.shape) | ||||
| if use_subtensor: | if use_subtensor: | ||||
| @@ -137,6 +137,13 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| ndim = x.ndim | ndim = x.ndim | ||||
| except AttributeError: | except AttributeError: | ||||
| pass | pass | ||||
| except ValueError: | |||||
| if dtype is not None and dtype != x.dtype: | |||||
| x = astype(x, dtype) | |||||
| if device is not None: | |||||
| cn = as_device(device).to_c() | |||||
| (x,) = apply(builtin.Copy(comp_node=cn), x) | |||||
| return x | |||||
| else: | else: | ||||
| if ndim != 0 and ndim != 1: | if ndim != 0 and ndim != 1: | ||||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | raise ValueError("ndim != 1 or 0, get : %d" % ndim) | ||||
| @@ -148,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| raise TypeError | raise TypeError | ||||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | ||||
| x = concatenate(x, device=device) | |||||
| x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||||
| if dtype is not None: | if dtype is not None: | ||||
| x = astype(x, dtype) | x = astype(x, dtype) | ||||
| return x | return x | ||||
| @@ -849,8 +849,15 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
| return list(map(int, axis)) | return list(map(int, axis)) | ||||
| axis = get_axes() | axis = get_axes() | ||||
| ndim = inp.ndim + len(axis) | |||||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
| try: | |||||
| ndim = inp.ndim + len(axis) | |||||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||||
| except ValueError: | |||||
| if any([ind < 0 for ind in axis]): | |||||
| raise IndexError( | |||||
| "Does not support negative index when tensor's ndim is unknown" | |||||
| ) | |||||
| axis = sorted(axis) | |||||
| assert axis, "axis could not be empty" | assert axis, "axis could not be empty" | ||||
| if inp._isscalar(): | if inp._isscalar(): | ||||
| assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | ||||
| @@ -384,6 +384,11 @@ PyObject* TensorWrapper::shape() { | |||||
| TensorShape shape; | TensorShape shape; | ||||
| if (m_tensor->m_var) { // get shape from m_var | if (m_tensor->m_var) { // get shape from m_var | ||||
| auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | ||||
| auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||||
| Py_RETURN_NONE; | |||||
| } | |||||
| auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | ||||
| if (!tshp) { | if (!tshp) { | ||||
| Py_RETURN_NONE; | Py_RETURN_NONE; | ||||
| @@ -878,6 +883,24 @@ void init_tensor(py::module m) { | |||||
| ->static_infer_manager(); | ->static_infer_manager(); | ||||
| return mgr.infer_shape_fallible(v->m_node); | return mgr.infer_shape_fallible(v->m_node); | ||||
| }) | }) | ||||
| .def("numpy", [](PySymbolVar* v){ | |||||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||||
| auto&& type = mgr.get_infer_type(v->m_node); | |||||
| using InferType = cg::static_infer::InferType; | |||||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||||
| throw py::value_error("value invalid!"); | |||||
| } | |||||
| auto* val = mgr.infer_value_fallible(v->m_node); | |||||
| if (!val) { | |||||
| throw py::value_error("value invalid!"); | |||||
| } | |||||
| auto np_val = py::cast(*val).attr("numpy")(); | |||||
| if (v->is_scalar) { | |||||
| return py::object(py::array(np_val).squeeze()); | |||||
| } | |||||
| return np_val; | |||||
| }) | |||||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | ||||
| .def("_setscalar", | .def("_setscalar", | ||||
| [](PySymbolVar* v) { return v->is_scalar = true; }) | [](PySymbolVar* v) { return v->is_scalar = true; }) | ||||