diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index 05765da243..939ddbc2b6 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -497,8 +497,6 @@ def column_stack(tup): return tup if not _check_is_list(tup) and not _check_is_tuple(tup): _raise_type_error("Tensor or, list or tuple of tensors are required, but got ", tup) - if not tup: - _raise_value_error("Need at least one tensor to concatenate.") trans_tup = () for tensor in tup: @@ -507,7 +505,9 @@ def column_stack(tup): if tensor.ndim == 1: tensor = F.expand_dims(tensor, 1) trans_tup += (tensor,) - return P.Concat(axis=1)(trans_tup) + if not trans_tup: + _raise_value_error("Need at least one tensor to concatenate.") + return P.Concat(1)(trans_tup) def vstack(tup): @@ -545,15 +545,15 @@ def vstack(tup): return tup if not _check_is_list(tup) and not _check_is_tuple(tup): _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) - if not tup: - _raise_value_error("Need at least one tensor to concatenate.") trans_tup = () for tensor in tup: if tensor.ndim <= 1: tensor = _expand(tensor, 2, 0) trans_tup += (tensor,) - return P.Concat(axis=0)(trans_tup) + if not trans_tup: + _raise_value_error("Need at least one tensor to concatenate.") + return P.Concat(0)(trans_tup) def hstack(tup): @@ -590,19 +590,18 @@ def hstack(tup): if _check_is_tensor(F.typeof(tup)): return tup if not _check_is_list(tup) and not _check_is_tuple(tup): - _raise_type_error(f"Tensor or, list or tuple of tensors are required, but got", tup) - if not tup: - _raise_value_error("Need at least one tensor to concatenate.") + _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) tuple_of_tensor = () for tensor in tup: if tensor.ndim < 1: tensor = F.expand_dims(tensor, 0) tuple_of_tensor += (tensor,) - + if not tuple_of_tensor: + _raise_value_error("Need at least one tensor to concatenate.") if tuple_of_tensor[0].ndim <= 1: - return P.Concat(axis=0)(tuple_of_tensor) - return P.Concat(axis=1)(tuple_of_tensor) + return P.Concat(0)(tuple_of_tensor) + return P.Concat(1)(tuple_of_tensor) def dstack(tup): @@ -641,8 +640,6 @@ def dstack(tup): return tup if not _check_is_list(tup) and not _check_is_tuple(tup): _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) - if not tup: - _raise_value_error("Need at least one tensor to concatenate.") trans_tup = () for tensor in tup: @@ -651,7 +648,9 @@ def dstack(tup): if tensor.ndim == 2: tensor = F.expand_dims(tensor, 2) trans_tup += (tensor,) - return P.Concat(axis=2)(trans_tup) + if not trans_tup: + _raise_value_error("Need at least one tensor to concatenate.") + return P.Concat(2)(trans_tup) def where(condition, x=None, y=None):