diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index cb4eb50743..211ef25e66 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -51,3 +51,4 @@ from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign from .lamb_apply_weight_assign import LambApplyWeightAssign from .softmax_grad_ext import SoftmaxGradExt from .square_sum_v1 import SquareSumV1 +from .conv2d import Conv2D diff --git a/mindspore/_extends/graph_kernel/expanders/conv2d.py b/mindspore/_extends/graph_kernel/expanders/conv2d.py new file mode 100644 index 0000000000..640722bad2 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/conv2d.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for Conv2D""" +from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad +from mindspore._extends.graph_kernel.model.model import DataFormat as DF +from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException +from ._utils import Expander, ExpanderInfoValidator as VLD + +M_ALIGN = 16 +N_ALIGN = 16 +K_ALIGN = 8 +OUT_CHANNEL_ALIGN = 8 + + +@VLD.add_format(DF.NHWC, DF.NHWC) +@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') +class Conv2D(Expander): + """Conv2D expander""" + def __init__(self, expand_info): + super().__init__(expand_info) + self.has_pad = False + self.can_optimize_to_batchmatmul = False + + def _check(self): + type_0 = self.inputs[0]['data_type'] + type_1 = self.inputs[1]['data_type'] + if type_0 != "float16" or type_1 != "float16": + raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) + + groups = self.attrs['groups'] + group = self.attrs['group'] + if groups != 1 or group != 1: + raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group)) + + dilation = self.attrs['dilation'] + check_nd(dilation, 4) + if dilation != [1, 1, 1, 1]: + raise GKException("dilation should be all 1, but got {}".format(dilation)) + + pad_list = self.attrs['pad_list'] + pad_mode = self.attrs['pad_mode'] + check_nd(pad_list, 4) + self.has_pad = conv_had_pad(pad_list, pad_mode) + + shape_0 = self.inputs[0]['shape'] + shape_1 = self.inputs[1]['shape'] + stride = self.attrs['stride'] + check_nd(shape_0, 4) + check_nd(shape_1, 4) + check_nd(stride, 4) + n0, h0, w0, c0 = shape_0 + n1, h1, w1, c1 = shape_1 + if c0 != c1: + raise GKException("C channel of inputs should be same, but got {} and {}".format(c0, c1)) + if self.has_pad: + h0 = h0 + pad_list[0] + pad_list[1] + w0 = w0 + pad_list[2] + pad_list[3] + n1 = ((n1 + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN + m = n0 * h0 * w0 + n = n1 * h1 * w1 + k = c1 + self.can_optimize_to_batchmatmul = False + if h1 == 1 and w1 == 1 and stride == [1, 1, 1, 1] and m % M_ALIGN == 0 and n % N_ALIGN == 0 and \ + k % K_ALIGN == 0: + self.can_optimize_to_batchmatmul = True + if n0 < 128 and h0 % 2 != 0 and w0 % 2 != 0 and not self.can_optimize_to_batchmatmul: + raise GKException("Conv2D expander only processes when N({}) > 128, H({}) and W({}) are odd of first input \ + or current Conv2D can be optimized to BatchMatMul.".format(n0, h0, w0)) + + def _expand(self, graph_builder): + input_0 = self.inputs[0] + input_1 = self.inputs[1] + + pad_value = 0 + if self.has_pad: + pad_list = self.attrs['pad_list'] + pad_before = [0, pad_list[0], pad_list[2], 0] + pad_after = [0, pad_list[1], pad_list[3], 0] + input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': pad_before, + 'tail': pad_after, + 'pad_val': pad_value}) + attrs = self.attrs + attrs['pad_list'] = [0, 0, 0, 0] + attrs['can_optimize_to_batchmatmul'] = self.can_optimize_to_batchmatmul + + out_channel = ((input_1.shape[0] + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN + if out_channel != input_1.shape[0]: + out_channel_pad = out_channel - input_1.shape[0] + pad_before = [0, 0, 0, 0] + pad_after = [out_channel_pad, 0, 0, 0] + unpad_after = [0, 0, 0, out_channel_pad] + input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': pad_before, + 'tail': pad_after, + 'pad_val': pad_value}) + result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) + result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after}) + else: + result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) + return result diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 665c0b98ff..8b5e94252f 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -298,6 +298,8 @@ class GraphSplitGpu(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 20 def get_default_mode(self, op): + if op.prim in ["PadAkg", "UnPadAkg"]: + return self.Area.MODE_COMPOSITE pattern = PrimLib.iter_type(op) return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 5fc0d17b1e..6c5d4b0ccd 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -187,6 +187,9 @@ class PrimLib: 'BroadcastTo': Prim(BROADCAST), 'MatMul': Prim(OPAQUE), 'TransData': Prim(OPAQUE), + 'Conv2D': Prim(OPAQUE), + 'PadAkg': Prim(OPAQUE), + 'UnPadAkg': Prim(OPAQUE), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 4c20fd740c..c972bc0342 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -14,7 +14,6 @@ # =========================================================================== """GraphKernel Op Infer""" - import copy import sys from functools import reduce @@ -24,6 +23,7 @@ from .model import PrimLib, DataFormat as DF def infer(op_name, inputs, attrs): """infer shape dtype and format""" + def _create_opinfer(): if hasattr(sys.modules[__name__], op_name): op_cls = getattr(sys.modules[__name__], op_name) @@ -38,6 +38,7 @@ def infer(op_name, inputs, attrs): raise GKException("OpInfo does not support op {}".format(op_name)) op_cls = getattr(sys.modules[__name__], cls_name) return op_cls(op_name, inputs, attrs) + return _create_opinfer().infer() @@ -236,3 +237,89 @@ class Select(_Elemwise): def _infer_type(self): return self.inputs[1].dtype + + +def check_nd(data, nd): + if not isinstance(data, (list, tuple)) or len(data) != nd: + raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) + + +def check_one(data): + if not isinstance(data, (list, tuple)): + raise GKException("input should be list or tuple") + for i, d in enumerate(data): + if d != 1: + raise GKException("value at index {} should be 1, but got {}.".format(i, d)) + + +def conv_had_pad(pad_list, pad_mode): + if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: + raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) + if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: + return True + if pad_mode not in ["VALID", "valid"]: + for _, pad in enumerate(pad_list): + if pad != 0: + return True + return False + + +class Conv2D(OpInfer): + """Conv2D infer""" + def _infer_shape(self): + shape_0 = list(self.inputs[0].shape) + shape_1 = list(self.inputs[1].shape) + check_nd(shape_0, 4) + check_nd(shape_1, 4) + + format_0 = self.inputs[0].data_format + format_1 = self.inputs[1].data_format + if format_0 != DF.NHWC or format_1 != DF.NHWC: + raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1)) + + n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] + pad_list = self.attrs["pad_list"] + pad_mode = self.attrs["pad_mode"] + kernel_size = self.attrs["kernel_size"] + stride = self.attrs["stride"] + dilation = self.attrs["dilation"] + check_nd(pad_list, 4) + check_nd(kernel_size, 2) + check_nd(stride, 4) + check_nd(dilation, 4) + + has_pad = conv_had_pad(pad_list, pad_mode) + if not has_pad: + pad_list = [0, 0, 0, 0] + + k_h = (kernel_size[0] - 1) * dilation[-2] + 1 + k_w = (kernel_size[1] - 1) * dilation[-1] + 1 + out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1 + out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1 + return [n, out_h, out_w, out_channel] + + +class PadAkg(OpInfer): + """PadAkg infer""" + def _infer_shape(self): + shape = list(self.inputs[0].shape) + n = len(shape) + pad_before = list(self.attrs["head"]) + pad_after = list(self.attrs["tail"]) + if len(pad_before) != n or len(pad_after) != n: + raise GKException("Input dimension and pad mismatch: {}d vs {}d vs {}d" + .format(n, len(pad_before), len(pad_after))) + out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)] + return out_shape + + +class UnPadAkg(OpInfer): + """UnPadAkg infer""" + def _infer_shape(self): + shape = list(self.inputs[0].shape) + n = len(shape) + unpad_after = list(self.attrs["tail"]) + if len(unpad_after) != n: + raise GKException("Input dimension and pad mismatch: {}d vs {}d".format(n, len(unpad_after))) + out_shape = [shape[i] - unpad_after[i] for i in range(n)] + return out_shape