|
|
|
@@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
network (Cell): The training network. The network only supports single output. |
|
|
|
optimizer (Cell): Optimizer for updating the weights. |
|
|
|
scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value |
|
|
|
is Tensor type, Tensor with shape :math:`()`. Default: None. |
|
|
|
is Tensor type, Tensor with shape :math:`()`. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. |
|
|
|
@@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
|
|
|
|
- **loss** (Tensor) - Tensor with shape :math:`()`. |
|
|
|
- **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. |
|
|
|
- **loss scaling value** (Tensor) - Tensor with shape :math:`()` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net_with_loss = Net() |
|
|
|
@@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
>>> output = train_network(inputs, label, scaling_sens) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network, optimizer, scale_sense=None): |
|
|
|
def __init__(self, network, optimizer, scale_sense): |
|
|
|
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
@@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) |
|
|
|
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE |
|
|
|
|
|
|
|
self.scale_sense = None |
|
|
|
self.loss_scaling_manager = None |
|
|
|
if isinstance(scale_sense, Cell): |
|
|
|
self.loss_scaling_manager = scale_sense |
|
|
|
self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), |
|
|
|
name="scale_sense") |
|
|
|
if isinstance(scale_sense, Tensor): |
|
|
|
elif isinstance(scale_sense, Tensor): |
|
|
|
self.scale_sense = Parameter(scale_sense, name='scale_sense') |
|
|
|
else: |
|
|
|
raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) |
|
|
|
|
|
|
|
@C.add_flags(has_effect=True) |
|
|
|
def construct(self, *inputs): |
|
|
|
@@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell): |
|
|
|
"""If the user has set the sens in the training process and wants to reassign the value, he can call |
|
|
|
this function again to make modification, and sens needs to be of type Tensor.""" |
|
|
|
if self.scale_sense and isinstance(sens, Tensor): |
|
|
|
self.self.scale_sense.set_data(sens) |
|
|
|
self.scale_sense.set_data(sens) |
|
|
|
else: |
|
|
|
raise TypeError("The input type must be Tensor,but got {}".format(type(sens))) |