You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """model utils"""
  16. import math
  17. import argparse
  18. import numpy as np
  19. def str2bool(value):
  20. """Convert string arguments to bool type"""
  21. if value.lower() in ('yes', 'true', 't', 'y', '1'):
  22. return True
  23. if value.lower() in ('no', 'false', 'f', 'n', '0'):
  24. return False
  25. raise argparse.ArgumentTypeError('Boolean value expected.')
  26. def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
  27. warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
  28. """Get scheduled learning rate"""
  29. lr_each_step = []
  30. total_steps = steps_per_epoch * total_epochs
  31. global_steps = steps_per_epoch * global_epoch
  32. self_warmup_delta = ((base_lr - warmup_lr_init) / \
  33. warmup_epochs) if warmup_epochs > 0 else 0
  34. self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
  35. for i in range(total_steps):
  36. epochs = math.floor(i/steps_per_epoch)
  37. cond = 1 if (epochs < warmup_epochs) else 0
  38. warmup_lr = warmup_lr_init + epochs * self_warmup_delta
  39. decay_nums = math.floor(epochs / decay_epochs)
  40. decay_rate = math.pow(self_decay_rate, decay_nums)
  41. decay_lr = base_lr * decay_rate
  42. lr = cond * warmup_lr + (1 - cond) * decay_lr
  43. lr_each_step.append(lr)
  44. lr_each_step = lr_each_step[global_steps:]
  45. lr_each_step = np.array(lr_each_step).astype(np.float32)
  46. return lr_each_step
  47. def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
  48. """Apply weight decay to only conv and dense layers (len(shape) > =2)
  49. Args:
  50. net (mindspore.nn.Cell): Mindspore network instance
  51. weight_decay (float): weight decay tobe used.
  52. skip_list (tuple): list of parameter names without weight decay
  53. Returns:
  54. A list of group of parameters, separated by different weight decay.
  55. """
  56. decay = []
  57. no_decay = []
  58. if not skip_list:
  59. skip_list = ()
  60. for param in net.trainable_params():
  61. if len(param.shape) == 1 or \
  62. param.name.endswith(".bias") or \
  63. param.name in skip_list:
  64. no_decay.append(param)
  65. else:
  66. decay.append(param)
  67. return [
  68. {'params': no_decay, 'weight_decay': 0.},
  69. {'params': decay, 'weight_decay': weight_decay}]
  70. def count_params(net):
  71. """Count number of parameters in the network
  72. Args:
  73. net (mindspore.nn.Cell): Mindspore network instance
  74. Returns:
  75. total_params (int): Total number of trainable params
  76. """
  77. total_params = 0
  78. for param in net.trainable_params():
  79. total_params += np.prod(param.shape)
  80. return total_params