From 0e98430f4d5b4b06cd19b73fa9a1abfc0253ede2 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:55:58 +0800 Subject: [PATCH] get concat offset refactor --- mindspore/ops/_utils/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index fbd81c4f0d..90496afc9b 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -15,7 +15,7 @@ """utils for operator""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): return broadcast_shape -def _get_concat_offset(x_shp, x_type, axis): +def _get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" - validator.check_type("shape", x_shp, [tuple]) - validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) - validator.check_subclass("shape0", x_type[0], mstype.tensor) - validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) + validator.check_value_type("shape", x_shp, [tuple], prim_name) + validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) + validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) + validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) rank_base = len(x_shp[0]) - validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) if axis < 0: axis = axis + rank_base all_shp = x_shp[0][axis] offset = [0,] for i in range(1, len(x_shp)): v = x_shp[i] - validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) - validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) + validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) for j in range(rank_base): if j != axis and v[j] != x_shp[0][j]: - raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) + raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) all_shp += v[axis] return offset, all_shp, axis