|
|
|
@@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('data_format', self.format) |
|
|
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) |
|
|
|
self.group = validator.check_positive_int(group, 'group', self.name) |
|
|
|
self.add_prim_attr('groups', self.group) |
|
|
|
self.add_prim_attr('offset_a', 0) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape=None): |
|
|
|
@@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('pad_mode', pad_mode) |
|
|
|
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) |
|
|
|
self.group = validator.check_positive_int(group, 'group', self.name) |
|
|
|
self.add_prim_attr('groups', self.group) |
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) |
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC": |
|
|
|
raise ValueError("NHWC format only support in GPU target.") |
|
|
|
|