Browse Source

!1407 support mixed tensor index for tensor get item and set item and support in operator.

Merge pull request !1407 from zhangbuxue/support_mixed_tensor_for_tensor_get_item_and_tensor_set_item
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ad279e90fd
14 changed files with 1023 additions and 234 deletions
  1. +1
    -1
      mindspore/_extends/parse/resources.py
  2. +2
    -0
      mindspore/ccsrc/ir/dtype_extends.cc
  3. +2
    -0
      mindspore/common/dtype.py
  4. +3
    -1
      mindspore/ops/composite/multitype_ops/__init__.py
  5. +154
    -0
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  6. +244
    -98
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  7. +25
    -14
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  8. +101
    -0
      mindspore/ops/composite/multitype_ops/in_impl.py
  9. +83
    -61
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  10. +0
    -1
      mindspore/ops/operations/array_ops.py
  11. +0
    -5
      tests/mindspore_test_framework/components/executor/exec_forward.py
  12. +23
    -4
      tests/ut/python/dtype/test_list.py
  13. +20
    -3
      tests/ut/python/dtype/test_tuple.py
  14. +365
    -46
      tests/ut/python/ops/test_tensor_slice.py

+ 1
- 1
mindspore/_extends/parse/resources.py View File

@@ -105,7 +105,7 @@ convert_object_map = {
T.ge: multitype_ops.greater_equal,
T.is_: F.is_,
T.is_not: F.is_not,
T.contains: F.in_dict,
T.contains: multitype_ops.in_,
T.not_contains: F.not_in_dict,

# system function


+ 2
- 0
mindspore/ccsrc/ir/dtype_extends.cc View File

@@ -474,6 +474,8 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<RefKeyType, Type, std::shared_ptr<RefKeyType>>(m_sub, "RefKeyType").def(py::init());
(void)py::class_<RefType, Type, std::shared_ptr<RefType>>(m_sub, "RefType").def(py::init());
(void)py::class_<TypeAnything, Type, std::shared_ptr<TypeAnything>>(m_sub, "TypeAnything").def(py::init());
(void)py::class_<Slice, Type, std::shared_ptr<Slice>>(m_sub, "Slice").def(py::init());
(void)py::class_<Ellipsis, Type, std::shared_ptr<Ellipsis>>(m_sub, "Ellipsis").def(py::init());
}));

const TypePtr kTypeExternal = std::make_shared<External>();


+ 2
- 0
mindspore/common/dtype.py View File

@@ -95,6 +95,8 @@ string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType
anything_type = typing.TypeAnything
slice_type = typing.Slice
ellipsis_type = typing.Ellipsis

