Browse Source

!32475 set inputs bug fix 1.7

Merge pull request !32475 from Henry Shi/branch_sxy_7
r1.7
i-robot Gitee 3 years ago
parent
commit
a8cebd74fd
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 9 additions and 7 deletions
  1. +9
    -7
      mindspore/python/mindspore/nn/cell.py

+ 9
- 7
mindspore/python/mindspore/nn/cell.py View File

@@ -893,11 +893,11 @@ class Cell(Cell_):
>>>
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
>>> class reluNet(nn.Cell):
>>> def __init__(self):
>>> super(reluNet, self).__init__()
>>> self.relu = nn.ReLU()
>>> def construct(self, x):
>>> return self.relu(x)
... def __init__(self):
... super(reluNet, self).__init__()
... self.relu = nn.ReLU()
... def construct(self, x):
... return self.relu(x)
>>>
>>> net = reluNet()
>>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
@@ -909,9 +909,11 @@ class Cell(Cell_):
This is an experimental interface that is subject to change or deletion.
"""

for ele in self._dynamic_shape_inputs:
if isinstance(ele, (str, int, dict)):
raise TypeError(f"For element in 'set_inputs', the type must be Tensor,\
but got {type(ele)}.")
self._dynamic_shape_inputs = inputs
if isinstance(self._dynamic_shape_inputs[0], (str, int, dict)):
raise TypeError(f"For 'set_inputs, the type must be tuple, but got {type(self._dynamic_shape_inputs[0])}.")

def get_inputs(self):
"""


Loading…
Cancel
Save