Browse Source

Pre Merge pull request !16102 from looop5/expand_conv2d

pull/16102/MERGE
looop5 Gitee 5 years ago
parent
commit
f233cfffb3
5 changed files with 205 additions and 1 deletions
  1. +1
    -0
      mindspore/_extends/graph_kernel/expanders/__init__.py
  2. +111
    -0
      mindspore/_extends/graph_kernel/expanders/conv2d.py
  3. +2
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  4. +3
    -0
      mindspore/_extends/graph_kernel/model/model.py
  5. +88
    -1
      mindspore/_extends/graph_kernel/model/op_infer.py

+ 1
- 0
mindspore/_extends/graph_kernel/expanders/__init__.py View File

@@ -51,3 +51,4 @@ from .lamb_apply_optimizer_assign import LambApplyOptimizerAssign
from .lamb_apply_weight_assign import LambApplyWeightAssign
from .softmax_grad_ext import SoftmaxGradExt
from .square_sum_v1 import SquareSumV1
from .conv2d import Conv2D

+ 111
- 0
mindspore/_extends/graph_kernel/expanders/conv2d.py View File

@@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for Conv2D"""
from mindspore._extends.graph_kernel.model.op_infer import check_nd, conv_had_pad
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD

M_ALIGN = 16
N_ALIGN = 16
K_ALIGN = 8
OUT_CHANNEL_ALIGN = 8


@VLD.add_format(DF.NHWC, DF.NHWC)
@VLD.check_attrs('pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
class Conv2D(Expander):
"""Conv2D expander"""
def __init__(self, expand_info):
super().__init__(expand_info)
self.has_pad = False
self.can_optimize_to_batchmatmul = False

def _check(self):
type_0 = self.inputs[0]['data_type']
type_1 = self.inputs[1]['data_type']
if type_0 != "float16" or type_1 != "float16":
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1))

groups = self.attrs['groups']
group = self.attrs['group']
if groups != 1 or group != 1:
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group))

dilation = self.attrs['dilation']
check_nd(dilation, 4)
if dilation != [1, 1, 1, 1]:
raise GKException("dilation should be all 1, but got {}".format(dilation))

pad_list = self.attrs['pad_list']
pad_mode = self.attrs['pad_mode']
check_nd(pad_list, 4)
self.has_pad = conv_had_pad(pad_list, pad_mode)

shape_0 = self.inputs[0]['shape']
shape_1 = self.inputs[1]['shape']
stride = self.attrs['stride']
check_nd(shape_0, 4)
check_nd(shape_1, 4)
check_nd(stride, 4)
n0, h0, w0, c0 = shape_0
n1, h1, w1, c1 = shape_1
if c0 != c1:
raise GKException("C channel of inputs should be same, but got {} and {}".format(c0, c1))
if self.has_pad:
h0 = h0 + pad_list[0] + pad_list[1]
w0 = w0 + pad_list[2] + pad_list[3]
n1 = ((n1 + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN
m = n0 * h0 * w0
n = n1 * h1 * w1
k = c1
self.can_optimize_to_batchmatmul = False
if h1 == 1 and w1 == 1 and stride == [1, 1, 1, 1] and m % M_ALIGN == 0 and n % N_ALIGN == 0 and \
k % K_ALIGN == 0:
self.can_optimize_to_batchmatmul = True
if n0 < 128 and h0 % 2 != 0 and w0 % 2 != 0 and not self.can_optimize_to_batchmatmul:
raise GKException("Conv2D expander only processes when N({}) > 128, H({}) and W({}) are odd of first input \
or current Conv2D can be optimized to BatchMatMul.".format(n0, h0, w0))

def _expand(self, graph_builder):
input_0 = self.inputs[0]
input_1 = self.inputs[1]

pad_value = 0
if self.has_pad:
pad_list = self.attrs['pad_list']
pad_before = [0, pad_list[0], pad_list[2], 0]
pad_after = [0, pad_list[1], pad_list[3], 0]
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': pad_before,
'tail': pad_after,
'pad_val': pad_value})
attrs = self.attrs
attrs['pad_list'] = [0, 0, 0, 0]
attrs['can_optimize_to_batchmatmul'] = self.can_optimize_to_batchmatmul

out_channel = ((input_1.shape[0] + OUT_CHANNEL_ALIGN - 1) // OUT_CHANNEL_ALIGN) * OUT_CHANNEL_ALIGN
if out_channel != input_1.shape[0]:
out_channel_pad = out_channel - input_1.shape[0]
pad_before = [0, 0, 0, 0]
pad_after = [out_channel_pad, 0, 0, 0]
unpad_after = [0, 0, 0, out_channel_pad]
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': pad_before,
'tail': pad_after,
'pad_val': pad_value})
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
else:
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
return result

+ 2
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -298,6 +298,8 @@ class GraphSplitGpu(GraphSplitByPattern):
REDUCE_FUSE_DEPTH = 20

def get_default_mode(self, op):
if op.prim in ["PadAkg", "UnPadAkg"]:
return self.Area.MODE_COMPOSITE
pattern = PrimLib.iter_type(op)
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE



+ 3
- 0
mindspore/_extends/graph_kernel/model/model.py View File

@@ -187,6 +187,9 @@ class PrimLib:
'BroadcastTo': Prim(BROADCAST),
'MatMul': Prim(OPAQUE),
'TransData': Prim(OPAQUE),
'Conv2D': Prim(OPAQUE),
'PadAkg': Prim(OPAQUE),
'UnPadAkg': Prim(OPAQUE),
}

default_primtive = Prim(UNKNOWN)


+ 88
- 1
mindspore/_extends/graph_kernel/model/op_infer.py View File

@@ -14,7 +14,6 @@
# ===========================================================================
"""GraphKernel Op Infer"""


import copy
import sys
from functools import reduce
@@ -24,6 +23,7 @@ from .model import PrimLib, DataFormat as DF

def infer(op_name, inputs, attrs):
"""infer shape dtype and format"""

def _create_opinfer():
if hasattr(sys.modules[__name__], op_name):
op_cls = getattr(sys.modules[__name__], op_name)
@@ -38,6 +38,7 @@ def infer(op_name, inputs, attrs):
raise GKException("OpInfo does not support op {}".format(op_name))
op_cls = getattr(sys.modules[__name__], cls_name)
return op_cls(op_name, inputs, attrs)

return _create_opinfer().infer()


@@ -236,3 +237,89 @@ class Select(_Elemwise):

def _infer_type(self):
return self.inputs[1].dtype


def check_nd(data, nd):
if not isinstance(data, (list, tuple)) or len(data) != nd:
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))


def check_one(data):
if not isinstance(data, (list, tuple)):
raise GKException("input should be list or tuple")
for i, d in enumerate(data):
if d != 1:
raise GKException("value at index {} should be 1, but got {}.".format(i, d))


def conv_had_pad(pad_list, pad_mode):
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
return True
if pad_mode not in ["VALID", "valid"]:
for _, pad in enumerate(pad_list):
if pad != 0:
return True
return False


class Conv2D(OpInfer):
"""Conv2D infer"""
def _infer_shape(self):
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
check_nd(shape_0, 4)
check_nd(shape_1, 4)

format_0 = self.inputs[0].data_format
format_1 = self.inputs[1].data_format
if format_0 != DF.NHWC or format_1 != DF.NHWC:
raise GKException("Conv2D's inputs format must be NHWC, but got {} and {}".format(format_0, format_1))

n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
pad_list = self.attrs["pad_list"]
pad_mode = self.attrs["pad_mode"]
kernel_size = self.attrs["kernel_size"]
stride = self.attrs["stride"]
dilation = self.attrs["dilation"]
check_nd(pad_list, 4)
check_nd(kernel_size, 2)
check_nd(stride, 4)
check_nd(dilation, 4)

has_pad = conv_had_pad(pad_list, pad_mode)
if not has_pad:
pad_list = [0, 0, 0, 0]

k_h = (kernel_size[0] - 1) * dilation[-2] + 1
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
return [n, out_h, out_w, out_channel]


class PadAkg(OpInfer):
"""PadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)
pad_before = list(self.attrs["head"])
pad_after = list(self.attrs["tail"])
if len(pad_before) != n or len(pad_after) != n:
raise GKException("Input dimension and pad mismatch: {}d vs {}d vs {}d"
.format(n, len(pad_before), len(pad_after)))
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
return out_shape


class UnPadAkg(OpInfer):
"""UnPadAkg infer"""
def _infer_shape(self):
shape = list(self.inputs[0].shape)
n = len(shape)
unpad_after = list(self.attrs["tail"])
if len(unpad_after) != n:
raise GKException("Input dimension and pad mismatch: {}d vs {}d".format(n, len(unpad_after)))
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
return out_shape

Loading…
Cancel
Save