number_type = (int8,
int16,


+ 3
- 1
mindspore/ops/composite/multitype_ops/__init__.py View File

@@ -37,6 +37,7 @@ from .logical_and_impl import logical_and
from .logical_or_impl import logical_or
from .logic_not_impl import logical_not
from .uadd_impl import uadd
from .in_impl import in_
__all__ = [
'add',
'sub',
@@ -59,5 +60,6 @@ __all__ = [
'setitem',
'logical_and',
'logical_or',
'logical_not'
'logical_not',
'in_'
]

+ 154
- 0
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""constexpr util"""
from . import _constexpr_utils as const_utils
from ... import functional as F
from ... import operations as P
from ...composite import base
from ....common import dtype as mstype

hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)


def broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
return F.tile(x, multiples)
return x


def transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
"""Transform indexing tensor to the required."""
x = broadcast(broadcast_shape, x)
return broadcast(final_shape, F.reshape(x, new_shape))


def generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices


def generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
for i in int_positions:
tuple_index = F.tuple_setitem(tuple_index, i, F.scalar_to_tensor(tuple_index[i], mstype.int32))
indexes_types = hyper_map(F.typeof, tuple_index)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index[i])
for j in slice_positions:
slice_indexes.append(tuple_index[j])
data_shape = F.shape(data)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)

slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = transform_indexing_tensor(broadcast_shape,
final_shape,
index_tensor_new_shape,
tuple_index[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number,
final_shape,
indexes_shapes_info,
op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices


def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)


def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)

data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)


def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
if check_dtype_same:
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return broadcast(updates_shape, value)
return value

mindspore/ops/composite/multitype_ops/_utils.py → mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -15,19 +15,15 @@

"""constexpr util"""
from functools import reduce

import numpy as np

from ...primitive import constexpr
from ....common.tensor import Tensor
from ....common import dtype as mstype
from .... import log as logger
from ...._extends.utils import Slice, Ellipsis_
from ....common import dtype as mstype
from ....common.tensor import Tensor
from ....ops import _utils as op_utils
from ...composite import base
from .... import log as logger
from ... import functional as F
from ... import operations as P

hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)

ALL_TENSOR = 0
NO_TENSOR = 1
@@ -264,7 +260,7 @@ def tuple_index_elements_type(types, op_name):
return ALL_TENSOR
if tensors_number == 0:
return NO_TENSOR
raise IndexError(f"For '{op_name}', the index does not support mixed tensor.")
return CONTAIN_TENSOR


@constexpr
@@ -279,12 +275,12 @@ def check_value_elements(data_dtype, types):
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with origin tensor data type '{data_dtype}'.")
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.issubclass_(ele, data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with origin tensor data type '{data_dtype}'.")
f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
if tensors_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
@@ -299,51 +295,46 @@ def get_index_tensor_dtype(dtype):
return INT_
if dtype == mstype.bool_:
return BOOL_
raise TypeError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
raise IndexError(f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")


@constexpr
def check_index_tensors_dtype(dtypes, op_name):
"""Check a tuple of tensor data type."""
if op_name == TENSOR_GETITEM:
valid_dtypes = (mstype.int32, mstype.int64)
elif op_name == TENSOR_SETITEM:
valid_dtypes = (mstype.int32,)
else:
raise ValueError("Unsupported operation.")
for ele in dtypes:
if ele in valid_dtypes and ele == dtypes[0]:
continue
raise TypeError(f"For '{op_name}', the index tensors data type must be same, "
f"and should be one of the following: {valid_dtypes}, but got {dtypes}.")
if not ele == mstype.int32:
raise IndexError(f"For '{op_name}', the all index tensor "
f"data types should be mstype.int32, but got {dtypes}.")
return True


@constexpr
def check_tensor_dtype_valid(dtype, valid_dtypes):
def check_index_tensor_dtype(dtype, op_name):
"""Check a tensor data type."""
if dtype in valid_dtypes:
if dtype == mstype.int32:
return True
raise TypeError(f"The index tensor data type must be one of "
f"the following: {valid_dtypes}, but got {dtype}.")
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")


@constexpr
def check_tensors_dtype_same(x_dtype, y_dtype, op_name):
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
if x_dtype == y_dtype:
if value_dtype == data_dtype:
return True
raise TypeError(f"For '{op_name}', the value data type '{y_dtype}' "
f"is not consistent with origin tensor data type {x_dtype}.")
raise TypeError(f"For '{op_name}', the value data type '{value_dtype}' "
f"is not consistent with assigned tensor data type {data_dtype}.")


@constexpr
def broadcast_shapes(shapes, op_name):
"""Broadcasts a tuple of tensor."""
def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape."""
broadcast_shape = shapes[0]
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
try:
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
except ValueError as ex:
raise IndexError(ex)
return tuple(broadcast_shape)


@@ -366,14 +357,82 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):

@constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between broadcast_shape with origin_shape."""
"""Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))


def tile(broadcast_shape, x):
multiples = compute_multiples(F.shape(x), broadcast_shape)
return F.tile(x, multiples)
@constexpr
def compute_new_shape(origin_shape, indexes_shapes_info):
"""Compute new shape between origin shape with final shape."""
new_shape = []
for i in indexes_shapes_info:
if i == origin_shape:
new_shape.extend(origin_shape)
else:
new_shape.append(1)
return tuple(new_shape)


