|
|
|
@@ -18,14 +18,14 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF |
|
|
|
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException |
|
|
|
from ._utils import Expander, ExpanderInfoValidator as VLD |
|
|
|
|
|
|
|
M_ALIGN = 16 |
|
|
|
N_ALIGN = 16 |
|
|
|
K_ALIGN = 8 |
|
|
|
M_ALIGN = 32 |
|
|
|
N_ALIGN = 32 |
|
|
|
K_ALIGN = 16 |
|
|
|
K_LIMIT = 800 |
|
|
|
MNK_LIMIT = 3 * (10 ** 10) |
|
|
|
N0_CHANNEL_ALIGN = 16 |
|
|
|
N1_CHANNEL_ALIGN = 16 |
|
|
|
C_CHANNEL_ALIGN = 8 |
|
|
|
N0_CHANNEL_ALIGN = 32 |
|
|
|
N1_CHANNEL_ALIGN = 32 |
|
|
|
C_CHANNEL_ALIGN = 16 |
|
|
|
OUT_NHW_ALIGN = 128 |
|
|
|
|
|
|
|
|
|
|
|
@@ -63,8 +63,7 @@ class Conv2D(Expander): |
|
|
|
dilation = self.attrs['dilation'] |
|
|
|
_, h, w, _ = self.inputs[1]['shape'] |
|
|
|
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \ |
|
|
|
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0 and \ |
|
|
|
self.k <= K_LIMIT and self.m * self.n * self.k < MNK_LIMIT: |
|
|
|
self.m % M_ALIGN == 0 and self.n % N_ALIGN == 0 and self.k % K_ALIGN == 0: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
@@ -72,7 +71,8 @@ class Conv2D(Expander): |
|
|
|
type_0 = self.inputs[0]['data_type'] |
|
|
|
type_1 = self.inputs[1]['data_type'] |
|
|
|
if type_0 != "float16" or type_1 != "float16": |
|
|
|
raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) |
|
|
|
raise GKException( |
|
|
|
"inputs type should be float16, but got {} and {}".format(type_0, type_1)) |
|
|
|
|
|
|
|
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']] |
|
|
|
check_format_any(formats, DF.NHWC) |
|
|
|
@@ -80,12 +80,14 @@ class Conv2D(Expander): |
|
|
|
groups = self.attrs['groups'] |
|
|
|
group = self.attrs['group'] |
|
|
|
if groups != 1 or group != 1: |
|
|
|
raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group)) |
|
|
|
raise GKException( |
|
|
|
"groups and group should be both 1, but got {} and {}.".format(groups, group)) |
|
|
|
|
|
|
|
dilation = self.attrs['dilation'] |
|
|
|
check_nd(dilation, 4) |
|
|
|
if dilation != [1, 1, 1, 1]: |
|
|
|
raise GKException("dilation should be all 1, but got {}".format(dilation)) |
|
|
|
raise GKException( |
|
|
|
"dilation should be all 1, but got {}".format(dilation)) |
|
|
|
|
|
|
|
pad_list = self.attrs['pad_list'] |
|
|
|
pad_mode = self.attrs['pad_mode'] |
|
|
|
@@ -100,16 +102,16 @@ class Conv2D(Expander): |
|
|
|
check_nd(stride, 4) |
|
|
|
n0, h0, w0, c0 = shape_0 |
|
|
|
n1, h1, w1, c1 = shape_1 |
|
|
|
if n0 <= N0_CHANNEL_ALIGN: |
|
|
|
raise GKException("N({}) channel of first input should > {}".format(n0, N0_CHANNEL_ALIGN)) |
|
|
|
if n1 < N1_CHANNEL_ALIGN: |
|
|
|
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN)) |
|
|
|
if c0 != c1 or c0 < C_CHANNEL_ALIGN: |
|
|
|
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN)) |
|
|
|
if stride != [1, 1, 2, 2]: |
|
|
|
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3])) |
|
|
|
if (n0 % N0_CHANNEL_ALIGN) != 0: |
|
|
|
raise GKException("N({}) channel of first input should be multiples of {}".format(n0, N0_CHANNEL_ALIGN)) |
|
|
|
if (n1 % N1_CHANNEL_ALIGN) != 0: |
|
|
|
raise GKException("O({}) channel of second input should be multiples of {}".format(n1, N1_CHANNEL_ALIGN)) |
|
|
|
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0: |
|
|
|
raise GKException("C channel of inputs({}, {}) should be same and also be multiples of {}".format( |
|
|
|
c0, c1, C_CHANNEL_ALIGN)) |
|
|
|
# n0 pad |
|
|
|
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN |
|
|
|
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // |
|
|
|
N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN |
|
|
|
# h0, w0 pad |
|
|
|
if self.has_pad: |
|
|
|
h0 = h0 + pad_list[0] + pad_list[1] |
|
|
|
@@ -118,16 +120,29 @@ class Conv2D(Expander): |
|
|
|
c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN |
|
|
|
c1 = c0 |
|
|
|
# n1 pad |
|
|
|
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN |
|
|
|
n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // |
|
|
|
N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN |
|
|
|
|
|
|
|
# check if can optimize to matmul |
|
|
|
self.m, self.n, self.k = n0 * h0 * w0, n1, c1 |
|
|
|
self.can_optimize_to_matmul = self._optimize_to_matmul() |
|
|
|
|
|
|
|
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 |
|
|
|
if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0: |
|
|
|
raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}" |
|
|
|
.format(n0, out_h, out_w, OUT_NHW_ALIGN)) |
|
|
|
# requirements |
|
|
|
if self.can_optimize_to_matmul: |
|
|
|
if self.k > K_LIMIT: |
|
|
|
raise GKException( |
|
|
|
"If transformed to MatMul, C0({}) should not be larger than {}".format(self.k, K_LIMIT)) |
|
|
|
if self.m * self.n * self.k >= MNK_LIMIT: |
|
|
|
raise GKException("If transformed to MatMul, The total size({}) should not be larger than {}".format( |
|
|
|
self.m * self.n * self.k, MNK_LIMIT)) |
|
|
|
else: |
|
|
|
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 |
|
|
|
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0: |
|
|
|
raise GKException("N({}) * H({}) * W({}) of output should be multiplies of {}".format( |
|
|
|
n0, out_h, out_w, OUT_NHW_ALIGN)) |
|
|
|
if stride != [1, 1, 2, 2]: |
|
|
|
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3])) |
|
|
|
|
|
|
|
self.shape_0_pad = [n0, h0, w0, c0] |
|
|
|
self.shape_1_pad = [n1, h1, w1, c1] |
|
|
|
|
|
|
|
|