| @@ -16,45 +16,6 @@ | |||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||
| def _get_tile_output_shape(shape, multiples): | |||
| """compute output shape of tile""" | |||
| if multiples is None: | |||
| return shape | |||
| if not isinstance(shape, (list, tuple)): | |||
| raise TypeError("Input shape of Tile must be of type list or tuple") | |||
| if not isinstance(multiples, (list, tuple)): | |||
| raise TypeError("multiples of Tile must be of type list or tuple") | |||
| shape = list(shape) | |||
| multiples = list(multiples) | |||
| diff_len = len(multiples) - len(shape) | |||
| if diff_len < 0: | |||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||
| if diff_len > 0: | |||
| for _ in range(diff_len): | |||
| shape.insert(0, 1) | |||
| shape_compatible = True | |||
| output_shape = [] | |||
| input_reshape = [] | |||
| output_reshape = [] | |||
| for sh, mul in list(zip(shape, multiples)): | |||
| dim = sh * mul | |||
| output_shape.append(dim) | |||
| if sh == 1 or mul == 1: | |||
| input_reshape.append(sh) | |||
| output_reshape.append(dim) | |||
| else: | |||
| shape_compatible = False | |||
| input_reshape.append(1) | |||
| input_reshape.append(sh) | |||
| output_reshape.append(mul) | |||
| output_reshape.append(sh) | |||
| return output_shape, input_reshape, output_reshape, shape_compatible | |||
| def expand_tile(expand_info): | |||
| """Tile expander""" | |||
| @@ -65,7 +26,7 @@ def expand_tile(expand_info): | |||
| for item in attrs: | |||
| if 'multiples' in item: | |||
| multiples = item['multiples'] | |||
| output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples) | |||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples) | |||
| graph_builder = builder.GraphBuilder() | |||
| # generate a graph. | |||
| @@ -18,6 +18,45 @@ import copy | |||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy | |||
| def get_tile_output_shape(shape, multiples): | |||
| """compute output shape of tile""" | |||
| if multiples is None: | |||
| return shape | |||
| if not isinstance(shape, (list, tuple)): | |||
| raise TypeError("Input shape of Tile must be of type list or tuple") | |||
| if not isinstance(multiples, (list, tuple)): | |||
| raise TypeError("multiples of Tile must be of type list or tuple") | |||
| shape = list(shape) | |||
| multiples = list(multiples) | |||
| diff_len = len(multiples) - len(shape) | |||
| if diff_len < 0: | |||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||
| if diff_len > 0: | |||
| for _ in range(diff_len): | |||
| shape.insert(0, 1) | |||
| shape_compatible = True | |||
| output_shape = [] | |||
| input_reshape = [] | |||
| output_reshape = [] | |||
| for sh, mul in list(zip(shape, multiples)): | |||
| dim = sh * mul | |||
| output_shape.append(dim) | |||
| if sh == 1 or mul == 1: | |||
| input_reshape.append(sh) | |||
| output_reshape.append(dim) | |||
| else: | |||
| shape_compatible = False | |||
| input_reshape.append(1) | |||
| input_reshape.append(sh) | |||
| output_reshape.append(mul) | |||
| output_reshape.append(sh) | |||
| return output_shape, input_reshape, output_reshape, shape_compatible | |||
| class OpInfer: | |||
| """Op infer""" | |||
| @staticmethod | |||
| @@ -74,6 +113,7 @@ class OpInfer: | |||
| 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, | |||
| 'Reshape': lambda inputs, attrs: attrs["shape"], | |||
| 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | |||
| 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], | |||
| } | |||
| infer_dtype_func = { | |||
| # add special infer func here | |||
| @@ -47,7 +47,13 @@ def test_sqrt_grad(shape_x, shape_dout, dtype): | |||
| expect_np = expect.asnumpy().copy() | |||
| output_np = output.asnumpy().copy() | |||
| assert np.allclose(expect_np, output_np, 0.0001, 0.0001) | |||
| rtol = 0.0001 | |||
| atol = 0.0001 | |||
| if dtype == np.float16: | |||
| rtol = 0.001 | |||
| atol = 0.001 | |||
| assert np.allclose(expect_np, output_np, rtol, atol) | |||
| @pytest.mark.level0 | |||