@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name):
"""Convert an ellipsis to a list of tensor."""
tensor_list = []
dims_dealt_count = 0
while dims_dealt_count < ellipsis_occupied_dims:
shape = []
slice_count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if slice_count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
shape.append(1)
slice_count += 1
if isinstance(ele, tuple):
shape.extend([1] * len(ele))
if array is None:
raise ValueError(f"For '{op_name}', generate tensors from ellipsis failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor_list.append(tensor)
slice_number += 1
dims_dealt_count += 1
return tensor_list


@constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
"""Convert a slice to a tensor."""
shape = []
count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
# When the slice is not the slice looking for, the shape is filled with 1.
shape.append(1)
count += 1
elif isinstance(ele, tuple):
shape.extend([1] * len(ele))
else:
shape.append(1)
if array is None:
raise ValueError(f"For '{op_name}', generate tensor from 'slice' failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
return tensor


@constexpr
@@ -381,8 +440,8 @@ def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
if shape != value_shapes[0]:
raise ValueError(f"For '{op_name}', the {i}th tensor shape in value tuple "
f"is not same as the first tensor shape.")
raise ValueError(f"For '{op_name}', the {i}th tensor shape in "
f"value tuple is not same as the first tensor shape.")
return True


@@ -396,7 +455,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with tensor data type {data_dtype}.")
f" is not consistent with the assigned tensor data type {data_dtype}.")


@constexpr
@@ -404,8 +463,8 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
if len(value) != updates_shape[-1]:
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} in the updates tuple "
f"does not meet the requirements: {updates_shape[-1]}.")
raise ValueError(f"For '{TENSOR_SETITEM}', the number of elements : {len(value)} "
f"in the updates tuple does not meet the requirements: {updates_shape[-1]}.")
array = np.array(value, dtype=mstype.dtype_to_nptype(data_dtype))
reps = compute_multiples(updates_shape[-1:], updates_shape)
return Tensor(np.tile(array, reps))
@@ -430,58 +489,145 @@ def check_number_of_index_tensor(data_shape, tuple_len, op_name):
f"is greater than the dimension {len(data_shape)} of the operated tensor.")


def generate_indeices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = broadcast_shapes(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(tile, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
return indices


def generate_updates_from_scalar(data, indices, value, op_type):
"""Generate an updates tensor from a scalar."""
data_shape = F.shape(data)
indices_shape = F.shape(indices)
data_dtype = F.dtype(data)
return convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)


def generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = check_value_elements(data_dtype, value_types)
if value_elements_type == ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = check_shapes_same(value_shapes, TENSOR_SETITEM)
if shapes_same:
value = F.pack(value)
return generate_updates_from_tensor(data, index, value, op_type)

data_shape = F.shape(data)
index_shape = F.shape(index)
return convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)


def generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = check_tensors_dtype_same(data_dtype, value_dtype, TENSOR_SETITEM)
if check_dtype_same:
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = check_two_shapes_need_broadcast(updates_shape, value_shape)
if need_broadcast:
return tile(updates_shape, value)
return value
@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name):
"""
Generate index info which contain broadcast shape, final shape,
indexes shapes info, ellipsis size from a tuple of mixed tensors.
"""
check_index_tensors_dtype(tensor_indexes_dtypes, op_name)
data_rank = len(data_shape)
indexes_size = len(indexes_types)
if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info = {}
index_tensors_info = {}
ellipsis_num = 0
ellipsis_occupied_dims = 0
tensor_count = 0
slice_count = 0
for i, ele_type in enumerate(indexes_types):
if ellipsis_num == 0:
pos = i
else:
pos = i + ellipsis_occupied_dims - 1
if isinstance(ele_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(ele_type, mstype.slice_type):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].end,
slice_indexes[slice_count].step)
# Use list to represent slicing result.
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
raise IndexError(f"For '{op_name}', the index could only contain one ellipsis.")
ellipsis_occupied_dims = data_rank - indexes_size + 1
for j in range(pos, pos + ellipsis_occupied_dims):
# Use list to represent slicing result.
indexes_info[j] = list(range(data_shape[j]))
ellipsis_num += 1
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
broadcast_shape, final_shape, indexes_shapes_info = \
_derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims


