|
|
|
@@ -100,12 +100,14 @@ class Conv2D(Expander): |
|
|
|
check_nd(stride, 4) |
|
|
|
n0, h0, w0, c0 = shape_0 |
|
|
|
n1, h1, w1, c1 = shape_1 |
|
|
|
if n0 < N0_CHANNEL_ALIGN: |
|
|
|
raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN)) |
|
|
|
if n0 <= N0_CHANNEL_ALIGN: |
|
|
|
raise GKException("N({}) channel of first input should > {}".format(n0, N0_CHANNEL_ALIGN)) |
|
|
|
if n1 < N1_CHANNEL_ALIGN: |
|
|
|
raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN)) |
|
|
|
if c0 != c1 or c0 < C_CHANNEL_ALIGN: |
|
|
|
raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN)) |
|
|
|
if stride != [1, 1, 2, 2]: |
|
|
|
raise GKException("Stride H and W should be [2, 2] but got [{}, {}]".format(stride[2], stride[3])) |
|
|
|
# n0 pad |
|
|
|
n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN |
|
|
|
# h0, w0 pad |
|
|
|
|