You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lr_schedule.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_lr_schedule """
  16. import numpy as np
  17. from mindspore import Parameter, ParameterTuple, Tensor
  18. from mindspore.nn import Cell
  19. from mindspore.nn.optim import Optimizer
  20. from mindspore.ops.composite import grad_by_list
  21. from mindspore.ops.operations import BiasAdd, MatMul
  22. class Net(Cell):
  23. """ Net definition """
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight")
  27. self.bias = Parameter(Tensor(np.ones([10])), name="bias")
  28. self.matmul = MatMul()
  29. self.biasAdd = BiasAdd()
  30. def construct(self, x):
  31. x = self.biasAdd(self.matmul(x, self.weight), self.bias)
  32. return x
  33. class _TrainOneStepCell(Cell):
  34. """ _TrainOneStepCell definition """
  35. def __init__(self, network, optimizer):
  36. """
  37. Append an optimizer to the training network after that the construct
  38. function can be called to create the backward graph.
  39. Arguments:
  40. network: The training network.
  41. Note that loss function should have been added.
  42. optimizer: optimizer for updating the weights
  43. """
  44. super(_TrainOneStepCell, self).__init__(auto_prefix=False)
  45. self.network = network
  46. self.weights = ParameterTuple(network.get_parameters())
  47. if not isinstance(optimizer, Optimizer):
  48. raise TypeError('{} is not an optimizer'.format(
  49. type(optimizer).__name__))
  50. self.has_lr_schedule = False
  51. self.optimizer = optimizer
  52. def construct(self, data, label, *args):
  53. weights = self.weights
  54. grads = grad_by_list(self.network, weights)(data, label)
  55. if self.lr_schedule:
  56. self.schedule.update_lr(*args)
  57. return self.optimizer(grads)