def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
"""Determine whether the tensor in the index appears continuously."""
for i in range(len(index_tensor_info_key) - 1):
if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1:
return False
return True


def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name):
"""Derive the resulting shape information from the a tuple index of mixed tensors."""
index_tensor_info_key = list(index_tensors_info.keys())
index_tensor_info_value = list(index_tensors_info.values())
broadcast_shape = generate_broadcast_shape(index_tensor_info_value, op_name)
final_shape = []
indexes_shapes_info = []
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key)
if mixed_tensors_continuous:
tensor_shape_dealt = False
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
if not tensor_shape_dealt:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
tensor_shape_dealt = True
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
else:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
continue
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)


@constexpr
def get_pos_of_int_index(indexes_types):
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
int_positions = []
for i, ele_type in enumerate(indexes_types):
if ele_type == mstype.int32:
int_positions.append(i)
return int_positions


@constexpr
def separate_mixed_tensors_index(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
tensor_positions = []
slice_positions = []
ellipsis_position = None
for i, ele_type in enumerate(indexes_types):
if isinstance(ele_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(ele_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'Slice', 'Ellipsis', but got {ele_type}.")

return tensor_positions, slice_positions, ellipsis_position


@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the scalar is not.")
if y is None:
raise ValueError("Judge scalar in tuple or list require scalar and sequence should be constant, "
"but the sequence is not.")
if x in y:
return True
return False

+ 25
- 14
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -14,11 +14,11 @@
# ============================================================================

"""Implementation for getitem."""
from . import _utils as multi_utils
from ..import base
from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from .. import base
from ... import functional as F
from ....common import dtype as mstype

getitem = base.MultitypeFuncGraph('getitem')
"""
@@ -227,7 +227,8 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
check_dtypes = multi_utils.check_tensor_dtype_valid(F.dtype(tensor_index), (mstype.int32, mstype.int64))
check_dtypes = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
result = None
if check_dtypes:
result = F.gather(data, tensor_index, 0)
@@ -246,14 +247,13 @@ def _tensor_getitem_by_tuple(data, tuple_index):
Outputs:
Tensor, element type is same as the element type of data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_GETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_slice(data, tuple_index)
if index_elements_type == multi_utils.ALL_TENSOR:
result = _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index)


@getitem.register("Tensor", "Ellipsis")
@@ -273,6 +273,17 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):

def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_GETITEM)
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result

+ 101
- 0
mindspore/ops/composite/multitype_ops/in_impl.py View File

@@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""in_impl"""

from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base

in_ = base.MultitypeFuncGraph("in")
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator
"""


@in_.register("Number", "Tuple")
def _number_in_tuple(x, y):
"""
Determine if a number in tuple.

Args:
x (Number): x
y (tuple): y

Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)


@in_.register("Number", "List")
def _number_in_list(x, y):
"""
Determine if a number in list.

Args:
x (Number): x
y (list): y

Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)


@in_.register("String", "Tuple")
def _string_in_tuple(x, y):
"""
Determine if a str in a tuple.

Args:
x (str): x
y (tuple): y

Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)


@in_.register("String", "List")
def _string_in_list(x, y):
"""
Determine if a str in a list.

Args:
x (str): x
y (list): y

Returns:
bool, if x in y return true, x not in y return false.
"""
return const_utils.scalar_in_sequence(x, y)


@in_.register("String", "Dictionary")
def _str_in_dict(x, y):
"""
Determine if a str in dict.

Args:
x: str
y: dict

Returns:
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)

+ 83
- 61
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -15,10 +15,11 @@

"""Implementation for setitem."""

from . import _compile_utils as compile_utils
from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
from ....common import dtype as mstype
from ... import functional as F
from . import _utils as multi_utils

setitem = base.MultitypeFuncGraph('setitem')

@@ -139,8 +140,8 @@ def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.INT_:
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.INT_:
return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)

@@ -166,8 +167,8 @@ def _tensor_setitem_by_tensor_with_number(data, index, value):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
tensor_dtype = multi_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == multi_utils.BOOL_:
tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
if tensor_dtype == const_utils.BOOL_:
return _tensor_setitem_by_bool_tensor_with_scalar(data, index, value)
return _tensor_setitem_by_int_tensor_with_scalar(data, index, value)

@@ -190,17 +191,24 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_scalar(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_number(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)


@setitem.register("Tensor", "Tuple", "Tensor")
@@ -221,17 +229,24 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.NO_TENSOR:
result = _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tensor(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.NO_TENSOR:
return _tensor_assgin_tensor(data, tuple_index, value)
if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)


@setitem.register("Tensor", "Tuple", "Tuple")
@@ -253,15 +268,22 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
index_types = multi_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = multi_utils.tuple_index_elements_type(index_types, multi_utils.TENSOR_SETITEM)
result = None
if index_elements_type == multi_utils.ALL_TENSOR:
indices = multi_utils.generate_indeices_from_tuple_of_tensor(data, tuple_index, multi_utils.TENSOR_SETITEM)
updates = multi_utils.generate_updates_from_tuple(data, indices, value,
multi_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
result = F.scatter_nd_update(data, indices, updates)
return result
indexes_types = compile_utils.hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = compile_utils.generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
indices = compile_utils.generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = compile_utils.generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return F.scatter_nd_update(data, indices, updates)


@setitem.register("Tensor", "Tensor", "Tuple")
@@ -278,7 +300,7 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Tensor, element type and shape is same as data.
"""
index_dtype = F.dtype(index)
check_dtype = multi_utils.check_tensor_dtype_valid(index_dtype, (mstype.int32, mstype.int64))
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
@@ -331,14 +353,14 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):

