From 56c669013d9c61c12c51b8597aabfb5acce37f27 Mon Sep 17 00:00:00 2001 From: Henry Date: Sat, 2 Apr 2022 11:06:14 +0800 Subject: [PATCH] set inputs bug fix --- mindspore/python/mindspore/nn/cell.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index 036021f756..b7e2f66978 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -893,11 +893,11 @@ class Cell(Cell_): >>> >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") >>> class reluNet(nn.Cell): - >>> def __init__(self): - >>> super(reluNet, self).__init__() - >>> self.relu = nn.ReLU() - >>> def construct(self, x): - >>> return self.relu(x) + ... def __init__(self): + ... super(reluNet, self).__init__() + ... self.relu = nn.ReLU() + ... def construct(self, x): + ... return self.relu(x) >>> >>> net = reluNet() >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32) @@ -909,9 +909,11 @@ class Cell(Cell_): This is an experimental interface that is subject to change or deletion. """ + for ele in self._dynamic_shape_inputs: + if isinstance(ele, (str, int, dict)): + raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\ + but got {type(ele)}.") self._dynamic_shape_inputs = inputs - if isinstance(self._dynamic_shape_inputs[0], (str, int, dict)): - raise TypeError(f"For 'set_inputs, the type must be tuple, but got {type(self._dynamic_shape_inputs[0])}.") def get_inputs(self): """