diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 765bbc22bf..b9292f9db0 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -267,11 +267,87 @@ py::object BuildValue(const ValuePtr &value_ptr) { return ValuePtrToPyData(value_ptr); } } + +py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) { + auto arg_tuple = dyn_cast(abs_base); + size_t len = arg_tuple->size(); + py::tuple shape_tuple(len); + py::tuple dtype_tuple(len); + py::tuple min_shape_tuple(len); + py::tuple max_shape_tuple(len); + auto dic = py::dict(); + bool dyn_shape = false; + + for (size_t i = 0; i < len; i++) { + auto arg = arg_tuple->elements()[i]; + py::dict out = ConvertAbstractToPython(arg); + shape_tuple[i] = out[ATTR_SHAPE]; + dtype_tuple[i] = out[ATTR_DTYPE]; + + // Elements in tuple is tensor, which shape is dynamic. + if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { + min_shape_tuple[i] = out[ATTR_MIN_SHAPE]; + max_shape_tuple[i] = out[ATTR_MAX_SHAPE]; + dyn_shape = true; + } else { + min_shape_tuple[i] = out[ATTR_SHAPE]; + max_shape_tuple[i] = out[ATTR_SHAPE]; + } + } + dic[ATTR_SHAPE] = shape_tuple; + dic[ATTR_DTYPE] = dtype_tuple; + dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); + + if (dyn_shape) { + dic[ATTR_MIN_SHAPE] = min_shape_tuple; + dic[ATTR_MAX_SHAPE] = max_shape_tuple; + } + + return dic; +} + +py::dict AbstractListToPython(const AbstractBasePtr &abs_base) { + auto arg_list = dyn_cast(abs_base); + size_t len = arg_list->size(); + py::list shape_list(len); + py::list dtype_list(len); + py::list min_shape_list(len); + py::list max_shape_list(len); + auto dic = py::dict(); + bool dyn_shape = false; + + for (size_t i = 0; i < len; i++) { + py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); + shape_list[i] = out[ATTR_SHAPE]; + dtype_list[i] = out[ATTR_DTYPE]; + + // Elements in list is tensor, which shape is dynamic. + if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) { + min_shape_list[i] = out[ATTR_MIN_SHAPE]; + max_shape_list[i] = out[ATTR_MAX_SHAPE]; + dyn_shape = true; + } else { + min_shape_list[i] = out[ATTR_SHAPE]; + max_shape_list[i] = out[ATTR_SHAPE]; + } + } + + if (dyn_shape) { + dic[ATTR_MIN_SHAPE] = min_shape_list; + dic[ATTR_MAX_SHAPE] = max_shape_list; + } + + dic[ATTR_SHAPE] = shape_list; + dic[ATTR_DTYPE] = dtype_list; + dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); + + return dic; +} } // end anonymous namespace py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { MS_EXCEPTION_IF_NULL(abs_base); - py::dict dic; + auto dic = py::dict(); if (abs_base->isa()) { auto arg_tensor = dyn_cast(abs_base); dic[ATTR_SHAPE] = arg_tensor->shape()->shape(); @@ -311,33 +387,9 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_DTYPE] = py::ellipsis(); dic[ATTR_VALUE] = py::ellipsis(); } else if (abs_base->isa()) { - auto arg_tuple = dyn_cast(abs_base); - size_t len = arg_tuple->size(); - py::tuple shape_tuple(len); - py::tuple dtype_tuple(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); - shape_tuple[i] = out[ATTR_SHAPE]; - dtype_tuple[i] = out[ATTR_DTYPE]; - } - dic[ATTR_SHAPE] = shape_tuple; - dic[ATTR_DTYPE] = dtype_tuple; - dic[ATTR_VALUE] = BuildValue(arg_tuple->BuildValue()); + return AbstractTupleToPython(abs_base); } else if (abs_base->isa()) { - auto arg_list = dyn_cast(abs_base); - size_t len = arg_list->size(); - py::list shape_list(len); - py::list dtype_list(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); - shape_list[i] = out[ATTR_SHAPE]; - dtype_list[i] = out[ATTR_DTYPE]; - } - dic[ATTR_SHAPE] = shape_list; - dic[ATTR_DTYPE] = dtype_list; - dic[ATTR_VALUE] = BuildValue(arg_list->BuildValue()); + return AbstractListToPython(abs_base); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::none(); diff --git a/tests/ut/python/ops/test_dynamic_shape.py b/tests/ut/python/ops/test_dynamic_shape.py index 9cc4673fdf..5813540fd8 100755 --- a/tests/ut/python/ops/test_dynamic_shape.py +++ b/tests/ut/python/ops/test_dynamic_shape.py @@ -108,3 +108,21 @@ def test_gatherv2(): y = Tensor(np.ones([8], dtype=np.int32)) net = Net() net(x, y) + + +def test_addn(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.unq = P.Unique() + self.addn = P.AddN() + + def construct(self, x): + u, _ = self.unq(x) + u = self.addn((u, u, u)) + z = self.addn([u, u]) + return z + + y = Tensor(np.ones([8], dtype=np.int32)) + net = Net() + net(y)