def _tensor_assgin_number(data, input_slice, value):
"""Givens a scalar assign to tensor by slice"""
check_result = multi_utils.check_tensor_setitem_index(input_slice)
check_result = const_utils.check_tensor_setitem_index(input_slice)
result = None
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_number(data, data_shape, input_slice, indices, value)
return result

@@ -347,7 +369,7 @@ def _tensor_assgin_number(data, input_slice, value):
def _tensor_setitem_with_int_v1(data, index, value):
"""Syntax: A[1] = 3"""
data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)


@@ -355,7 +377,7 @@ def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_with_int_v2(data, index, value):
"""Syntax: A[1] = Tensor"""
data_shape = F.shape(data)
indices = multi_utils.integer_to_indices(index, data_shape)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)


@@ -376,7 +398,7 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
data_size = F.size(data)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = multi_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
@@ -391,13 +413,13 @@ def _tensor_setitem_with_ellipsis_v2(data, index, value):
def _tensor_assgin_tensor(data, input_slice, value):
"""Assigns a tensor value to the tensor by slice."""
result = None
check_result = multi_utils.check_tensor_setitem_index(input_slice)
check_result = const_utils.check_tensor_setitem_index(input_slice)
if check_result:
data_shape = F.shape(data)
indices = multi_utils.slice2indices(input_slice, data_shape)
is_tuple_int = multi_utils.tuple_element_is_int(input_slice)
indices = const_utils.slice2indices(input_slice, data_shape)
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = multi_utils.integer_to_indices(input_slice, data_shape)
indices = const_utils.integer_to_indices(input_slice, data_shape)
result = _tensor_indices_tensor(data, data_shape, input_slice, indices, value)
return result

@@ -407,7 +429,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
@@ -415,7 +437,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
value_fill = None
value_size = F.size(value)

