Browse Source

fix reshape reshape case

tags/v0.5.0-beta
yao_yf 5 years ago
parent
commit
96c9569dca
2 changed files with 25 additions and 2 deletions
  1. +3
    -1
      mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h
  2. +22
    -1
      tests/ut/python/parallel/test_auto_parallel_reshape.py

+ 3
- 1
mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h View File

@@ -88,7 +88,9 @@ class TwoReshapeEliminater : public AnfVisitor {

auto fg = node->func_graph();
if (fg != nullptr && x_ != nullptr && shape_ != nullptr) {
return fg->NewCNode({NewValueNode(prim_), x_, shape_});
auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_});
new_node->set_abstract(node->abstract());
return new_node;
}
return nullptr;
}


+ 22
- 1
tests/ut/python/parallel/test_auto_parallel_reshape.py View File

@@ -45,7 +45,6 @@ class GradWrap(nn.Cell):
return C.grad_all(self.network)(x)


# core dump, step_auto_parallel should SetInputs for transpose axis
def test_reshape_matmul():
class Net(nn.Cell):
def __init__(self):
@@ -68,6 +67,28 @@ def test_reshape_matmul():
net.set_auto_parallel()
_executor.compile(net, x)

def test_reshape_reshape():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = P.Reshape()
self.relu = P.ReLU()

def construct(self, x):
x = self.relu(x)
out = self.reshape(x, (64, 28))
out = self.reshape(out, (64, 28, 1))
return out

size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)

net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)


def test_reshape_auto_1():
class Net(nn.Cell):


Loading…
Cancel
Save