Browse Source

!6474 add shape check

Merge pull request !6474 from lijiaqi/valiation
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
05027e6c41
2 changed files with 7 additions and 4 deletions
  1. +1
    -1
      mindspore/__init__.py
  2. +6
    -3
      mindspore/nn/wrap/loss_scale.py

+ 1
- 1
mindspore/__init__.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""MindSpore package."""
""".. MindSpore package."""


from ._version_check import check_version_and_env_config from ._version_check import check_version_and_env_config
from . import common, train from . import common, train


+ 6
- 3
mindspore/nn/wrap/loss_scale.py View File

@@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
network (Cell): The training network. The network only supports single output. network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights. optimizer (Cell): Optimizer for updating the weights.
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value
is Tensor type, Tensor with shape :math:`()`.
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`.


Inputs: Inputs:
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`.
@@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
name="scale_sense") name="scale_sense")
elif isinstance(scale_sense, Tensor): elif isinstance(scale_sense, Tensor):
self.scale_sense = Parameter(scale_sense, name='scale_sense')
if scale_sense.shape == (1,) or scale_sense.shape == ():
self.scale_sense = Parameter(scale_sense, name='scale_sense')
else:
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape))
else: else:
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense)))


@@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
if self.scale_sense and isinstance(sens, Tensor): if self.scale_sense and isinstance(sens, Tensor):
self.scale_sense.set_data(sens) self.scale_sense.set_data(sens)
else: else:
raise TypeError("The input type must be Tensor,but got {}".format(type(sens)))
raise TypeError("The input type must be Tensor, but got {}".format(type(sens)))

Loading…
Cancel
Save