|
|
|
@@ -180,7 +180,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
network (Cell): The training network. The network only supports single output. |
|
|
|
optimizer (Cell): Optimizer for updating the weights. |
|
|
|
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value |
|
|
|
is Tensor type, Tensor with shape :math:`()`. |
|
|
|
is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. |
|
|
|
@@ -230,7 +230,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), |
|
|
|
name="scale_sense") |
|
|
|
elif isinstance(scale_sense, Tensor): |
|
|
|
self.scale_sense = Parameter(scale_sense, name='scale_sense') |
|
|
|
if scale_sense.shape == (1,) or scale_sense.shape == (): |
|
|
|
self.scale_sense = Parameter(scale_sense, name='scale_sense') |
|
|
|
else: |
|
|
|
raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) |
|
|
|
else: |
|
|
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) |
|
|
|
|
|
|
|
@@ -284,4 +287,4 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): |
|
|
|
if self.scale_sense and isinstance(sens, Tensor): |
|
|
|
self.scale_sense.set_data(sens) |
|
|
|
else: |
|
|
|
raise TypeError("The input type must be Tensor,but got {}".format(type(sens))) |
|
|
|
raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) |