|
|
|
@@ -13,7 +13,7 @@ |
|
|
|
# limitations under the License. |
|
|
|
# =========================================================================== |
|
|
|
"""generate json desc for Tile""" |
|
|
|
from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer |
|
|
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException |
|
|
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF |
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD |
|
|
|
|
|
|
|
@@ -23,14 +23,32 @@ from ._utils import Expander, ExpanderInfoValidator as VLD |
|
|
|
class Tile(Expander): |
|
|
|
"""Tile expander""" |
|
|
|
|
|
|
|
def _get_output_shape(self): |
|
|
|
"""Get output shape""" |
|
|
|
shape = self.inputs[0].shape |
|
|
|
multiples = self.attrs["multiples"] |
|
|
|
|
|
|
|
shape = list(shape) |
|
|
|
multiples = list(multiples) |
|
|
|
diff_len = len(multiples) - len(shape) |
|
|
|
if diff_len < 0: |
|
|
|
raise GKException("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) |
|
|
|
if diff_len > 0: |
|
|
|
for _ in range(diff_len): |
|
|
|
shape.insert(0, 1) |
|
|
|
|
|
|
|
output_shape = [] |
|
|
|
|
|
|
|
for sh, mul in list(zip(shape, multiples)): |
|
|
|
if sh != 1 and mul != 1: |
|
|
|
raise GKException("Tile op in expander only Support Automatic Broadcast!") |
|
|
|
dim = sh * mul |
|
|
|
output_shape.append(dim) |
|
|
|
return output_shape |
|
|
|
|
|
|
|
def _expand(self, graph_builder): |
|
|
|
input_x = self.inputs[0] |
|
|
|
multiples = self.attrs['multiples'] |
|
|
|
|
|
|
|
tile_infer = TileInfer(self.name, self.inputs, self.attrs) |
|
|
|
output_shape, _, _ = tile_infer.infer() |
|
|
|
if tile_infer.broadcast_compatible: |
|
|
|
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) |
|
|
|
else: |
|
|
|
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) |
|
|
|
output_shape = self._get_output_shape() |
|
|
|
|
|
|
|
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) |
|
|
|
return result |