value_size = multi_utils.check_indices_value_size(indices_size, value_size)
value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
value_fill = F.fill(data_dtype, (indices_size,), 1)
value = F.cast(value, data_dtype)
@@ -432,7 +454,7 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
data_size = F.size(data)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = multi_utils.check_indices(indices_size, index)
indices_size = const_utils.check_indices(indices_size, index)
update = F.fill(mstype.int32, (indices_size,), 1)
condition_1d = F.scatter_nd(indices, update, (data_size,))
condition = F.reshape(condition_1d, data_shape)
@@ -445,16 +467,16 @@ def _tensor_indices_number(data, data_shape, index, indices, value):

def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = multi_utils.generate_updates_from_tuple(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_tuple(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
result = F.scatter_update(data, index, updates)
return result


def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
updates = multi_utils.generate_updates_from_scalar(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_scalar(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)


@@ -462,7 +484,7 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):
"""Set a tensor item by a bool tensor with a scalar."""
index_shape = F.shape(index)
shape = F.shape(data)
shape = multi_utils.check_equal(
shape = const_utils.check_equal(
shape, index_shape, "The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
dtype = F.dtype(data)
u = F.fill(dtype, shape, value)
@@ -471,8 +493,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):

def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
"""Set a tensor item by a int tensor with a tensor."""
updates = multi_utils.generate_updates_from_tensor(data, index, value,
multi_utils.SET_ITEM_BY_ONE_TENSOR)
updates = compile_utils.generate_updates_from_tensor(data, index, value,
const_utils.SET_ITEM_BY_ONE_TENSOR)
return F.scatter_update(data, index, updates)


@@ -480,10 +502,10 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
"""Set a tensor item by a bool tensor with a tensor."""
index_shape = F.shape(index)
data_shape = F.shape(data)
data_shape = multi_utils.check_equal(data_shape, index_shape,
data_shape = const_utils.check_equal(data_shape, index_shape,
"The tensor(shape={}) and tensor index(shape={}) should be the same shape.")
size = F.size(value)
size = multi_utils.check_equal(1, size,
size = const_utils.check_equal(1, size,
"When assign value is a tensor, its size should be {}, but current size is {}.")
dtype = F.dtype(data)
u_cast = F.cast(value, dtype)


+ 0
- 1
mindspore/ops/operations/array_ops.py View File

@@ -1419,7 +1419,6 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name):
validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]


+ 0
- 5
tests/mindspore_test_framework/components/executor/exec_forward.py View File

@@ -33,9 +33,4 @@ class IdentityEC(IExectorComponent):
keyword.desc_inputs: self.inputs[keyword.desc_inputs],
keyword.result: self.function[keyword.block](*self.inputs[keyword.desc_inputs])
}
print("buxue------------------------------------------------")
print("inputs")
print(ret[keyword.desc_inputs])
print("outputs")
print(ret[keyword.result])
return ret

tests/ut/python/ops/test_list.py → tests/ut/python/dtype/test_list.py View File

@@ -19,9 +19,9 @@ import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config


@@ -133,7 +133,7 @@ def test_list_append_2():


class ListOperate(nn.Cell):
def __init__(self,):
def __init__(self, ):
super(ListOperate, self).__init__()

def construct(self, t, l):
@@ -152,6 +152,20 @@ class ListOperate(nn.Cell):
return x


class InListNet(nn.Cell):
def __init__(self, ):
super(InListNet, self).__init__()
self.list_ = [1, 2, 3, 4, 5, "ok"]

def construct(self, x):
ret = x
if 2 in self.list_:
ret = x + x
if "ok" in self.list_:
ret = x - x
return ret


class AxisListNet(nn.Cell):
def __init__(self):
super(AxisListNet, self).__init__()
@@ -204,10 +218,15 @@ test_case_ops = [
('AxisListDefault', {
'block': AxisListDefaultNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
]

test_case_lists = [test_case_ops]
test_exec_case = functools.reduce(lambda x, y: x + y, test_case_lists)


# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm


tests/ut/python/ops/test_tuple.py → tests/ut/python/dtype/test_tuple.py View File

@@ -19,9 +19,9 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
from tests.ut.python.ut_filter import non_graph_engine
from tests.mindspore_test_framework.mindspore_test import mindspore_test
from tests.mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config

context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
@@ -52,6 +52,20 @@ class NestTupleGraphNet(nn.Cell):
return self.layers[0][1](x)


class InTupleNet(nn.Cell):
def __init__(self, ):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")

def construct(self, x):
ret = x
if 2 in self.tuple_:
ret = x + x
if "ok" in self.tuple_:
ret = x - x
return ret


test_case_ops = [
('TupleGraph', {
'block': TupleGraphNet(),
@@ -59,6 +73,9 @@ test_case_ops = [
('NestTupleGraph', {
'block': NestTupleGraphNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
]

test_case_lists = [test_case_ops]

+ 365
- 46
tests/ut/python/ops/test_tensor_slice.py View File

@@ -176,12 +176,134 @@ class TensorGetItemByThreeTensors(Cell):
return ret


class TensorGetItemByMixedTensors(Cell):
class TensorGetItemByMixedTensors_0(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors, self).__init__()
super(TensorGetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))

def construct(self, tensor, index_0, index_1):
ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
return ret


class TensorGetItemByMixedTensors_1(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))

def construct(self, tensor, index_0, index_1):
ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
return ret


class TensorGetItemByMixedTensors_2(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))

def construct(self, tensor, index_0, index_1):
ret = tensor[0, index_0, index_1, ..., 3] + self.const
return ret


class TensorGetItemByMixedTensors_3(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_3, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))

def construct(self, tensor, index_0, index_1):
ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
return ret


class TensorGetItemByMixedTensors_4(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_4, self).__init__()
self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))

def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
return ret


class TensorGetItemByMixedTensors_5(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_5, self).__init__()
self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))

def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
return ret


class TensorGetItemByMixedTensors_6(Cell):
def __init__(self):
super(TensorGetItemByMixedTensors_6, self).__init__()
self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))

