From 2ba3e962e148b1216ecad32dc6694626685bd544 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Fri, 18 Sep 2020 14:38:47 +0800 Subject: [PATCH] add shape check --- mindspore/__init__.py | 2 +- mindspore/nn/wrap/loss_scale.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mindspore/__init__.py b/mindspore/__init__.py index b5085c01a7..0cac28e33b 100755 --- a/mindspore/__init__.py +++ b/mindspore/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""MindSpore package.""" +""".. MindSpore package.""" from ._version_check import check_version_and_env_config from . import common, train diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 47a66d7ffd..22c3114fb2 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): network (Cell): The training network. The network only supports single output. optimizer (Cell): Optimizer for updating the weights. scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value - is Tensor type, Tensor with shape :math:`()`. + is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`. Inputs: - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. @@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), name="scale_sense") elif isinstance(scale_sense, Tensor): - self.scale_sense = Parameter(scale_sense, name='scale_sense') + if scale_sense.shape == (1,) or scale_sense.shape == (): + self.scale_sense = Parameter(scale_sense, name='scale_sense') + else: + raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) else: raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) @@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): if self.scale_sense and isinstance(sens, Tensor): self.scale_sense.set_data(sens) else: - raise TypeError("The input type must be Tensor,but got {}".format(type(sens))) + raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))