|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
|
|
|
|
"""utils for operator""" |
|
|
|
|
|
|
|
from ..._checkparam import ParamValidator as validator |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
|
|
|
|
@@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): |
|
|
|
return broadcast_shape |
|
|
|
|
|
|
|
|
|
|
|
def _get_concat_offset(x_shp, x_type, axis): |
|
|
|
def _get_concat_offset(x_shp, x_type, axis, prim_name): |
|
|
|
"""for concat and concatoffset check args and compute offset""" |
|
|
|
validator.check_type("shape", x_shp, [tuple]) |
|
|
|
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) |
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor) |
|
|
|
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) |
|
|
|
validator.check_value_type("shape", x_shp, [tuple], prim_name) |
|
|
|
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) |
|
|
|
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) |
|
|
|
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) |
|
|
|
rank_base = len(x_shp[0]) |
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) |
|
|
|
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) |
|
|
|
if axis < 0: |
|
|
|
axis = axis + rank_base |
|
|
|
all_shp = x_shp[0][axis] |
|
|
|
offset = [0,] |
|
|
|
for i in range(1, len(x_shp)): |
|
|
|
v = x_shp[i] |
|
|
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) |
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) |
|
|
|
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) |
|
|
|
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) |
|
|
|
for j in range(rank_base): |
|
|
|
if j != axis and v[j] != x_shp[0][j]: |
|
|
|
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) |
|
|
|
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") |
|
|
|
offset.append(all_shp) |
|
|
|
all_shp += v[axis] |
|
|
|
return offset, all_shp, axis |