You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """utils script"""
  16. def _load_param_into_net(model, params_dict):
  17. """
  18. load fp32 model parameters to quantization model.
  19. Args:
  20. model: quantization model
  21. params_dict: f32 param
  22. Returns:
  23. None
  24. """
  25. iterable_dict = {
  26. 'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
  27. 'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
  28. 'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
  29. 'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
  30. 'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
  31. 'moving_variance': iter(
  32. [item for item in params_dict.items() if item[0].endswith('moving_variance')]),
  33. 'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
  34. 'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
  35. }
  36. for name, param in model.parameters_and_names():
  37. key_name = name.split(".")[-1]
  38. if key_name not in iterable_dict.keys():
  39. continue
  40. value_param = next(iterable_dict[key_name], None)
  41. if value_param is not None:
  42. param.set_parameter_data(value_param[1].data)
  43. print(f'init model param {name} with checkpoint param {value_param[0]}')