| @@ -1185,8 +1185,19 @@ class ExecuteOrderGenerator { | |||||
| MS_EXCEPTION_IF_NULL(target); | MS_EXCEPTION_IF_NULL(target); | ||||
| auto para = param_write_times.find(target); | auto para = param_write_times.find(target); | ||||
| if (para != param_write_times.end() && para->second == 1) { | if (para != param_write_times.end() && para->second == 1) { | ||||
| // If target only write once, replace target with source and erase assign node. | |||||
| // Check source of the Assign. | |||||
| auto &source = node->inputs().at(kAssignSourceIndex); | auto &source = node->inputs().at(kAssignSourceIndex); | ||||
| MS_EXCEPTION_IF_NULL(source); | |||||
| if (source->isa<Parameter>()) { | |||||
| auto it = param_write_times.find(source); | |||||
| if (it != param_write_times.end() && it->second > 0) { | |||||
| // Skip if Assign source is a parameter and be written in other place. | |||||
| ++iter; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // If target only write once, and source not be written, | |||||
| // replace target with source and erase the Assign node. | |||||
| auto kg = target->func_graph()->cast<KernelGraphPtr>(); | auto kg = target->func_graph()->cast<KernelGraphPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(kg); | MS_EXCEPTION_IF_NULL(kg); | ||||
| kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); | kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); | ||||
| @@ -1429,6 +1429,33 @@ def test_if_cast(): | |||||
| np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) | np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_while_forward(): | |||||
| class MyWhileNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.max = P.ReduceMax() | |||||
| def construct(self, idx, end, x): | |||||
| while idx < end: | |||||
| part = x[idx, :, :] | |||||
| max_num = self.max(part) | |||||
| x[idx, :, 0:2] = max_num | |||||
| idx = idx + 1 | |||||
| return x | |||||
| net = MyWhileNet() | |||||
| idx = Tensor(np.array(0), dtype=ms.int32) | |||||
| end = Tensor(np.array(2), dtype=ms.int32) | |||||
| x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) | |||||
| output = net(idx, end, x) | |||||
| expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32) | |||||
| assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) | |||||
| @pytest.mark.skip(reason="not supported yet") | @pytest.mark.skip(reason="not supported yet") | ||||
| def test_multi_add_assign(): | def test_multi_add_assign(): | ||||
| class Net(Cell): | class Net(Cell): | ||||