# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # =========================================================================== """generate json desc for Conv2D""" from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from ._utils import Expander, ExpanderInfoValidator as VLD M_ALIGN = 32 N_ALIGN = 32 K_ALIGN = 16 K_LIMIT = 800 MNK_LIMIT = 3 * (10 ** 10) N0_CHANNEL_ALIGN = 32 N1_CHANNEL_ALIGN = 32 C_CHANNEL_ALIGN = 16 OUT_NHW_ALIGN = 128 @VLD.add_format(DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC) @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') class Conv2D(Expander): """ Conv2D expander Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped. Conditions to expand: inputs are NHWC format and float16. attr groups and group are 1. attr dilation are all 1. N channel of inputs > 16. C channel of inputs > 8. output N*H*W are multiplies of 128. """ def __init__(self, expand_info): super().__init__(expand_info) self.dst_type = self.outputs[0]['data_type'] self.dst_format = self.outputs[0]['format'] self.has_pad = False self.can_optimize_to_matmul = False self.shape_0_pad = self.inputs[0]['shape'] self.shape_1_pad = self.inputs[1]['shape'] self.m = 0 self.n = 0 self.k = 0 def _optimize_to_matmul(self): stride = self.attrs['stride'] dilation = self.attrs['dilation'] _, h, w, _ = self.inputs[1]['shape'] if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \ self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0: return True return False def _check(self): type_0 = self.inputs[0]['data_type'] type_1 = self.inputs[1]['data_type'] if type_0 != "float16" or type_1 != "float16": raise GKException( "inputs type should be float16, but got {} and {}".format(type_0, type_1)) formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']] check_format_any(formats, DF.NHWC) groups = self.attrs['groups'] group = self.attrs['group'] if groups != 1 or group != 1: raise GKException( "groups and group should be both 1, but got {} and {}.".format(groups, group)) dilation = self.attrs['dilation'] check_nd(dilation, 4) if dilation != [1, 1, 1, 1]: raise GKException( "dilation should be all 1, but got {}".format(dilation)) pad_list = self.attrs['pad_list'] pad_mode = self.attrs['pad_mode'] check_nd(pad_list, 4) self.has_pad = conv_had_pad(pad_list, pad_mode) shape_0 = self.inputs[0]['shape'] shape_1 = self.inputs[1]['shape'] stride = self.attrs['stride'] check_nd(shape_0, 4) check_nd(shape_1, 4) check_nd(stride, 4) n0, h0, w0, c0 = shape_0 n1, h1, w1, c1 = shape_1 if (n0 % N0_CHANNEL_ALIGN) != 0: raise GKException("N({}) channel of first input should be multiples of {}".format(n0, N0_CHANNEL_ALIGN)) if (n1 % N1_CHANNEL_ALIGN) != 0: raise GKException("O({}) channel of second input should be multiples of {}".format(n1, N1_CHANNEL_ALIGN)) if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0: raise GKException("C channel of inputs({}, {}) should be same and also be multiples of {}".format( c0, c1, C_CHANNEL_ALIGN)) # n0 pad n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN # h0, w0 pad if self.has_pad: h0 = h0 + pad_list[0] + pad_list[1] w0 = w0 + pad_list[2] + pad_list[3] # c0, c1 pad c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN c1 = c0 # n1 pad n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN # check if can optimize to matmul self.m, self.n, self.k = n0 * h0 * w0, n1, c1 self.can_optimize_to_matmul = self._optimize_to_matmul() # requirements if self.can_optimize_to_matmul: if self.k > K_LIMIT: raise GKException( "If transformed to MatMul, C0({}) should not be larger than {}".format(self.k, K_LIMIT)) if self.m * self.n * self.k >= MNK_LIMIT: raise GKException("If transformed to MatMul, The total size({}) should not be larger than {}".format( self.m * self.n * self.k, MNK_LIMIT)) else: out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0: raise GKException("N({}) * H({}) * W({}) of output should be multiplies of {}".format( n0, out_h, out_w, OUT_NHW_ALIGN)) if stride != [1, 1, 2, 2]: raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3])) self.shape_0_pad = [n0, h0, w0, c0] self.shape_1_pad = [n1, h1, w1, c1] def _expand(self, graph_builder): input_0 = self.inputs[0] input_1 = self.inputs[1] n0, _, _, c0 = input_0.shape n1, _, _, c1 = input_1.shape n0_p, h0_p, w0_p, c0_p = self.shape_0_pad n1_p, _, _, c1_p = self.shape_1_pad pad_value = 0 # input0 pad input_0_pad_before = [0, 0, 0, 0] input_0_pad_after = [0, 0, 0, 0] if self.has_pad: pad_list = self.attrs['pad_list'] input_0_pad_before = [0, pad_list[0], pad_list[2], 0] input_0_pad_after = [0, pad_list[1], pad_list[3], 0] input_0_pad_after[0] = n0_p - n0 input_0_pad_after[3] = c0_p - c0 if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]: input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before, 'tail': input_0_pad_after, 'pad_val': pad_value}) # input1 pad input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1] if input_1_pad_after != [0, 0, 0, 0]: input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0], 'tail': input_1_pad_after, 'pad_val': pad_value}) if self.can_optimize_to_matmul: a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]}) b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]}) c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False, 'transpose_b': True, 'dst_type': self.dst_type}) result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p], 'format': self.dst_format}) else: attrs = self.attrs attrs['pad_list'] = [0, 0, 0, 0] attrs['dst_type'] = self.dst_type result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) # unpad unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]] if unpad_after != [0, 0, 0, 0]: result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after}) return result