def construct(self, tensor, index_0, index_1, index_2):
ret = tensor[..., index_0, index_1, index_2, 3] + self.const
return ret


class TensorSetItemByMixedTensors_0(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_0, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
mstype.float32),
name="x")
self.value = value

def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
ret = self.param + self.const
return ret


class TensorSetItemByMixedTensors_1(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_1, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value

def construct(self, index_0, index_1, index_2):
self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
ret = self.param + self.const
return ret


class TensorSetItemByMixedTensors_2(Cell):
def __init__(self, value):
super(TensorSetItemByMixedTensors_2, self).__init__()
self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
name="x")
self.value = value

def construct(self, index_0, index_1, index_2):
self.param[..., index_0, index_1, index_2, 3] = self.value
ret = self.param + self.const
return ret


class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__()

def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:6]
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret


class TensorGetItemByMixedTensorsNumberError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsNumberError, self).__init__()

def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
return ret


@@ -189,7 +311,7 @@ class TensorSetItemByOneTensorWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value

def construct(self, index):
@@ -202,7 +324,7 @@ class TensorSetItemByOneTensorWithTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")

def construct(self, index, value):
self.param[index] = value
@@ -214,7 +336,7 @@ class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value

def construct(self, index):
@@ -227,7 +349,7 @@ class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")

def construct(self, index, value_0, value_1, value_2):
self.param[index] = (value_0, value_1, value_2)
@@ -239,7 +361,7 @@ class TensorSetItemByTensorsWithNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value

def construct(self, index_0, index_1, index_2):
@@ -252,7 +374,7 @@ class TensorSetItemByTensorsWithTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")

def construct(self, index_0, index_1, index_2, value):
self.param[index_0, index_1, index_2] = value
@@ -264,7 +386,7 @@ class TensorSetItemByTensorsWithTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")

def construct(self, index_0, index_1, index_2, index_3, value):
self.param[index_0, index_1, index_2, index_3] = value
@@ -276,7 +398,7 @@ class TensorSetItemByTensorsWithTupleOfNumber(Cell):
def __init__(self, value):
super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = value

