Browse Source

refactor tile op and in expander open on gpu

pull/15642/head
zengzitao 4 years ago
parent
commit
8dcff8d83c
4 changed files with 29 additions and 56 deletions
  1. +27
    -9
      mindspore/_extends/graph_kernel/expanders/tile.py
  2. +0
    -45
      mindspore/_extends/graph_kernel/model/op_infer.py
  3. +1
    -1
      mindspore/_extends/graph_kernel/splitter.py
  4. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc

+ 27
- 9
mindspore/_extends/graph_kernel/expanders/tile.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for Tile""" """generate json desc for Tile"""
from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD


@@ -23,14 +23,32 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
class Tile(Expander): class Tile(Expander):
"""Tile expander""" """Tile expander"""


def _get_output_shape(self):
"""Get output shape"""
shape = self.inputs[0].shape
multiples = self.attrs["multiples"]

shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise GKException("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)

output_shape = []

for sh, mul in list(zip(shape, multiples)):
if sh != 1 and mul != 1:
raise GKException("Tile op in expander only Support Automatic Broadcast!")
dim = sh * mul
output_shape.append(dim)
return output_shape

def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x = self.inputs[0] input_x = self.inputs[0]
multiples = self.attrs['multiples']

tile_infer = TileInfer(self.name, self.inputs, self.attrs)
output_shape, _, _ = tile_infer.infer()
if tile_infer.broadcast_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else:
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})
output_shape = self._get_output_shape()

result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
return result return result

+ 0
- 45
mindspore/_extends/graph_kernel/model/op_infer.py View File

@@ -198,51 +198,6 @@ class BroadcastTo(OpInfer):
return self.inputs[0].data_format return self.inputs[0].data_format




class Tile(OpInfer):
"""Op Tile"""

def __init__(self, op_name, inputs, attrs):
super().__init__(op_name, inputs, attrs)
self.input_reshape = None
self.output_reshape = None
self.broadcast_compatible = True

def _infer_shape(self):
shape = self.inputs[0].shape
multiples = self.attrs["multiples"]

shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)

self.broadcast_compatible = True
output_shape = []
self.input_reshape = []
self.output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
self.input_reshape.append(sh)
self.output_reshape.append(dim)
else:
self.broadcast_compatible = False
self.input_reshape.append(1)
self.input_reshape.append(sh)
self.output_reshape.append(mul)
self.output_reshape.append(sh)

return output_shape

def _infer_format(self):
return DF.DEFAULT


class _CompareOp(_Elemwise): class _CompareOp(_Elemwise):
"""Compare operators""" """Compare operators"""




+ 1
- 1
mindspore/_extends/graph_kernel/splitter.py View File

@@ -61,7 +61,7 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
f.write("********** main graph: {} **********\n".format(graph_desc.name)) f.write("********** main graph: {} **********\n".format(graph_desc.name))
f.write("input json:\n{}\n".format(graph_json)) f.write("input json:\n{}\n".format(graph_json))
f.write("graph desc:\n{}\n".format(str(graph_desc))) f.write("graph desc:\n{}\n".format(str(graph_desc)))
if len(subgraphs) > 1 or subgraphs[0].stitch_info is not None:
if len(subgraphs) > 1:
for i, g in enumerate(subgraphs): for i, g in enumerate(subgraphs):
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
f.write("{}\n".format(str(g))) f.write("{}\n".format(str(g)))


+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc View File

@@ -50,8 +50,8 @@ std::vector<PrimitivePtr> GetExpandOps() {
prim::kPrimLayerNorm, prim::kPrimLayerNorm,
prim::kPrimLayerNormGrad, prim::kPrimLayerNormGrad,
prim::kPrimExpandDims, prim::kPrimExpandDims,
#if ENABLE_D
prim::kPrimTile, prim::kPrimTile,
#if ENABLE_D
prim::kPrimSqrtGrad, prim::kPrimSqrtGrad,
prim::kPrimClipByNormNoDivSum, prim::kPrimClipByNormNoDivSum,
prim::kLambApplyOptimizerAssign, prim::kLambApplyOptimizerAssign,


Loading…
Cancel
Save