Browse Source

get concat offset refactor

tags/v0.3.0-alpha
zhaozhenlong 6 years ago
parent
commit
0e98430f4d
1 changed files with 10 additions and 10 deletions
  1. +10
    -10
      mindspore/ops/_utils/utils.py

+ 10
- 10
mindspore/ops/_utils/utils.py View File

@@ -15,7 +15,7 @@

"""utils for operator"""

from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype

@@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
return broadcast_shape


def _get_concat_offset(x_shp, x_type, axis):
def _get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis

Loading…
Cancel
Save