def construct(self, index_0, index_1, index_2):
@@ -289,7 +411,7 @@ class TensorSetItemByTensorsWithTupleOfTensor(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")

def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
@@ -301,7 +423,7 @@ class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
def __init__(self):
super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")

def construct(self, index_0, index_1, index_2, value_0, value_1):
self.param[index_0, index_1, index_2] = (value_0, value_1)
@@ -313,7 +435,7 @@ class TensorSetItemByMixedTensors(Cell):
def __init__(self):
super(TensorSetItemByMixedTensors, self).__init__()
self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
self.value = 99.0

def construct(self, index_0, index_1):
@@ -538,11 +660,11 @@ def test_tensor_assign_bool_index():
net1(Ta, Tb, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Td, Tc, u_tensor)
with pytest.raises(TypeError):
with pytest.raises(IndexError):
net1(Ta, u_tensor, Tc, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Td, u_tensor)
with pytest.raises(TypeError):
with pytest.raises(IndexError):
net1(Ta, Tb, Ta, u_tensor)
with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error)
@@ -620,22 +742,67 @@ test_cases = [
}),
('TensorGetItemByOneTensor', {
'block': TensorGetItemByOneTensor(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
}),
('TensorGetItemByTwoTensors', {
'block': TensorGetItemByTwoTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByThreeTensors', {
'block': TensorGetItemByThreeTensors(),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_0', {
'block': TensorGetItemByMixedTensors_0(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_1', {
'block': TensorGetItemByMixedTensors_1(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_2', {
'block': TensorGetItemByMixedTensors_2(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_3', {
'block': TensorGetItemByMixedTensors_3(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_4', {
'block': TensorGetItemByMixedTensors_4(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_5', {
'block': TensorGetItemByMixedTensors_5(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors_6', {
'block': TensorGetItemByMixedTensors_6(),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumber', {
'block': TensorSetItemByOneTensorWithNumber(value=0.0),
'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
@@ -683,46 +850,143 @@ test_cases = [
Tensor(np.zeros((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)), mstype.float32),
Tensor(np.ones((4, 5)) * 2, mstype.float32)],
})
}),
('TensorSetItemByMixedTensorsWithNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
Tensor(np.zeros((4, 5, 3, 9), np.float32)),
Tensor(np.ones((4, 5, 3, 9), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=88.0),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float32)),
Tensor(np.zeros((4, 5), np.float32)),
Tensor(np.ones((4, 5), np.float32)))),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
]

raise_error_set = [
('TensorGetItemByOneTensorDtypeError', {
'block': (TensorGetItemByOneTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
}),
('TensorGetItemByTwoTensorsShapeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
}),
('TensorGetItemByTwoTensorsDtypeError', {
'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
}),
('TensorGetItemByThreeTensorsShapeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
}),
('TensorGetItemByThreeTensorsDtypeError', {
'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensors', {
'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
('TensorGetItemByMixedTensorsNumberError', {
'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)],
Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsTypeError', {
'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsDtypeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
}),
('TensorGetItemByMixedTensorsShapeError', {
'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
}),
('TensorSetItemByOneTensorWithNumberTypeError', {
'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
@@ -760,21 +1024,21 @@ raise_error_set = [
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorShapeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTensorTypeError', {
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
Tensor(np.zeros((4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTensorNumberError', {
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@@ -782,19 +1046,19 @@ raise_error_set = [
Tensor(np.zeros((2, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@@ -802,7 +1066,7 @@ raise_error_set = [
Tensor(np.ones((4, 5)), mstype.float32)],
}),
('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
@@ -810,10 +1074,65 @@ raise_error_set = [
Tensor(np.ones((4, 5)), mstype.int32),
Tensor(np.ones((4, 5)) * 2, mstype.int32)],
}),
('TensorSetItemByMixedTensors', {
'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
{'exception': ValueError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': TypeError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
}),
('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.zeros((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.float32)),
Tensor(np.ones((5, 2, 6), np.int32)))),
{'exception': IndexError}),
'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
})
]



Loading…
Cancel
Save