From: @wangshuide2020 Reviewed-by: @liangchenghui,@youui Signed-off-by: @liangchenghuitags/v1.1.0
| @@ -291,11 +291,12 @@ class MSSSIM(Cell): | |||
| Examples: | |||
| >>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033)) | |||
| >>> img1 = Tensor(np.random.random((1,3,128,128))) | |||
| >>> img2 = Tensor(np.random.random((1,3,128,128))) | |||
| >>> np.random.seed(0) | |||
| >>> img1 = Tensor(np.random.random((1, 3, 128, 128))) | |||
| >>> img2 = Tensor(np.random.random((1, 3, 128, 128))) | |||
| >>> output = net(img1, img2) | |||
| >>> print(output) | |||
| [0.22965115] | |||
| [0.20607519] | |||
| """ | |||
| def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, | |||
| filter_sigma=1.5, k1=0.01, k2=0.03): | |||
| @@ -285,11 +285,12 @@ class BatchNorm1d(_BatchNorm): | |||
| Examples: | |||
| >>> net = nn.BatchNorm1d(num_features=4) | |||
| >>> np.random.seed(0) | |||
| >>> input = Tensor(np.random.randint(0, 255, [2, 4]), mindspore.float32) | |||
| >>> output = net(input) | |||
| >>> print(output) | |||
| [[210.99895 136.99931 89.99955 240.9988 ] | |||
| [ 87.99956 157.9992 89.99955 42.999786]] | |||
| [[171.99915 46.999763 116.99941 191.99904 ] | |||
| [ 66.999664 250.99875 194.99902 102.99948 ]] | |||
| """ | |||
| def __init__(self, | |||
| @@ -370,15 +371,18 @@ class BatchNorm2d(_BatchNorm): | |||
| Examples: | |||
| >>> net = nn.BatchNorm2d(num_features=3) | |||
| >>> np.random.seed(0) | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mindspore.float32) | |||
| >>> output = net(input) | |||
| >>> print(output) | |||
| [[[[128.99936 53.99973] | |||
| [191.99904 183.99908]] | |||
| [[146.99927 182.99908] | |||
| [184.99907 120.9994 ]] | |||
| [[ 33.99983 234.99883] | |||
| [188.99905 11.99994]]]] | |||
| [[[[171.99915 46.999763 ] | |||
| [116.99941 191.99904 ]] | |||
| [[ 66.999664 250.99875 ] | |||
| [194.99902 102.99948 ]] | |||
| [[ 8.999955 210.99895 ] | |||
| [ 20.999895 241.9988 ]]]] | |||
| """ | |||
| def __init__(self, | |||
| @@ -455,9 +459,34 @@ class GlobalBatchNorm(_BatchNorm): | |||
| Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | |||
| Examples: | |||
| >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4) | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) | |||
| >>> global_bn_op(input) | |||
| >>> # This example should be run with multiple processes. Refer to the run_distribute_train.sh | |||
| >>> import os | |||
| >>> import numpy as np | |||
| >>> from mindspore.communication import init | |||
| >>> from mindspore import context | |||
| >>> from mindspore.context import ParallelMode | |||
| >>> from mindspore import nn, Tensor | |||
| >>> from mindspore.common import dtype as mstype | |||
| >>> | |||
| >>> device_id = int(os.environ["DEVICE_ID"]) | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | |||
| >>> device_id=int(device_id)) | |||
| >>> init() | |||
| >>> context.reset_auto_parallel_context() | |||
| >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | |||
| >>> np.random.seed(0) | |||
| >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2) | |||
| >>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32) | |||
| >>> output = global_bn_op(input) | |||
| >>> print(output) | |||
| [[[[171.99915 46.999763] | |||
| [116.99941 191.99904 ]] | |||
| [[ 66.999664 250.99875 ] | |||
| [194.99902 102.99948 ]] | |||
| [[ 8.999955 210.99895 ] | |||
| [ 20.9999895 241.9988 ]]]] | |||
| """ | |||
| def __init__(self, | |||
| @@ -248,6 +248,16 @@ class GetNextSingleOp(Cell): | |||
| queue_name (str): Queue name to fetch the data. | |||
| For detailed information, refer to `ops.operations.GetNext`. | |||
| Examples: | |||
| >>> # Refer to dataset_helper.py for detail usage. | |||
| >>> data_set = get_dataset() | |||
| >>> dataset_shapes = data_set.output_shapes() | |||
| >>> np_types = data_set.output_types() | |||
| >>> dataset_types = convert_type(dataset_shapes, np_types) | |||
| >>> queue_name = data_set.__TRANSFER_DATASET__.queue_name | |||
| >>> getnext_op = GetNextSingleOp(dataset_types, dataset_shapes, queue_name) | |||
| >>> getnext_op() | |||
| """ | |||
| def __init__(self, dataset_types, dataset_shapes, queue_name): | |||
| @@ -246,16 +246,19 @@ class DistributedGradReducer(Cell): | |||
| ValueError: If degree is not a int or less than 0. | |||
| Examples: | |||
| >>> from mindspore.communication import init, get_group_size | |||
| >>> # This example should be run with multiple processes. Refer to the run_distribute_train.sh | |||
| >>> import os | |||
| >>> import numpy as np | |||
| >>> from mindspore.communication import init | |||
| >>> from mindspore.ops import composite as C | |||
| >>> from mindspore.ops import operations as P | |||
| >>> from mindspore.ops import functional as F | |||
| >>> from mindspore import context | |||
| >>> from mindspore.context import ParallelMode | |||
| >>> from mindspore import Parameter, Tensor | |||
| >>> from mindspore import nn | |||
| >>> from mindspore import ParameterTuple | |||
| >>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, | |||
| >>> _get_parallel_mode) | |||
| >>> from mindspore.nn.wrap.cell_wrapper import WithLossCell | |||
| >>> from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean) | |||
| >>> | |||
| >>> device_id = int(os.environ["DEVICE_ID"]) | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | |||
| @@ -295,12 +298,28 @@ class DistributedGradReducer(Cell): | |||
| >>> grads = self.grad_reducer(grads) | |||
| >>> return F.depend(loss, self.optimizer(grads)) | |||
| >>> | |||
| >>> network = Net() | |||
| >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> train_cell = TrainingWrapper(network, optimizer) | |||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self, in_features, out_features): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), | |||
| >>> name='weight') | |||
| >>> self.matmul = P.MatMul() | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> output = self.matmul(x, self.weight) | |||
| >>> return output | |||
| >>> | |||
| >>> size, in_features, out_features = 16, 16, 10 | |||
| >>> network = Net(in_features, out_features) | |||
| >>> loss = nn.MSELoss() | |||
| >>> net_with_loss = WithLossCell(network, loss) | |||
| >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> train_cell = TrainingWrapper(net_with_loss, optimizer) | |||
| >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) | |||
| >>> grads = train_cell(inputs, label) | |||
| >>> print(grads) | |||
| 256.0 | |||
| """ | |||
| def __init__(self, parameters, mean=True, degree=None): | |||
| @@ -76,16 +76,30 @@ class DynamicLossScaleUpdateCell(Cell): | |||
| Tensor, a scalar Tensor with shape :math:`()`. | |||
| Examples: | |||
| >>> net_with_loss = Net() | |||
| >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor, Parameter, nn | |||
| >>> from mindspore.ops import operations as P | |||
| >>> from mindspore.nn.wrap.cell_wrapper import WithLossCell | |||
| >>> | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self, in_features, out_features): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), | |||
| >>> name='weight') | |||
| >>> self.matmul = P.MatMul() | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> output = self.matmul(x, self.weight) | |||
| >>> return output | |||
| >>> | |||
| >>> in_features, out_features = 16, 10 | |||
| >>> net = Net(in_features, out_features) | |||
| >>> loss = nn.MSELoss() | |||
| >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> net_with_loss = WithLossCell(net, loss) | |||
| >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) | |||
| >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) | |||
| >>> train_network.set_train() | |||
| >>> | |||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | |||
| >>> output = train_network(inputs, label, scale_sense=scaling_sens) | |||
| """ | |||
| def __init__(self, | |||
| @@ -142,16 +156,30 @@ class FixedLossScaleUpdateCell(Cell): | |||
| loss_scale_value (float): Initializes loss scale. | |||
| Examples: | |||
| >>> net_with_loss = Net() | |||
| >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor, Parameter, nn | |||
| >>> from mindspore.ops import operations as P | |||
| >>> from mindspore.nn.wrap.cell_wrapper import WithLossCell | |||
| >>> | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self, in_features, out_features): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), | |||
| >>> name='weight') | |||
| >>> self.matmul = P.MatMul() | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> output = self.matmul(x, self.weight) | |||
| >>> return output | |||
| >>> | |||
| >>> in_features, out_features = 16, 10 | |||
| >>> net = Net(in_features, out_features) | |||
| >>> loss = nn.MSELoss() | |||
| >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> net_with_loss = WithLossCell(net, loss) | |||
| >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12) | |||
| >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) | |||
| >>> train_network.set_train() | |||
| >>> | |||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | |||
| >>> output = train_network(inputs, label, scale_sense=scaling_sens) | |||
| """ | |||
| def __init__(self, loss_scale_value): | |||
| @@ -193,21 +221,45 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| - **loss scaling value** (Tensor) - Tensor with shape :math:`()` | |||
| Examples: | |||
| >>> #1) when the type scale_sense is Cell: | |||
| >>> net_with_loss = Net() | |||
| >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor, Parameter, nn | |||
| >>> from mindspore.ops import operations as P | |||
| >>> from mindspore.nn.wrap.cell_wrapper import WithLossCell | |||
| >>> from mindspore.common import dtype as mstype | |||
| >>> | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self, in_features, out_features): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), | |||
| >>> name='weight') | |||
| >>> self.matmul = P.MatMul() | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> output = self.matmul(x, self.weight) | |||
| >>> return output | |||
| >>> | |||
| >>> size, in_features, out_features = 16, 16, 10 | |||
| >>> #1) when the type of scale_sense is Cell: | |||
| >>> net = Net(in_features, out_features) | |||
| >>> loss = nn.MSELoss() | |||
| >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> net_with_loss = WithLossCell(net, loss) | |||
| >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) | |||
| >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) | |||
| >>> train_network.set_train() | |||
| >>> | |||
| >>> #2) when the type scale_sense is Tensor: | |||
| >>> net_with_loss = Net() | |||
| >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mindspore.float32) | |||
| >>> #2) when the type of scale_sense is Tensor: | |||
| >>> net = Net(in_features, out_features) | |||
| >>> loss = nn.MSELoss() | |||
| >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> net_with_loss = WithLossCell(net, loss) | |||
| >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) | |||
| >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) | |||
| >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) | |||
| >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) | |||
| >>> output = train_network(inputs, label) | |||
| >>> print(output[0]) | |||
| 256.0 | |||
| """ | |||
| def __init__(self, network, optimizer, scale_sense): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) | |||