diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc index 26f0e70e88..c15611f92b 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc @@ -39,9 +39,9 @@ void CustomActor::Run(OpContext *const ctx) { std::string error_info = "Launch custom kernel exception: " + node->fullname_with_scope(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*ctx), error_info); } - // update the output addr size after initop && updateop, because after the initop & updateop, the shape of output - // maybe changed - if (AnfUtils::GetCustomActorType(kernel_.lock()) == kInit || + // Update the output addr size after inferop && updateop, because after the inferop & updateop, the shape of output + // maybe changed. + if (AnfUtils::GetCustomActorType(kernel_.lock()) == kInfer || AnfUtils::GetCustomActorType(kernel_.lock()) == kUpdate) { auto base_node = AnfUtils::GetCustomActorBaseNode(kernel_.lock()); auto kernel_info = dynamic_cast(base_node->kernel_info()); diff --git a/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py b/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py index 28e0adae23..91fa92d33d 100644 --- a/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py +++ b/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py @@ -44,6 +44,19 @@ class UniqueSquare(nn.Cell): return self.square(x) +class UniqueSquareRelu(nn.Cell): + def __init__(self): + super(UniqueSquareRelu, self).__init__() + self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU") + self.square_cpu = P.Square().add_prim_attr("primitive_target", "CPU") + self.relu = P.ReLU() + + def construct(self, x): + x, _ = self.unique_cpu(x) + x = self.square_cpu(x) + return self.relu(x) + + class UniqueReshapeAdd(nn.Cell): def __init__(self): super(UniqueReshapeAdd, self).__init__() @@ -99,6 +112,24 @@ def test_unique_square(): assert (output.asnumpy() == expect).all() +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unique_square_relu(): + """ + Feature: Dynamic shape with heterogeneity. + Description: Test unique, square and relu kernels in dynamic shape with heterogeneity scenarios. + Expectation: The value and shape of output are the expected values. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.float32) + net = UniqueSquareRelu() + output = net(x) + expect = np.array([1, 4, 9]) + assert (output.asnumpy() == expect).all() + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training