You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_step_wrap.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. train step wrap
  17. """
  18. import mindspore.nn as nn
  19. from mindspore.ops import functional as F
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import operations as P
  22. from mindspore import Parameter, ParameterTuple
  23. run_opt = C.MultitypeFuncGraph("run_opt")
  24. # pylint: disable=unused-argument
  25. @run_opt.register("Function", "Int", "Number", "Number",
  26. "Tensor", "Tensor", "Tensor")
  27. def tensor_run_opt(opt, iterator, learning_rate, momentum,
  28. gradient, variable, moment):
  29. success = True
  30. new_weight = opt(gradient, moment, variable, learning_rate, momentum)
  31. success = F.depend(success, P.Assign()(variable, new_weight))
  32. return success
  33. class OptimizerByMomentum(nn.Cell):
  34. """
  35. OptimizerByMomentum definition
  36. """
  37. # list of tensor
  38. def __init__(self, weights):
  39. super(OptimizerByMomentum, self).__init__()
  40. self.learning_rate = Parameter(0.1, name="learning_rate")
  41. self.momentum = Parameter(0.05, name="momentum")
  42. self.iter = Parameter(0, name="iter")
  43. self.weights = weights
  44. self.moments = weights.clone(prefix="moments", init='zeros')
  45. self.hyper_map = C.HyperMap()
  46. self.opt = P.ApplyMomentum()
  47. def construct(self, grads):
  48. success = True
  49. weights = self.weights
  50. moments = self.moments
  51. success = self.hyper_map(
  52. F.partial(run_opt, self.opt, self.iter,
  53. self.learning_rate, self.momentum), grads, weights, moments)
  54. # self.learning_rate = updata_lr(self.learning_rate, self.momentum)
  55. return success
  56. class TrainStepWrap(nn.Cell):
  57. """
  58. TrainStepWrap definition
  59. """
  60. def __init__(self, network):
  61. super(TrainStepWrap, self).__init__()
  62. self.network = network
  63. self.network.set_train()
  64. self.weights = ParameterTuple(network.trainable_params())
  65. self.optimizer = OptimizerByMomentum(self.weights)
  66. self.hyper_map = C.HyperMap()
  67. self.grad = C.GradOperation('grad', get_by_list=True)
  68. def construct(self, x, label):
  69. weights = self.weights
  70. grads = self.grad(self.network, weights)(x, label)
  71. return self.optimizer(grads)
  72. class NetWithLossClass(nn.Cell):
  73. """
  74. NetWithLossClass definition
  75. """
  76. def __init__(self, network):
  77. super(NetWithLossClass, self).__init__(auto_prefix=False)
  78. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  79. self.network = network
  80. def construct(self, x, label):
  81. predict = self.network(x)
  82. return self.loss(predict, label)
  83. def train_step_with_loss_warp(network):
  84. return TrainStepWrap(NetWithLossClass(network))
  85. class TrainStepWrap2(nn.Cell):
  86. """
  87. TrainStepWrap2 definition
  88. """
  89. def __init__(self, network, sens):
  90. super(TrainStepWrap2, self).__init__()
  91. self.network = network
  92. self.network.set_train()
  93. self.weights = ParameterTuple(network.get_parameters())
  94. self.optimizer = OptimizerByMomentum(self.weights)
  95. self.hyper_map = C.HyperMap()
  96. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  97. self.sens = sens
  98. def construct(self, x):
  99. weights = self.weights
  100. grads = self.grad(self.network, weights)(x, self.sens)
  101. return self.optimizer(grads)
  102. def train_step_with_sens(network, sens):
  103. return TrainStepWrap2(network, sens)
  104. class TrainStepWrapWithoutOpt(nn.Cell):
  105. """
  106. TrainStepWrapWithoutOpt definition
  107. """
  108. def __init__(self, network):
  109. super(TrainStepWrapWithoutOpt, self).__init__()
  110. self.network = network
  111. self.weights = ParameterTuple(network.trainable_params())
  112. self.grad = C.GradOperation('grad', get_by_list=True)
  113. def construct(self, x, label):
  114. grads = self.grad(self.network, self.weights)(x, label)
  115. return grads
  116. def train_step_without_opt(network):
  117. return TrainStepWrapWithoutOpt(NetWithLossClass(network))

MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.