From: @dayschan Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -20,24 +20,30 @@ from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| @VLD.add_format(DF.DEFAULT) | @VLD.add_format(DF.DEFAULT) | ||||
| @VLD.add_format(DF.NHWC) | @VLD.add_format(DF.NHWC) | ||||
| @VLD.add_format(DF.NCHW) | @VLD.add_format(DF.NCHW) | ||||
| @VLD.add_format(DF.FRAC_NZ) | |||||
| class BiasAddGrad(Expander): | class BiasAddGrad(Expander): | ||||
| """BiasAddGrad expander""" | """BiasAddGrad expander""" | ||||
| def _expand(self, graph_builder): | def _expand(self, graph_builder): | ||||
| input_x = self.inputs[0] | |||||
| x = self.inputs[0] | |||||
| reduce_axis = () | reduce_axis = () | ||||
| if input_x.data_format == 'NHWC': | |||||
| if x.data_format == DF.NHWC: | |||||
| reduce_axis = (0, 1, 2) | reduce_axis = (0, 1, 2) | ||||
| elif input_x.data_format == 'NCHW': | |||||
| elif x.data_format == DF.NCHW: | |||||
| reduce_axis = (0, 2, 3) | reduce_axis = (0, 2, 3) | ||||
| # DefaultFormat shape's length should be from 2 to 4 | |||||
| elif x.data_format == DF.FRAC_NZ: | |||||
| reduce_axis = (-2, -3) | |||||
| else: | else: | ||||
| if len(input_x.shape) == 2: | |||||
| # DefaultFormat shape's length should be from 2 to 4 | |||||
| if len(x.shape) == 2: | |||||
| reduce_axis = (0,) | reduce_axis = (0,) | ||||
| elif len(input_x.shape) == 3: | |||||
| elif len(x.shape) == 3: | |||||
| reduce_axis = (0, 1) | reduce_axis = (0, 1) | ||||
| else: | else: | ||||
| reduce_axis = (0, 2, 3) | reduce_axis = (0, 2, 3) | ||||
| result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| if x.data_format == DF.FRAC_NZ: | |||||
| out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]] | |||||
| result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape}) | |||||
| return result | return result | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """generate json desc for Tile""" | """generate json desc for Tile""" | ||||
| from mindspore._extends.graph_kernel.model import model_builder as builder | |||||
| from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | from mindspore._extends.graph_kernel.model.model import DataFormat as DF | ||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | from ._utils import Expander, ExpanderInfoValidator as VLD | ||||
| @@ -27,8 +27,9 @@ class Tile(Expander): | |||||
| input_x = self.inputs[0] | input_x = self.inputs[0] | ||||
| multiples = self.attrs['multiples'] | multiples = self.attrs['multiples'] | ||||
| output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples) | |||||
| if shape_compatible: | |||||
| tile_infer = TileInfer(self.name, self.inputs, self.attrs) | |||||
| output_shape, _, _ = tile_infer.infer() | |||||
| if tile_infer.broadcast_compatible: | |||||
| result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) | ||||
| else: | else: | ||||
| result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) | ||||
| @@ -15,139 +15,8 @@ | |||||
| """GraphKernel model builder""" | """GraphKernel model builder""" | ||||
| import copy | import copy | ||||
| from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat | |||||
| def get_tile_output_shape(shape, multiples): | |||||
| """compute output shape of tile""" | |||||
| if multiples is None: | |||||
| return shape | |||||
| if not isinstance(shape, (list, tuple)): | |||||
| raise TypeError("Input shape of Tile must be of type list or tuple") | |||||
| if not isinstance(multiples, (list, tuple)): | |||||
| raise TypeError("multiples of Tile must be of type list or tuple") | |||||
| shape = list(shape) | |||||
| multiples = list(multiples) | |||||
| diff_len = len(multiples) - len(shape) | |||||
| if diff_len < 0: | |||||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||||
| if diff_len > 0: | |||||
| for _ in range(diff_len): | |||||
| shape.insert(0, 1) | |||||
| shape_compatible = True | |||||
| output_shape = [] | |||||
| input_reshape = [] | |||||
| output_reshape = [] | |||||
| for sh, mul in list(zip(shape, multiples)): | |||||
| dim = sh * mul | |||||
| output_shape.append(dim) | |||||
| if sh == 1 or mul == 1: | |||||
| input_reshape.append(sh) | |||||
| output_reshape.append(dim) | |||||
| else: | |||||
| shape_compatible = False | |||||
| input_reshape.append(1) | |||||
| input_reshape.append(sh) | |||||
| output_reshape.append(mul) | |||||
| output_reshape.append(sh) | |||||
| return output_shape, input_reshape, output_reshape, shape_compatible | |||||
| class OpInfer: | |||||
| """Op infer""" | |||||
| @staticmethod | |||||
| def default_reduce_infer(inputs, attrs): | |||||
| """Default reduce infer""" | |||||
| shape = copy.deepcopy(inputs[0].shape) | |||||
| if attrs['keep_dims']: | |||||
| for i in attrs['reduce_axis']: | |||||
| shape[i] = 1 | |||||
| return shape | |||||
| real_shape = [] | |||||
| for i, _ in enumerate(shape): | |||||
| if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']: | |||||
| real_shape.append(shape[i]) | |||||
| return real_shape | |||||
| @staticmethod | |||||
| def default_elementwise_infer(inputs, attrs): | |||||
| """Default elementwise infer""" | |||||
| shape = (1,) | |||||
| max_flatten_shape = 1 | |||||
| for t in inputs: | |||||
| flatten_shape = 1 | |||||
| for s in t.shape: | |||||
| flatten_shape *= s | |||||
| if flatten_shape >= max_flatten_shape: | |||||
| max_flatten_shape = flatten_shape | |||||
| shape = t.shape | |||||
| return shape | |||||
| default_infer_shape_func = [ | |||||
| None, | |||||
| None, | |||||
| default_elementwise_infer.__func__, | |||||
| lambda inputs, attrs: max([t.shape for t in inputs]), | |||||
| default_reduce_infer.__func__, | |||||
| None, | |||||
| lambda inputs, attrs: [1], # control op | |||||
| ] | |||||
| @staticmethod | |||||
| def default_infer_dtype_func(inputs, attrs): | |||||
| """Infer dtype""" | |||||
| return inputs[0].dtype | |||||
| @staticmethod | |||||
| def default_infer_format_func(inputs, attrs): | |||||
| """Infer format""" | |||||
| result = inputs[0].data_format | |||||
| # default_format and other_format results in other_format | |||||
| for input_tensor in inputs[1:]: | |||||
| data_format = input_tensor.data_format | |||||
| if data_format != DataFormat.DEFAULT: | |||||
| if result not in [DataFormat.DEFAULT, data_format]: | |||||
| raise RuntimeError("Incompatible data format %s and %s" % (data_format, result)) | |||||
| result = data_format | |||||
| return result | |||||
| infer_shape_func = { | |||||
| # add special infer func here | |||||
| 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, | |||||
| 'Reshape': lambda inputs, attrs: attrs["shape"], | |||||
| 'BroadcastTo': lambda inputs, attrs: attrs["shape"], | |||||
| 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], | |||||
| 'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1), | |||||
| } | |||||
| infer_dtype_func = { | |||||
| # add special infer func here | |||||
| 'Cast': lambda inputs, attrs: attrs['dst_type'], | |||||
| 'Less': lambda inputs, attrs: "bool", | |||||
| 'LessEqual': lambda inputs, attrs: "bool", | |||||
| 'Equal': lambda inputs, attrs: "bool", | |||||
| 'Greater': lambda inputs, attrs: "bool", | |||||
| 'GreaterEqual': lambda inputs, attrs: "bool", | |||||
| } | |||||
| infer_format_func = { | |||||
| # add special infer func here | |||||
| 'Reshape': lambda inputs, attrs: "DefaultFormat", | |||||
| } | |||||
| @classmethod | |||||
| def infer(cls, prim_name, inputs, attrs): | |||||
| prim = PrimLib.primtives[prim_name] | |||||
| infer_shape = cls.infer_shape_func.get( | |||||
| prim_name, cls.default_infer_shape_func[prim.iter_type]) | |||||
| infer_dtype = cls.infer_dtype_func.get( | |||||
| prim_name, cls.default_infer_dtype_func) | |||||
| infer_format = cls.infer_format_func.get( | |||||
| prim_name, cls.default_infer_format_func) | |||||
| return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs) | |||||
| from . import op_infer | |||||
| from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy | |||||
| class GraphBuilder: | class GraphBuilder: | ||||
| @@ -229,7 +98,7 @@ class GraphBuilder: | |||||
| if isinstance(inputs, (Tensor, Value)): | if isinstance(inputs, (Tensor, Value)): | ||||
| inputs = [inputs] | inputs = [inputs] | ||||
| tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] | tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] | ||||
| out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs) | |||||
| out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) | |||||
| output = self.tensor(out_shape, out_dtype, out_format, name) | output = self.tensor(out_shape, out_dtype, out_format, name) | ||||
| self.op(prim, output, inputs, attrs) | self.op(prim, output, inputs, attrs) | ||||
| return output | return output | ||||
| @@ -0,0 +1,275 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """GraphKernel Op Infer""" | |||||
| import copy | |||||
| import sys | |||||
| from functools import reduce | |||||
| from .model import GraphKernelUnsupportedException as GKException | |||||
| from .model import PrimLib, DataFormat as DF | |||||
| def infer(op_name, inputs, attrs): | |||||
| """infer shape dtype and format""" | |||||
| def _create_opinfer(): | |||||
| if hasattr(sys.modules[__name__], op_name): | |||||
| op_cls = getattr(sys.modules[__name__], op_name) | |||||
| return op_cls(op_name, inputs, attrs) | |||||
| # common infer | |||||
| class_name_map = { | |||||
| PrimLib.ELEMWISE: "_Elemwise", | |||||
| PrimLib.REDUCE: "_Reduce", | |||||
| } | |||||
| cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None) | |||||
| if not cls_name: | |||||
| raise GKException("OpInfo does not support op {}".format(op_name)) | |||||
| op_cls = getattr(sys.modules[__name__], cls_name) | |||||
| return op_cls(op_name, inputs, attrs) | |||||
| return _create_opinfer().infer() | |||||
| class OpInfer: | |||||
| """ | |||||
| OpInfer is the base class for inferring operator info in GraphKernel model builder. | |||||
| There are three methods should be overridden to define the infer logic of the operator: | |||||
| _infer_shape(), _infer_type() and _infer_format(). | |||||
| """ | |||||
| def __init__(self, name, inputs, attrs): | |||||
| self.name = name | |||||
| self.inputs = inputs | |||||
| self.attrs = attrs | |||||
| def infer(self): | |||||
| """Infer shape, type and format by op inputs""" | |||||
| self._check() | |||||
| return self._infer_shape(), self._infer_type(), self._infer_format() | |||||
| def _infer_shape(self): | |||||
| return self.inputs[0].shape | |||||
| def _infer_type(self): | |||||
| return self.inputs[0].dtype | |||||
| def _infer_format(self): | |||||
| return self.inputs[0].data_format | |||||
| def _check(self): | |||||
| self._check_shape() | |||||
| self._check_type() | |||||
| self._check_format() | |||||
| def _check_shape(self): | |||||
| pass | |||||
| def _check_type(self): | |||||
| """check all dtypes are same""" | |||||
| dtype = self.inputs[0].dtype | |||||
| for i, t in enumerate(self.inputs[1:]): | |||||
| if t.dtype != dtype: | |||||
| raise GKException( | |||||
| "Incompatible dtype between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype)) | |||||
| def _check_format(self): | |||||
| """check formats are compatible. only DefaultFormat is compatible with others""" | |||||
| result = self.inputs[0].data_format | |||||
| i = 0 | |||||
| for j, t in enumerate(self.inputs[1:]): | |||||
| if t.data_format != result: | |||||
| if DF.DEFAULT not in (result, t.data_format): | |||||
| raise GKException("Incompatible format between input {}({}) and {}({})".format( | |||||
| i, result, j + 1, t.data_format)) | |||||
| if result == DF.DEFAULT: | |||||
| result = t.data_format | |||||
| i = j + 1 | |||||
| class _Elemwise(OpInfer): | |||||
| """Common infer for elementwise operators""" | |||||
| def _infer_shape(self): | |||||
| """returns the input shape with largest flatten size""" | |||||
| shape = (1,) | |||||
| max_flatten_size = 1 | |||||
| for t in self.inputs: | |||||
| flatten_size = reduce(lambda x, y: x * y, t.shape) | |||||
| if flatten_size >= max_flatten_size: | |||||
| max_flatten_size = flatten_size | |||||
| shape = t.shape | |||||
| return shape | |||||
| def _infer_format(self): | |||||
| for tensor in self.inputs: | |||||
| if tensor.data_format != DF.DEFAULT: | |||||
| return tensor.data_format | |||||
| return DF.DEFAULT | |||||
| class _Reduce(OpInfer): | |||||
| """Common infer for reduction operators""" | |||||
| def _check(self): | |||||
| super()._check() | |||||
| # check reduce axis in the range [-len, len) | |||||
| shape_len = len(self.inputs[0].shape) | |||||
| axis = self.attrs['reduce_axis'] | |||||
| if isinstance(axis, int): | |||||
| axis = [axis] | |||||
| if not all([(-shape_len <= i < shape_len) for i in axis]): | |||||
| raise GKException( | |||||
| "reduce_axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis)) | |||||
| def _infer_shape(self): | |||||
| shape = copy.deepcopy(self.inputs[0].shape) | |||||
| axis = self.attrs['reduce_axis'] | |||||
| if isinstance(axis, int): | |||||
| axis = [axis] | |||||
| if any([i < 0 for i in axis]): | |||||
| # change the axis to non-negative number. | |||||
| axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis)) | |||||
| self.attrs['reduce_axis'] = sorted(axis) | |||||
| if self.attrs['keep_dims']: | |||||
| for i in axis: | |||||
| shape[i] = 1 | |||||
| return shape | |||||
| real_shape = [] | |||||
| for i, s in enumerate(shape): | |||||
| if i not in axis: | |||||
| real_shape.append(s) | |||||
| return real_shape | |||||
| def _infer_format(self): | |||||
| return DF.DEFAULT | |||||
| class _Reshape(OpInfer): | |||||
| """Common infer for reshape operators, should not be instantiated""" | |||||
| def _infer_shape(self): | |||||
| raise GKException("_infer_shape should be implemented by subclass") | |||||
| def _infer_format(self): | |||||
| return DF.DEFAULT | |||||
| class Reshape(_Reshape): | |||||
| def _infer_shape(self): | |||||
| return self.attrs["shape"] | |||||
| class ExpandDims(_Reshape): | |||||
| def _infer_shape(self): | |||||
| return list(self.inputs[0].shape).insert(self.attrs["axis"], 1) | |||||
| class Cast(_Elemwise): | |||||
| def _infer_type(self): | |||||
| return self.attrs["dst_type"] | |||||
| class InplaceAssign(_Elemwise): | |||||
| def _infer_shape(self): | |||||
| return [1] if self.attrs["fake_output"] else self.inputs[2].shape | |||||
| def _infer_type(self): | |||||
| return self.inputs[2].dtype | |||||
| def _infer_format(self): | |||||
| return DF.DEFAULT if self.attrs["fake_output"] else self.inputs[2].data_format | |||||
| class BroadcastTo(OpInfer): | |||||
| def _infer_shape(self): | |||||
| return self.attrs["shape"] | |||||
| def _infer_format(self): | |||||
| return self.inputs[0].data_format | |||||
| class Tile(OpInfer): | |||||
| """Op Tile""" | |||||
| def __init__(self, op_name, inputs, attrs): | |||||
| super().__init__(op_name, inputs, attrs) | |||||
| self.input_reshape = None | |||||
| self.output_reshape = None | |||||
| self.broadcast_compatible = True | |||||
| def _infer_shape(self): | |||||
| shape = self.inputs[0].shape | |||||
| multiples = self.attrs["multiples"] | |||||
| shape = list(shape) | |||||
| multiples = list(multiples) | |||||
| diff_len = len(multiples) - len(shape) | |||||
| if diff_len < 0: | |||||
| raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) | |||||
| if diff_len > 0: | |||||
| for _ in range(diff_len): | |||||
| shape.insert(0, 1) | |||||
| self.broadcast_compatible = True | |||||
| output_shape = [] | |||||
| self.input_reshape = [] | |||||
| self.output_reshape = [] | |||||
| for sh, mul in list(zip(shape, multiples)): | |||||
| dim = sh * mul | |||||
| output_shape.append(dim) | |||||
| if sh == 1 or mul == 1: | |||||
| self.input_reshape.append(sh) | |||||
| self.output_reshape.append(dim) | |||||
| else: | |||||
| self.broadcast_compatible = False | |||||
| self.input_reshape.append(1) | |||||
| self.input_reshape.append(sh) | |||||
| self.output_reshape.append(mul) | |||||
| self.output_reshape.append(sh) | |||||
| return output_shape | |||||
| def _infer_format(self): | |||||
| return DF.DEFAULT | |||||
| class _CompareOp(_Elemwise): | |||||
| """Compare operators""" | |||||
| def _infer_type(self): | |||||
| return "bool" | |||||
| class Less(_CompareOp): | |||||
| pass | |||||
| class LessEqual(_CompareOp): | |||||
| pass | |||||
| class Equal(_CompareOp): | |||||
| pass | |||||
| class Greater(_CompareOp): | |||||
| pass | |||||
| class GreaterEqual(_CompareOp): | |||||
| pass | |||||