You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weight_init.py 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Weight init utilities."""
  16. import math
  17. import numpy as np
  18. from mindspore.common.tensor import Tensor
  19. def _average_units(shape):
  20. """
  21. Average shape dim.
  22. """
  23. if not shape:
  24. return 1.
  25. if len(shape) == 1:
  26. return float(shape[0])
  27. if len(shape) == 2:
  28. return float(shape[0] + shape[1]) / 2.
  29. raise RuntimeError("not support shape.")
  30. def weight_variable(shape):
  31. scale_shape = shape
  32. avg_units = _average_units(scale_shape)
  33. scale = 1.0 / max(1., avg_units)
  34. limit = math.sqrt(3.0 * scale)
  35. values = np.random.uniform(-limit, limit, shape).astype(np.float32)
  36. return Tensor(values)
  37. def one_weight(shape):
  38. ones = np.ones(shape).astype(np.float32)
  39. return Tensor(ones)
  40. def zero_weight(shape):
  41. zeros = np.zeros(shape).astype(np.float32)
  42. return Tensor(zeros)
  43. def normal_weight(shape, num_units):
  44. norm = np.random.normal(0.0, num_units**-0.5, shape).astype(np.float32)
  45. return Tensor(norm)