|
|
|
@@ -203,6 +203,10 @@ DShape ToNz(const DShape &default_shape) { |
|
|
|
auto len = default_shape.size();
|
|
|
|
DShape leading_shape;
|
|
|
|
DShape tail_shape;
|
|
|
|
if (default_shape.size() == 1 && default_shape[0] == 1) {
|
|
|
|
// # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
|
|
|
|
return default_shape;
|
|
|
|
}
|
|
|
|
if (default_shape.size() > nz_size) {
|
|
|
|
(void)leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - SizeToLong(nz_size));
|
|
|
|
}
|
|
|
|
@@ -408,10 +412,12 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) { |
|
|
|
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
|
|
|
|
auto stride = GetListInt(attrs.find("stride")->second);
|
|
|
|
auto dilation = GetListInt(attrs.find("dilation")->second);
|
|
|
|
check_nd(pad_list, 4);
|
|
|
|
check_nd(kernel_size, 2);
|
|
|
|
check_nd(stride, 4);
|
|
|
|
check_nd(dilation, 4);
|
|
|
|
constexpr auto dim_len = 4;
|
|
|
|
check_nd(pad_list, dim_len);
|
|
|
|
constexpr auto kernel_len = 2;
|
|
|
|
check_nd(kernel_size, kernel_len);
|
|
|
|
check_nd(stride, dim_len);
|
|
|
|
check_nd(dilation, dim_len);
|
|
|
|
bool has_pad = false;
|
|
|
|
if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
|
|
|
|
has_pad = true;
|
|
|
|
|