Browse Source

!8113 complete element_wise augassign with out he supprot Ascend OP

Merge pull request !8113 from yepei6/master_augassign_element
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9a2da99df2
3 changed files with 85 additions and 19 deletions
  1. +2
    -1
      mindspore/common/dtype.py
  2. +34
    -7
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  3. +49
    -11
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py

+ 2
- 1
mindspore/common/dtype.py View File

@@ -38,7 +38,7 @@ __dtype__ = [
"number", "tensor",
"string", "type_none",
"tensor_type",
"Type"
"Type", "Int"
]

__method__ = [
@@ -104,6 +104,7 @@ tuple_type = typing.Tuple
index_slices = typing.RowTensorType()
sparse_tensor = typing.SparseTensorType()
undetermined = typing.UndeterminedType()
Int = typing.Int

number_type = (int8,
int16,


+ 34
- 7
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -271,7 +271,7 @@ def tensor_index_by_tuple(data, tuple_index):
if len(tuple_index) == 1:
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:
return _tensor_index_by_tuple_slice(data, tuple_index)
if index_elements_type == const_utils.ALL_TENSOR:
@@ -437,13 +437,16 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
@@ -498,16 +501,37 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
data_shape = data.shape
tuple_index_new = ()
for i, index in enumerate(tuple_index):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
const_utils.raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_index_new += (tuple_index[i]+data_shape[i],)
else:
tuple_index_new += (tuple_index[i],)
else:
tuple_index_new += (tuple_index[i],)

indexes_types = hyper_map(F.typeof, tuple_index_new)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
tuple_index_new,
const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new)
new_shape = ()
for _ in tuple_index_new:
new_shape += (1,)
new_shape += value.shape
value = F.reshape(value, new_shape)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
tuple_index_new,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data,
indices,
@@ -522,13 +546,16 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
data[tuple_index[0]] = value
return data
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)


+ 49
- 11
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -28,6 +28,9 @@ ALL_TENSOR = 0
NO_TENSOR = 1
CONTAIN_TENSOR = 2
ALL_SCALAR = 3
ALL_INT = 4
NO_INT = 5
CONTAIN_INT = 6

INT_ = 0
BOOL_ = 1
@@ -72,6 +75,35 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
value_shape, data_shape))


@constexpr
def restrict_int_index(data_shape, tuple_indexes):
"""
Check the int index of tuple_indexes if value of index is out of the corresponding data shape
and turn the negtive int index to positive int index.

Inputs:
data_shape: the shape of data.
tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem.

Outputs:
tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive.
"""
if tuple_indexes is None:
return tuple_indexes
tuple_indexes_new = ()
for i, index in enumerate(tuple_indexes):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
const_utils.raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_indexes_new += (tuple_indexes[i]+data_shape[i],)
else:
tuple_indexes_new += (tuple_indexes[i],)
else:
tuple_indexes_new += (tuple_indexes[i],)
return tuple_indexes_new


@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Checks tuple index type of tensor assignment."""
@@ -276,17 +308,17 @@ def tuple_element_is_int(indexs):


@constexpr
def tuple_index_elements_type(types, op_name):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
for ele in types:
if isinstance(ele, mstype.tensor_type):
tensors_number += 1
if tensors_number == len(types):
return ALL_TENSOR
if tensors_number == 0:
return NO_TENSOR
return CONTAIN_TENSOR
def tuple_index_tensor_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
return ALL_TENSOR if tensor_cnt == len(types) else NO_TENSOR if tensor_cnt == 0 else CONTAIN_TENSOR
@constexpr
def tuple_index_int_cnt(types, op_name):
"""count the int type of types which contains the tuple elements' type."""
int_cnt = sum(isinstance(ele, mstype.Int) for ele in types)
return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT


@constexpr
@@ -406,6 +438,12 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
return tuple(new_shape)


@constexpr
def convert_int_to_slice(tuple_indexes):
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
return tuple_indexes_new


@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,


Loading…
Cancel
Save