|
|
|
@@ -161,7 +161,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 2: |
|
|
|
stride_h = stride[0] |
|
|
|
stride_w = stride[1] |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 3: |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 4: |
|
|
|
stride_h = stride[2] |
|
|
|
stride_w = stride[3] |
|
|
|
else: |
|
|
|
@@ -328,7 +328,7 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 2: |
|
|
|
stride_h = stride[0] |
|
|
|
stride_w = stride[1] |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 3: |
|
|
|
elif isinstance(stride, tuple) and len(stride) == 4: |
|
|
|
stride_h = stride[2] |
|
|
|
stride_w = stride[3] |
|
|
|
else: |
|
|
|
@@ -340,7 +340,7 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): |
|
|
|
elif isinstance(dilation, tuple) and len(dilation) == 2: |
|
|
|
dilation_h = dilation[0] |
|
|
|
dilation_w = dilation[1] |
|
|
|
elif isinstance(dilation, tuple) and len(dilation) == 3: |
|
|
|
elif isinstance(dilation, tuple) and len(dilation) == 4: |
|
|
|
dilation_h = dilation[2] |
|
|
|
dilation_w = dilation[3] |
|
|
|
else: |
|
|
|
|