|
|
|
@@ -0,0 +1,111 @@ |
|
|
|
# Copyright 2021 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# =========================================================================== |
|
|
|
"""generate json desc for Conv2D""" |
|
|
|
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad |
|
|
|
from mindspore._extends.graph_kernel.model.model import DataFormat as DF |
|
|
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException |
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD |
|
|
|
|
|
|
|
M_ALIGN = 16 |
|
|
|
N_ALIGN = 16 |
|
|
|
K_ALIGN = 8 |
|
|
|
OUT_CHANNEL_ALIGN = 8 |
|
|
|
|
|
|
|
|
|
|
|
@VLD.add_format(DF.NHWC, DF.NHWC) |
|
|
|
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') |
|
|
|
class Conv2D(Expander): |
|
|
|
"""Conv2D expander""" |
|
|
|
def __init__(self, expand_info): |
|
|
|
super().__init__(expand_info) |
|
|
|
self.has_pad = False |
|
|
|
self.can_optimize_to_batchmatmul = False |
|
|
|
|
|
|
|
def _check(self): |
|
|
|
type_0 = self.inputs[0]['data_type'] |
|
|
|
type_1 = self.inputs[1]['data_type'] |
|
|
|
if type_0 != "float16" or type_1 != "float16": |
|
|
|
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) |
|
|
|
|
|
|
|
groups = self.attrs['groups'] |
|
|
|
group = self.attrs['group'] |
|
|
|
if groups != 1 or group != 1: |
|
|
|
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group)) |
|
|
|
|
|
|
|
dilation = self.attrs['dilation'] |
|
|
|
check_nd(dilation, 4) |
|
|
|
if dilation != [1, 1, 1, 1]: |
|
|
|
raise GKException("dilation should be all 1, but got {}".format(dilation)) |
|
|
|
|
|
|
|
pad_list = self.attrs['pad_list'] |
|
|
|
pad_mode = self.attrs['pad_mode'] |
|
|
|
check_nd(pad_list, 4) |
|
|
|
self.has_pad = conv_had_pad(pad_list, pad_mode) |
|
|
|
|
|
|
|
shape_0 = self.inputs[0]['shape'] |
|
|
|
shape_1 = self.inputs[1]['shape'] |
|
|
|
stride = self.attrs['stride'] |
|
|
|
check_nd(shape_0, 4) |
|
|
|
check_nd(shape_1, 4) |
|
|
|
check_nd(stride, 4) |
|
|
|
n0, h0, w0, c0 = shape_0 |
|
|
|
n1, h1, w1, c1 = shape_1 |
|
|
|
if c0 != c1: |
|
|
|
raise GKException("C channel of inputs should be same, but got {} and {}".format(c0, c1)) |
|
|
|
if self.has_pad: |
|
|
|
h0 = h0 + pad_list[0] + pad_list[1] |
|
|
|
w0 = w0 + pad_list[2] + pad_list[3] |
|
|
|
n1 = ((n1 + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN |
|
|
|
m = n0 * h0 * w0 |
|
|
|
n = n1 * h1 * w1 |
|
|
|
k = c1 |
|
|
|
self.can_optimize_to_batchmatmul = False |
|
|
|
if h1 == 1 and w1 == 1 and stride == [1, 1, 1, 1] and m % M_ALIGN == 0 and n % N_ALIGN == 0 and \ |
|
|
|
k % K_ALIGN == 0: |
|
|
|
self.can_optimize_to_batchmatmul = True |
|
|
|
if n0 < 128 and h0 % 2 != 0 and w0 % 2 != 0 and not self.can_optimize_to_batchmatmul: |
|
|
|
raise GKException("Conv2D expander only processes when N({}) > 128, H({}) and W({}) are odd of first input \ |
|
|
|
or current Conv2D can be optimized to BatchMatMul.".format(n0, h0, w0)) |
|
|
|
|
|
|
|
def _expand(self, graph_builder): |
|
|
|
input_0 = self.inputs[0] |
|
|
|
input_1 = self.inputs[1] |
|
|
|
|
|
|
|
pad_value = 0 |
|
|
|
if self.has_pad: |
|
|
|
pad_list = self.attrs['pad_list'] |
|
|
|
pad_before = [0, pad_list[0], pad_list[2], 0] |
|
|
|
pad_after = [0, pad_list[1], pad_list[3], 0] |
|
|
|
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': pad_before, |
|
|
|
'tail': pad_after, |
|
|
|
'pad_val': pad_value}) |
|
|
|
attrs = self.attrs |
|
|
|
attrs['pad_list'] = [0, 0, 0, 0] |
|
|
|
attrs['can_optimize_to_batchmatmul'] = self.can_optimize_to_batchmatmul |
|
|
|
|
|
|
|
out_channel = ((input_1.shape[0] + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN |
|
|
|
if out_channel != input_1.shape[0]: |
|
|
|
out_channel_pad = out_channel - input_1.shape[0] |
|
|
|
pad_before = [0, 0, 0, 0] |
|
|
|
pad_after = [out_channel_pad, 0, 0, 0] |
|
|
|
unpad_after = [0, 0, 0, out_channel_pad] |
|
|
|
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': pad_before, |
|
|
|
'tail': pad_after, |
|
|
|
'pad_val': pad_value}) |
|
|
|
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) |
|
|
|
result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after}) |
|
|
|
else: |
|
|
|
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) |
|
|
|
return result |