You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.DistributedGradReducer.rst 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. mindspore.nn.DistributedGradReducer
  2. ===================================
  3. .. py:class:: mindspore.nn.DistributedGradReducer(parameters, mean=True, degree=None, fusion_type=1, group=GlobalComm.WORLD_COMM_GROUP)
  4. 分布式优化器。
  5. 用于数据并行模式中,对所有卡的梯度利用AllReduce进行聚合。
  6. **参数:**
  7. - **parameters** (list) - 需要更新的参数。
  8. - **mean** (bool) - 当mean为True时,对AllReduce之后的梯度求均值。默认值:False。
  9. - **degree** (int) - 平均系数,通常等于设备编号。默认值:None。
  10. - **fusion_type** (int) - AllReduce算子的融合类型。默认值:1。
  11. - **group** (str) - AllReduce算子的通信域,若需要自定义通信域,需要调用create_group接口。默认值:GlobalComm.WORLD_COMM_GROUP。
  12. **异常:**
  13. **ValueError**:如果degree不是int或小于0。
  14. **支持平台:**
  15. ``Ascend`` ``GPU``
  16. **样例:**
  17. >>> #此示例应与多个进程一起运行。
  18. >>> #请参考Mindpore.cn上的“教程>分布式训练”。
  19. >>> import numpy as np
  20. >>> from mindspore.communication import init
  21. >>> from mindspore import ops
  22. >>> from mindspore import context
  23. >>> from mindspore.context import ParallelMode
  24. >>> from mindspore import Parameter, Tensor
  25. >>> from mindspore import nn
  26. >>>
  27. >>> context.set_context(mode=context.GRAPH_MODE)
  28. >>> init()
  29. >>> context.reset_auto_parallel_context()
  30. >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
  31. >>>
  32. >>> class TrainingWrapper(nn.Cell):
  33. ... def __init__(self, network, optimizer, sens=1.0):
  34. ... super(TrainingWrapper, self).__init__(auto_prefix=False)
  35. ... self.network = network
  36. ... self.network.add_flags(defer_inline=True)
  37. ... self.weights = optimizer.parameters
  38. ... self.optimizer = optimizer
  39. ... self.grad = ops.GradOperation(get_by_list=True, sens_param=True)
  40. ... self.sens = sens
  41. ... self.reducer_flag = False
  42. ... self.grad_reducer = None
  43. ... self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  44. ... if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  45. ... self.reducer_flag = True
  46. ... if self.reducer_flag:
  47. ... mean = context.get_auto_parallel_context("gradients_mean")
  48. ... degree = context.get_auto_parallel_context("device_num")
  49. ... self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  50. ...
  51. ... def construct(self, *args):
  52. ... weights = self.weights
  53. ... loss = self.network(*args)
  54. ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
  55. ... grads = self.grad(self.network, weights)(*args, sens)
  56. ... if self.reducer_flag:
  57. ... # apply grad reducer on grads
  58. ... grads = self.grad_reducer(grads)
  59. ... return ops.Depend(loss, self.optimizer(grads))
  60. >>>
  61. >>> class Net(nn.Cell):
  62. ... def __init__(self, in_features, out_features):
  63. ... super(Net, self).__init__()
  64. ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
  65. ... name='weight')
  66. ... self.matmul = ops.MatMul()
  67. ...
  68. ... def construct(self, x):
  69. ... output = self.matmul(x, self.weight)
  70. ... return output
  71. >>>
  72. >>> size, in_features, out_features = 16, 16, 10
  73. >>> network = Net(in_features, out_features)
  74. >>> loss = nn.MSELoss()
  75. >>> net_with_loss = nn.WithLossCell(network, loss)
  76. >>> optimizer = nn.Momentum(net_with_loss.trainable_params(), learning_rate=0.1, momentum=0.9)
  77. >>> train_cell = TrainingWrapper(net_with_loss, optimizer)
  78. >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
  79. >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
  80. >>> grads = train_cell(inputs, label)
  81. >>> print(grads)
  82. 256.0