Browse Source

add expander checkings

tags/v1.4.0
Yang Jiao 4 years ago
parent
commit
582669e41c
3 changed files with 18 additions and 3 deletions
  1. +14
    -0
      mindspore/_extends/graph_kernel/expanders/_utils.py
  2. +2
    -1
      mindspore/_extends/graph_kernel/expanders/square_sum_v1.py
  3. +2
    -2
      mindspore/_extends/graph_kernel/model/op_infer.py

+ 14
- 0
mindspore/_extends/graph_kernel/expanders/_utils.py View File

@@ -47,8 +47,10 @@ class Expander:
graph_scope.set_input(*self.inputs)
outputs = self._expand(graph_builder)
if isinstance(outputs, (list, tuple)):
self._check_output_same(outputs)
graph_scope.set_output(*outputs)
else:
self._check_output_same([outputs])
graph_scope.set_output(outputs)

graph = graph_builder.get()[0]
@@ -58,6 +60,18 @@ class Expander:
def _check(self):
"""Check inputs"""

def _check_output_same(self, outputs):
for index, value in enumerate(self.outputs):
if list(outputs[index].shape) != list(value['shape']):
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
if outputs[index].dtype != value['data_type']:
raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
self.__class__.__name__, outputs[index].dtype, value['data_type']))
if outputs[index].data_format != value['format']:
raise GKException("{} 's output format {} is wrong. Expected: {}".format(
self.__class__.__name__, outputs[index].data_format, value['format']))

@abstractmethod
def _expand(self, graph_builder):
"""Expand operator, this function should be overridden in subclass"""


+ 2
- 1
mindspore/_extends/graph_kernel/expanders/square_sum_v1.py View File

@@ -17,6 +17,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis


@VLD.add_format(DF.FRAC_NZ)
@VLD.add_format(DF.DEFAULT)
@VLD.check_attrs('axis')
@@ -47,7 +48,7 @@ class SquareSumV1(Expander):
axis = to_frac_z_axis(ori_shape, axis)

square_res = graph_builder.emit('Mul', [x, x])
result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': axis, 'keep_dims': True})
result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': axis, 'keep_dims': False})
if x.data_format == DF.FRAC_NZ:
result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
return result

+ 2
- 2
mindspore/_extends/graph_kernel/model/op_infer.py View File

@@ -231,13 +231,13 @@ class Cast(_Elemwise):

class InplaceAssign(_Elemwise):
def _infer_shape(self):
return [1] if self.attrs["fake_output"] else self.inputs[2].shape
return self.inputs[2].shape

def _infer_type(self):
return self.inputs[2].dtype

def _infer_format(self):
return DF.DEFAULT if self.attrs["fake_output"] else self.inputs[2].data_format
return self.inputs[2].data_format


class BroadcastTo(OpInfer):


Loading…
Cancel
Save