| @@ -179,7 +179,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| network (Cell): The training network. The network only supports single output. | network (Cell): The training network. The network only supports single output. | ||||
| optimizer (Cell): Optimizer for updating the weights. | optimizer (Cell): Optimizer for updating the weights. | ||||
| scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value | scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value | ||||
| is Tensor type, Tensor with shape :math:`()`. Default: None. | |||||
| is Tensor type, Tensor with shape :math:`()`. | |||||
| Inputs: | Inputs: | ||||
| - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. | ||||
| @@ -189,6 +189,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| - **loss** (Tensor) - Tensor with shape :math:`()`. | - **loss** (Tensor) - Tensor with shape :math:`()`. | ||||
| - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. | - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. | ||||
| - **loss scaling value** (Tensor) - Tensor with shape :math:`()` | |||||
| Examples: | Examples: | ||||
| >>> net_with_loss = Net() | >>> net_with_loss = Net() | ||||
| @@ -203,7 +204,7 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| >>> output = train_network(inputs, label, scaling_sens) | >>> output = train_network(inputs, label, scaling_sens) | ||||
| """ | """ | ||||
| def __init__(self, network, optimizer, scale_sense=None): | |||||
| def __init__(self, network, optimizer, scale_sense): | |||||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | ||||
| self.network = network | self.network = network | ||||
| self.network.set_grad() | self.network.set_grad() | ||||
| @@ -236,14 +237,15 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | ||||
| self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE | ||||
| self.scale_sense = None | |||||
| self.loss_scaling_manager = None | self.loss_scaling_manager = None | ||||
| if isinstance(scale_sense, Cell): | if isinstance(scale_sense, Cell): | ||||
| self.loss_scaling_manager = scale_sense | self.loss_scaling_manager = scale_sense | ||||
| self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), | self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), | ||||
| name="scale_sense") | name="scale_sense") | ||||
| if isinstance(scale_sense, Tensor): | |||||
| elif isinstance(scale_sense, Tensor): | |||||
| self.scale_sense = Parameter(scale_sense, name='scale_sense') | self.scale_sense = Parameter(scale_sense, name='scale_sense') | ||||
| else: | |||||
| raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) | |||||
| @C.add_flags(has_effect=True) | @C.add_flags(has_effect=True) | ||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| @@ -293,4 +295,6 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| """If the user has set the sens in the training process and wants to reassign the value, he can call | """If the user has set the sens in the training process and wants to reassign the value, he can call | ||||
| this function again to make modification, and sens needs to be of type Tensor.""" | this function again to make modification, and sens needs to be of type Tensor.""" | ||||
| if self.scale_sense and isinstance(sens, Tensor): | if self.scale_sense and isinstance(sens, Tensor): | ||||
| self.self.scale_sense.set_data(sens) | |||||
| self.scale_sense.set_data(sens) | |||||
| else: | |||||
| raise TypeError("The input type must be Tensor,but got {}".format(type(sens))) | |||||