You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_step_wrap.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. train step wrap
  17. """
  18. import mindspore.nn as nn
  19. from mindspore.ops import functional as F
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import operations as P
  22. from mindspore import Parameter, ParameterTuple
  23. class TrainStepWrap(nn.Cell):
  24. """
  25. TrainStepWrap definition
  26. """
  27. def __init__(self, network):
  28. super(TrainStepWrap, self).__init__()
  29. self.network = network
  30. self.network.set_train()
  31. self.weights = ParameterTuple(network.trainable_params())
  32. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  33. self.hyper_map = C.HyperMap()
  34. self.grad = C.GradOperation('grad', get_by_list=True)
  35. def construct(self, x, label):
  36. weights = self.weights
  37. grads = self.grad(self.network, weights)(x, label)
  38. return self.optimizer(grads)
  39. class NetWithLossClass(nn.Cell):
  40. """
  41. NetWithLossClass definition
  42. """
  43. def __init__(self, network):
  44. super(NetWithLossClass, self).__init__(auto_prefix=False)
  45. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  46. self.network = network
  47. def construct(self, x, label):
  48. predict = self.network(x)
  49. return self.loss(predict, label)
  50. def train_step_with_loss_warp(network):
  51. return TrainStepWrap(NetWithLossClass(network))
  52. class TrainStepWrap2(nn.Cell):
  53. """
  54. TrainStepWrap2 definition
  55. """
  56. def __init__(self, network, sens):
  57. super(TrainStepWrap2, self).__init__()
  58. self.network = network
  59. self.network.set_train()
  60. self.weights = ParameterTuple(network.get_parameters())
  61. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  62. self.hyper_map = C.HyperMap()
  63. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  64. self.sens = sens
  65. def construct(self, x):
  66. weights = self.weights
  67. grads = self.grad(self.network, weights)(x, self.sens)
  68. return self.optimizer(grads)
  69. def train_step_with_sens(network, sens):
  70. return TrainStepWrap2(network, sens)
  71. class TrainStepWrapWithoutOpt(nn.Cell):
  72. """
  73. TrainStepWrapWithoutOpt definition
  74. """
  75. def __init__(self, network):
  76. super(TrainStepWrapWithoutOpt, self).__init__()
  77. self.network = network
  78. self.weights = ParameterTuple(network.trainable_params())
  79. self.grad = C.GradOperation('grad', get_by_list=True)
  80. def construct(self, x, label):
  81. grads = self.grad(self.network, self.weights)(x, label)
  82. return grads
  83. def train_step_without_opt(network):
  84. return TrainStepWrapWithoutOpt(NetWithLossClass(network))