GitOrigin-RevId: 70ddc06eee
tags/v1.5.0
| @@ -86,16 +86,22 @@ def _broadcast(inp, shape): | |||
| def _reshape(x, shape): | |||
| shape_tuple = _make_shape_tuple(shape) | |||
| unspec_axis = None | |||
| # XXX: assume unspec_axis is not changed in trace | |||
| for i, s in enumerate(shape_tuple): | |||
| if s < 0: | |||
| if s != -1: | |||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
| if unspec_axis is not None: | |||
| raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | |||
| unspec_axis = i | |||
| try: | |||
| shape_tuple = _make_shape_tuple(shape) | |||
| except ValueError: | |||
| pass | |||
| else: | |||
| # XXX: assume unspec_axis is not changed in trace | |||
| for i, s in enumerate(shape_tuple): | |||
| if s < 0: | |||
| if s != -1: | |||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
| if unspec_axis is not None: | |||
| raise ValueError( | |||
| "multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
| ) | |||
| unspec_axis = i | |||
| shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| @@ -18,9 +18,9 @@ from .utils import astensor1d, isscalar, make_shape_tuple | |||
| def remove_ellipsis(tensor, tuple_val): | |||
| ndim_sum = tensor.ndim | |||
| cur_sum = 0 | |||
| pos = -1 | |||
| has_unkown_ndim_bool_index = False | |||
| for i_idx, i in enumerate(tuple_val): | |||
| if i is Ellipsis: | |||
| for j in tuple_val[:i_idx:-1]: | |||
| @@ -28,10 +28,28 @@ def remove_ellipsis(tensor, tuple_val): | |||
| raise IndexError("only one ellipsis is allowed") | |||
| pos = i_idx | |||
| else: | |||
| cur_sum += i.ndim if hasattr(i, "ndim") else 1 | |||
| try: | |||
| cur_sum += ( | |||
| i.ndim | |||
| if hasattr(i, "dtype") | |||
| and i.dtype == np.bool_ | |||
| and hasattr(i, "ndim") | |||
| else 1 | |||
| ) | |||
| except ValueError: | |||
| has_unkown_ndim_bool_index = True | |||
| if pos == -1: | |||
| return tuple_val | |||
| else: | |||
| if has_unkown_ndim_bool_index: | |||
| raise IndexError( | |||
| "Does not support bool index with unknown shape when using Ellipsis" | |||
| ) | |||
| try: | |||
| ndim_sum = tensor.ndim | |||
| except ValueError: | |||
| raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.") | |||
| return ( | |||
| tuple_val[:pos] | |||
| + (slice(None, None, None),) * (ndim_sum - cur_sum) | |||
| @@ -41,7 +59,11 @@ def remove_ellipsis(tensor, tuple_val): | |||
| # XXX: assume same results during trace | |||
| def check_bool_index(tensor, tuple_val): | |||
| cur_shape = make_shape_tuple(tensor.shape) | |||
| try: | |||
| cur_shape = make_shape_tuple(tensor.shape) | |||
| except ValueError: | |||
| return tensor, tuple_val | |||
| new_tuple_val = [] | |||
| offset = 0 | |||
| tdim = 0 | |||
| @@ -92,20 +114,31 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| ndim_indexed_scalar = 0 | |||
| for i in tuple_val: | |||
| if not i is Ellipsis: | |||
| ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim | |||
| ndim_indexed += ( | |||
| i.ndim | |||
| if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") | |||
| else 1 | |||
| ) | |||
| if isscalar(i): | |||
| ndim_indexed_scalar += 1 | |||
| if ndim_indexed > inp.ndim: | |||
| raise IndexError( | |||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
| inp.ndim, ndim_indexed | |||
| ret_scalar = False | |||
| try: | |||
| ret_scalar = ndim_indexed_scalar == inp.ndim | |||
| except ValueError: | |||
| # inp.ndim is unknown | |||
| pass | |||
| else: | |||
| if ndim_indexed > inp.ndim: | |||
| raise IndexError( | |||
| "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( | |||
| inp.ndim, len(tuple_val) | |||
| ) | |||
| ) | |||
| ) | |||
| tuple_val = remove_ellipsis(inp, tuple_val) | |||
| use_subtensor = True | |||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||
| if inp.shape is not None: | |||
| inp, tuple_val = check_bool_index(inp, tuple_val) | |||
| new_axes = [] | |||
| tensors = [] | |||
| @@ -186,7 +219,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
| items.append(item) | |||
| if new_axes: | |||
| raise IndexError("newaxis is not allowed here") | |||
| return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim | |||
| return inp, tensors, items, use_subtensor, ret_scalar | |||
| def try_condtake(tensor, index): | |||
| @@ -249,16 +282,21 @@ def setitem(tensor, index, value): | |||
| op = builtin.IndexingMultiAxisVec(items=items) | |||
| (tmp_result,) = apply(op, tensor, *tensors) | |||
| for i in range(min(len(value.shape), len(tmp_result.shape))): | |||
| if (value.shape[-i - 1] != 1) & ( | |||
| value.shape[-i - 1] != tmp_result.shape[-i - 1] | |||
| ): | |||
| raise ValueError( | |||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
| value.shape, tmp_result.shape | |||
| try: | |||
| value_shape = value._tuple_shape | |||
| tmp_result_shape = tmp_result._tuple_shape | |||
| except ValueError: | |||
| pass | |||
| else: | |||
| for i in range(min(len(value_shape), len(tmp_result_shape))): | |||
| if (value_shape[-i - 1] != 1) & ( | |||
| value_shape[-i - 1] != tmp_result_shape[-i - 1] | |||
| ): | |||
| raise ValueError( | |||
| "cannot copy tensor with shape {} to subtensor with shape {}".format( | |||
| value_shape, tmp_result_shape | |||
| ) | |||
| ) | |||
| ) | |||
| value = value._broadcast(tmp_result.shape) | |||
| if use_subtensor: | |||
| @@ -137,6 +137,13 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| ndim = x.ndim | |||
| except AttributeError: | |||
| pass | |||
| except ValueError: | |||
| if dtype is not None and dtype != x.dtype: | |||
| x = astype(x, dtype) | |||
| if device is not None: | |||
| cn = as_device(device).to_c() | |||
| (x,) = apply(builtin.Copy(comp_node=cn), x) | |||
| return x | |||
| else: | |||
| if ndim != 0 and ndim != 1: | |||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | |||
| @@ -148,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
| raise TypeError | |||
| if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
| x = concatenate(x, device=device) | |||
| x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||
| if dtype is not None: | |||
| x = astype(x, dtype) | |||
| return x | |||
| @@ -849,8 +849,15 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
| return list(map(int, axis)) | |||
| axis = get_axes() | |||
| ndim = inp.ndim + len(axis) | |||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
| try: | |||
| ndim = inp.ndim + len(axis) | |||
| axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
| except ValueError: | |||
| if any([ind < 0 for ind in axis]): | |||
| raise IndexError( | |||
| "Does not support negative index when tensor's ndim is unknown" | |||
| ) | |||
| axis = sorted(axis) | |||
| assert axis, "axis could not be empty" | |||
| if inp._isscalar(): | |||
| assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0]) | |||
| @@ -384,6 +384,11 @@ PyObject* TensorWrapper::shape() { | |||
| TensorShape shape; | |||
| if (m_tensor->m_var) { // get shape from m_var | |||
| auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||
| Py_RETURN_NONE; | |||
| } | |||
| auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var); | |||
| if (!tshp) { | |||
| Py_RETURN_NONE; | |||
| @@ -878,6 +883,24 @@ void init_tensor(py::module m) { | |||
| ->static_infer_manager(); | |||
| return mgr.infer_shape_fallible(v->m_node); | |||
| }) | |||
| .def("numpy", [](PySymbolVar* v){ | |||
| auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); | |||
| auto&& type = mgr.get_infer_type(v->m_node); | |||
| using InferType = cg::static_infer::InferType; | |||
| if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||
| throw py::value_error("value invalid!"); | |||
| } | |||
| auto* val = mgr.infer_value_fallible(v->m_node); | |||
| if (!val) { | |||
| throw py::value_error("value invalid!"); | |||
| } | |||
| auto np_val = py::cast(*val).attr("numpy")(); | |||
| if (v->is_scalar) { | |||
| return py::object(py::array(np_val).squeeze()); | |||
| } | |||
| return np_val; | |||
| }) | |||
| .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
| .def("_setscalar", | |||
| [](PySymbolVar* v) { return v->is_scalar = true; }) | |||