浏览代码

Fixes #I3AB6T

tags/v1.3.0
wudenggang 5 年前
父节点
当前提交
696d80607d
共有 1 个文件被更改,包括 6 次插入4 次删除
  1. +6
    -4
      mindspore/common/initializer.py

+ 6
- 4
mindspore/common/initializer.py 查看文件

@@ -366,25 +366,27 @@ class Uniform(Initializer):
@_register()
class Normal(Initializer):
"""
Initialize a normal array, and obtain values N(0, sigma) from the uniform distribution
Initialize a normal array, and obtain values N(sigma, mean) from the normal distribution
to fill the input tensor.

Args:
sigma (float): The sigma of the array. Default: 0.01.
mean (float): The mean of the array. Default: 0.0.

Returns:
Array, normal array.
"""
def __init__(self, sigma=0.01):
super(Normal, self).__init__(sigma=sigma)
def __init__(self, sigma=0.01, mean=0.0):
super(Normal, self).__init__(sigma=sigma, mean=mean)
self.sigma = sigma
self.mean = mean

def _initialize(self, arr):
seed, seed2 = self.seed
output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32))
random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor)
output_data = output_tensor.asnumpy()
output_data *= self.sigma
output_data = output_data * self.sigma + self.mean
_assignment(arr, output_data)

@_register()


正在加载...
取消
保存