Browse Source

add Tile infer shape function

tags/v1.1.0
looop5 5 years ago
parent
commit
8bbe723603
3 changed files with 48 additions and 41 deletions
  1. +1
    -40
      mindspore/_extends/graph_kernel/expanders/tile.py
  2. +40
    -0
      mindspore/_extends/graph_kernel/model/model_builder.py
  3. +7
    -1
      tests/st/ops/graph_kernel/test_sqrt_grad.py

+ 1
- 40
mindspore/_extends/graph_kernel/expanders/tile.py View File

@@ -16,45 +16,6 @@
from mindspore._extends.graph_kernel.model import model_builder as builder


def _get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""

if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")

shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)

shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)

return output_shape, input_reshape, output_reshape, shape_compatible


def expand_tile(expand_info):
"""Tile expander"""

@@ -65,7 +26,7 @@ def expand_tile(expand_info):
for item in attrs:
if 'multiples' in item:
multiples = item['multiples']
output_shape, _, _, shape_compatible = _get_tile_output_shape(input_desc['shape'], multiples)
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(input_desc['shape'], multiples)
graph_builder = builder.GraphBuilder()

# generate a graph.


+ 40
- 0
mindspore/_extends/graph_kernel/model/model_builder.py View File

@@ -18,6 +18,45 @@ import copy
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy


def get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""

if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")

shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)

shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)

return output_shape, input_reshape, output_reshape, shape_compatible


class OpInfer:
"""Op infer"""
@staticmethod
@@ -74,6 +113,7 @@ class OpInfer:
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
}
infer_dtype_func = {
# add special infer func here


+ 7
- 1
tests/st/ops/graph_kernel/test_sqrt_grad.py View File

@@ -47,7 +47,13 @@ def test_sqrt_grad(shape_x, shape_dout, dtype):
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()

assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
rtol = 0.0001
atol = 0.0001
if dtype == np.float16:
rtol = 0.001
atol = 0.001

assert np.allclose(expect_np, output_np, rtol, atol)


@pytest.mark.level0


Loading…
Cancel
Save