Browse Source

[GraphKernel][Ascend]fix a bug about ToNz

tags/v1.6.0
hanhuifeng2020 4 years ago
parent
commit
cf0376a6a3
2 changed files with 13 additions and 4 deletions
  1. +3
    -0
      mindspore/_extends/graph_kernel/model/op_infer.py
  2. +10
    -4
      mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc

+ 3
- 0
mindspore/_extends/graph_kernel/model/op_infer.py View File

@@ -121,6 +121,9 @@ class _Elemwise(OpInfer):
@staticmethod
def defaultformat_to_nz(default_shape):
"""default format shape to fractal_Nz format shape"""
# As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
if len(default_shape) == 1 and default_shape[0] == 1:
return default_shape
more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
# (32) or (1, 32) -> (2, 1, 1, 16)
if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):


+ 10
- 4
mindspore/ccsrc/backend/optimizer/graph_kernel/model/op_node.cc View File

@@ -203,6 +203,10 @@ DShape ToNz(const DShape &default_shape) {
auto len = default_shape.size();
DShape leading_shape;
DShape tail_shape;
if (default_shape.size() == 1 && default_shape[0] == 1) {
// # As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
return default_shape;
}
if (default_shape.size() > nz_size) {
(void)leading_shape.insert(leading_shape.end(), default_shape.begin(), default_shape.end() - SizeToLong(nz_size));
}
@@ -408,10 +412,12 @@ DShape Conv2dOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
auto kernel_size = GetListInt(attrs.find("kernel_size")->second);
auto stride = GetListInt(attrs.find("stride")->second);
auto dilation = GetListInt(attrs.find("dilation")->second);
check_nd(pad_list, 4);
check_nd(kernel_size, 2);
check_nd(stride, 4);
check_nd(dilation, 4);
constexpr auto dim_len = 4;
check_nd(pad_list, dim_len);
constexpr auto kernel_len = 2;
check_nd(kernel_size, kernel_len);
check_nd(stride, dim_len);
check_nd(dilation, dim_len);
bool has_pad = false;
if (pad_list[0] != pad_list[1] || pad_list[2] != pad_list[3]) {
has_pad = true;


Loading…
Cancel
Save