| @@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| network (Cell): The training network. The network only supports single output. | |||
| optimizer (Cell): Optimizer for updating the weights. | |||
| scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value | |||
| is Tensor type, Tensor with shape :math:`()`. Default: None. | |||
| is Tensor type, Tensor with shape :math:`()`. | |||
| Inputs: | |||
| - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | |||
| @@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| - **loss** (Tensor) - Tensor with shape :math:`()`. | |||
| - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. | |||
| - **loss scaling value** (Tensor) - Tensor with shape :math:`()` | |||
| Examples: | |||
| >>> net_with_loss = Net() | |||
| @@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| >>> output = train_network(inputs, label, scaling_sens) | |||
| """ | |||
| def __init__(self, network, optimizer, scale_sense=None): | |||
| def __init__(self, network, optimizer, scale_sense): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.set_grad() | |||
| @@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | |||
| self.scale_sense = None | |||
| self.loss_scaling_manager = None | |||
| if isinstance(scale_sense, Cell): | |||
| self.loss_scaling_manager = scale_sense | |||
| self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), | |||
| name="scale_sense") | |||
| if isinstance(scale_sense, Tensor): | |||
| elif isinstance(scale_sense, Tensor): | |||
| self.scale_sense = Parameter(scale_sense, name='scale_sense') | |||
| else: | |||
| raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, *inputs): | |||
| @@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell): | |||
| """If the user has set the sens in the training process and wants to reassign the value, he can call | |||
| this function again to make modification, and sens needs to be of type Tensor.""" | |||
| if self.scale_sense and isinstance(sens, Tensor): | |||
| self.self.scale_sense.set_data(sens) | |||
| self.scale_sense.set_data(sens) | |||
| else: | |||
| raise TypeError("The input type must be Tensor,but got {}".format(type(sens))) | |||