| @@ -51,3 +51,4 @@ from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign | |||||
| from .lamb_apply_weight_assign import LambApplyWeightAssign | from .lamb_apply_weight_assign import LambApplyWeightAssign | ||||
| from .softmax_grad_ext import SoftmaxGradExt | from .softmax_grad_ext import SoftmaxGradExt | ||||
| from .square_sum_v1 import SquareSumV1 | from .square_sum_v1 import SquareSumV1 | ||||
| from .conv2d import Conv2D | |||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for Conv2D""" | |||||
| from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| M_ALIGN = 16 | |||||
| N_ALIGN = 16 | |||||
| K_ALIGN = 8 | |||||
| OUT_CHANNEL_ALIGN = 8 | |||||
| @VLD.add_format(DF.NHWC, DF.NHWC) | |||||
| @VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') | |||||
| class Conv2D(Expander): | |||||
| """Conv2D expander""" | |||||
| def __init__(self, expand_info): | |||||
| super().__init__(expand_info) | |||||
| self.has_pad = False | |||||
| self.can_optimize_to_batchmatmul = False | |||||
| def _check(self): | |||||
| type_0 = self.inputs[0]['data_type'] | |||||
| type_1 = self.inputs[1]['data_type'] | |||||
| if type_0 != "float16" or type_1 != "float16": | |||||
| raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) | |||||
| groups = self.attrs['groups'] | |||||
| group = self.attrs['group'] | |||||
| if groups != 1 or group != 1: | |||||
| raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group)) | |||||
| dilation = self.attrs['dilation'] | |||||
| check_nd(dilation, 4) | |||||
| if dilation != [1, 1, 1, 1]: | |||||
| raise GKException("dilation should be all 1, but got {}".format(dilation)) | |||||
| pad_list = self.attrs['pad_list'] | |||||
| pad_mode = self.attrs['pad_mode'] | |||||
| check_nd(pad_list, 4) | |||||
| self.has_pad = conv_had_pad(pad_list, pad_mode) | |||||
| shape_0 = self.inputs[0]['shape'] | |||||
| shape_1 = self.inputs[1]['shape'] | |||||
| stride = self.attrs['stride'] | |||||
| check_nd(shape_0, 4) | |||||
| check_nd(shape_1, 4) | |||||
| check_nd(stride, 4) | |||||
| n0, h0, w0, c0 = shape_0 | |||||
| n1, h1, w1, c1 = shape_1 | |||||
| if c0 != c1: | |||||
| raise GKException("C channel of inputs should be same, but got {} and {}".format(c0, c1)) | |||||
| if self.has_pad: | |||||
| h0 = h0 + pad_list[0] + pad_list[1] | |||||
| w0 = w0 + pad_list[2] + pad_list[3] | |||||
| n1 = ((n1 + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN | |||||
| m = n0 * h0 * w0 | |||||
| n = n1 * h1 * w1 | |||||
| k = c1 | |||||
| self.can_optimize_to_batchmatmul = False | |||||
| if h1 == 1 and w1 == 1 and stride == [1, 1, 1, 1] and m % M_ALIGN == 0 and n % N_ALIGN == 0 and \ | |||||
| k % K_ALIGN == 0: | |||||
| self.can_optimize_to_batchmatmul = True | |||||
| if n0 < 128 and h0 % 2 != 0 and w0 % 2 != 0 and not self.can_optimize_to_batchmatmul: | |||||
| raise GKException("Conv2D expander only processes when N({}) > 128, H({}) and W({}) are odd of first input \ | |||||
| or current Conv2D can be optimized to BatchMatMul.".format(n0, h0, w0)) | |||||
| def _expand(self, graph_builder): | |||||
| input_0 = self.inputs[0] | |||||
| input_1 = self.inputs[1] | |||||
| pad_value = 0 | |||||
| if self.has_pad: | |||||
| pad_list = self.attrs['pad_list'] | |||||
| pad_before = [0, pad_list[0], pad_list[2], 0] | |||||
| pad_after = [0, pad_list[1], pad_list[3], 0] | |||||
| input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': pad_before, | |||||
| 'tail': pad_after, | |||||
| 'pad_val': pad_value}) | |||||
| attrs = self.attrs | |||||
| attrs['pad_list'] = [0, 0, 0, 0] | |||||
| attrs['can_optimize_to_batchmatmul'] = self.can_optimize_to_batchmatmul | |||||
| out_channel = ((input_1.shape[0] + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN | |||||
| if out_channel != input_1.shape[0]: | |||||
| out_channel_pad = out_channel - input_1.shape[0] | |||||
| pad_before = [0, 0, 0, 0] | |||||
| pad_after = [out_channel_pad, 0, 0, 0] | |||||
| unpad_after = [0, 0, 0, out_channel_pad] | |||||
| input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': pad_before, | |||||
| 'tail': pad_after, | |||||
| 'pad_val': pad_value}) | |||||
| result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) | |||||
| result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after}) | |||||
| else: | |||||
| result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) | |||||
| return result | |||||
| @@ -298,6 +298,8 @@ class GraphSplitGpu(GraphSplitByPattern): | |||||
| REDUCE_FUSE_DEPTH = 20 | REDUCE_FUSE_DEPTH = 20 | ||||
| def get_default_mode(self, op): | def get_default_mode(self, op): | ||||
| if op.prim in ["PadAkg", "UnPadAkg"]: | |||||
| return self.Area.MODE_COMPOSITE | |||||
| pattern = PrimLib.iter_type(op) | pattern = PrimLib.iter_type(op) | ||||
| return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE | return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE | ||||
| @@ -187,6 +187,9 @@ class PrimLib: | |||||
| 'BroadcastTo': Prim(BROADCAST), | 'BroadcastTo': Prim(BROADCAST), | ||||
| 'MatMul': Prim(OPAQUE), | 'MatMul': Prim(OPAQUE), | ||||
| 'TransData': Prim(OPAQUE), | 'TransData': Prim(OPAQUE), | ||||
| 'Conv2D': Prim(OPAQUE), | |||||
| 'PadAkg': Prim(OPAQUE), | |||||
| 'UnPadAkg': Prim(OPAQUE), | |||||
| } | } | ||||
| default_primtive = Prim(UNKNOWN) | default_primtive = Prim(UNKNOWN) | ||||
| @@ -14,7 +14,6 @@ | |||||
| # =========================================================================== | # =========================================================================== | ||||
| """GraphKernel Op Infer""" | """GraphKernel Op Infer""" | ||||
| import copy | import copy | ||||
| import sys | import sys | ||||
| from functools import reduce | from functools import reduce | ||||
| @@ -24,6 +23,7 @@ from .model import PrimLib, DataFormat as DF | |||||
| def infer(op_name, inputs, attrs): | def infer(op_name, inputs, attrs): | ||||
| """infer shape dtype and format""" | """infer shape dtype and format""" | ||||
| def _create_opinfer(): | def _create_opinfer(): | ||||
| if hasattr(sys.modules[__name__], op_name): | if hasattr(sys.modules[__name__], op_name): | ||||
| op_cls = getattr(sys.modules[__name__], op_name) | op_cls = getattr(sys.modules[__name__], op_name) | ||||
| @@ -38,6 +38,7 @@ def infer(op_name, inputs, attrs): | |||||
| raise GKException("OpInfo does not support op {}".format(op_name)) | raise GKException("OpInfo does not support op {}".format(op_name)) | ||||
| op_cls = getattr(sys.modules[__name__], cls_name) | op_cls = getattr(sys.modules[__name__], cls_name) | ||||
| return op_cls(op_name, inputs, attrs) | return op_cls(op_name, inputs, attrs) | ||||
| return _create_opinfer().infer() | return _create_opinfer().infer() | ||||
| @@ -236,3 +237,89 @@ class Select(_Elemwise): | |||||
| def _infer_type(self): | def _infer_type(self): | ||||
| return self.inputs[1].dtype | return self.inputs[1].dtype | ||||
| def check_nd(data, nd): | |||||
| if not isinstance(data, (list, tuple)) or len(data) != nd: | |||||
| raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) | |||||
| def check_one(data): | |||||
| if not isinstance(data, (list, tuple)): | |||||
| raise GKException("input should be list or tuple") | |||||
| for i, d in enumerate(data): | |||||
| if d != 1: | |||||
| raise GKException("value at index {} should be 1, but got {}.".format(i, d)) | |||||
| def conv_had_pad(pad_list, pad_mode): | |||||
| if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: | |||||
| raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) | |||||
| if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: | |||||
| return True | |||||
| if pad_mode not in ["VALID", "valid"]: | |||||
| for _, pad in enumerate(pad_list): | |||||
| if pad != 0: | |||||
| return True | |||||
| return False | |||||
| class Conv2D(OpInfer): | |||||
| """Conv2D infer""" | |||||
| def _infer_shape(self): | |||||
| shape_0 = list(self.inputs[0].shape) | |||||
| shape_1 = list(self.inputs[1].shape) | |||||
| check_nd(shape_0, 4) | |||||
| check_nd(shape_1, 4) | |||||
| format_0 = self.inputs[0].data_format | |||||
| format_1 = self.inputs[1].data_format | |||||
| if format_0 != DF.NHWC or format_1 != DF.NHWC: | |||||
| raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1)) | |||||
| n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] | |||||
| pad_list = self.attrs["pad_list"] | |||||
| pad_mode = self.attrs["pad_mode"] | |||||
| kernel_size = self.attrs["kernel_size"] | |||||
| stride = self.attrs["stride"] | |||||
| dilation = self.attrs["dilation"] | |||||
| check_nd(pad_list, 4) | |||||
| check_nd(kernel_size, 2) | |||||
| check_nd(stride, 4) | |||||
| check_nd(dilation, 4) | |||||
| has_pad = conv_had_pad(pad_list, pad_mode) | |||||
| if not has_pad: | |||||
| pad_list = [0, 0, 0, 0] | |||||
| k_h = (kernel_size[0] - 1) * dilation[-2] + 1 | |||||
| k_w = (kernel_size[1] - 1) * dilation[-1] + 1 | |||||
| out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1 | |||||
| out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1 | |||||
| return [n, out_h, out_w, out_channel] | |||||
| class PadAkg(OpInfer): | |||||
| """PadAkg infer""" | |||||
| def _infer_shape(self): | |||||
| shape = list(self.inputs[0].shape) | |||||
| n = len(shape) | |||||
| pad_before = list(self.attrs["head"]) | |||||
| pad_after = list(self.attrs["tail"]) | |||||
| if len(pad_before) != n or len(pad_after) != n: | |||||
| raise GKException("Input dimension and pad mismatch: {}d vs {}d vs {}d" | |||||
| .format(n, len(pad_before), len(pad_after))) | |||||
| out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)] | |||||
| return out_shape | |||||
| class UnPadAkg(OpInfer): | |||||
| """UnPadAkg infer""" | |||||
| def _infer_shape(self): | |||||
| shape = list(self.inputs[0].shape) | |||||
| n = len(shape) | |||||
| unpad_after = list(self.attrs["tail"]) | |||||
| if len(unpad_after) != n: | |||||
| raise GKException("Input dimension and pad mismatch: {}d vs {}d".format(n, len(unpad_after))) | |||||
| out_shape = [shape[i] - unpad_after[i] for i in range(n)] | |||||
| return out_shape | |||||