| @@ -151,6 +151,84 @@ class One(Initializer): | |||||
| _assignment(arr, 1) | _assignment(arr, 1) | ||||
| def _calculate_fan_in_and_fan_out(shape): | |||||
| """ | |||||
| calculate fan_in and fan_out | |||||
| Args: | |||||
| shape (tuple): input shape. | |||||
| Returns: | |||||
| Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`. | |||||
| """ | |||||
| dimensions = len(shape) | |||||
| if dimensions < 2: | |||||
| raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | |||||
| if dimensions == 2: # Linear | |||||
| fan_in = shape[1] | |||||
| fan_out = shape[0] | |||||
| else: | |||||
| num_input_fmaps = shape[1] | |||||
| num_output_fmaps = shape[0] | |||||
| receptive_field_size = 1 | |||||
| if dimensions > 2: | |||||
| receptive_field_size = shape[2] * shape[3] | |||||
| fan_in = num_input_fmaps * receptive_field_size | |||||
| fan_out = num_output_fmaps * receptive_field_size | |||||
| return fan_in, fan_out | |||||
| def _calculate_correct_fan(shape, mode): | |||||
| """ | |||||
| Calculate fan. | |||||
| Args: | |||||
| shape (tuple): input shape. | |||||
| mode (str): only support fan_in and fan_out. | |||||
| Returns: | |||||
| fan_in or fan_out. | |||||
| """ | |||||
| mode = mode.lower() | |||||
| valid_modes = ['fan_in', 'fan_out'] | |||||
| if mode not in valid_modes: | |||||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | |||||
| fan_in, fan_out = _calculate_fan_in_and_fan_out(shape) | |||||
| return fan_in if mode == 'fan_in' else fan_out | |||||
| def _calculate_gain(nonlinearity, param=None): | |||||
| """ | |||||
| Calculate gain. | |||||
| Args: | |||||
| nonlinearity (str): nonlinearity function. | |||||
| param (str): used to calculate negative_slope. | |||||
| Returns: | |||||
| number. | |||||
| """ | |||||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |||||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |||||
| res = 1 | |||||
| elif nonlinearity == 'tanh': | |||||
| res = 5.0 / 3 | |||||
| elif nonlinearity == 'relu': | |||||
| res = math.sqrt(2.0) | |||||
| elif nonlinearity == 'leaky_relu': | |||||
| if param is None: | |||||
| negative_slope = 0.01 | |||||
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |||||
| # True/False are instances of int, hence check above | |||||
| negative_slope = param | |||||
| else: | |||||
| raise ValueError("negative_slope {} not a valid number".format(param)) | |||||
| res = math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||||
| else: | |||||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | |||||
| return res | |||||
| def _calculate_in_and_out(arr): | def _calculate_in_and_out(arr): | ||||
| """ | """ | ||||
| Calculate n_in and n_out. | Calculate n_in and n_out. | ||||
| @@ -223,6 +301,35 @@ class HeUniform(Initializer): | |||||
| _assignment(arr, data) | _assignment(arr, data) | ||||
| @_register('he_normal') | |||||
| class HeNormal(Initializer): | |||||
| r""" | |||||
| Initialize the array with He kaiming Normal algorithm, and from a normal distribution collect samples within | |||||
| N(0, sigma). | |||||
| Args: | |||||
| negative_slope (int, float, bool): Default: 0, used when nonlinearity is 'leaky_relu'. | |||||
| mode (str): Default: fan_in. | |||||
| nonlinearity (str): Default: leaky_relu. | |||||
| Returns: | |||||
| Array, assigned array. | |||||
| """ | |||||
| def __init__(self, negative_slope=0, mode='fan_in', nonlinearity='leaky_relu'): | |||||
| super(HeNormal, self).__init__(negative_slope=negative_slope, mode=mode, nonlinearity=nonlinearity) | |||||
| self.negative_slope = negative_slope | |||||
| self.mode = mode | |||||
| self.nonlinearity = nonlinearity | |||||
| def _initialize(self, arr): | |||||
| fan = _calculate_correct_fan(arr.shape, self.mode) | |||||
| gain = _calculate_gain(self.nonlinearity, self.negative_slope) | |||||
| std = gain / math.sqrt(fan) | |||||
| data = np.random.normal(0, std, arr.shape) | |||||
| _assignment(arr, data) | |||||
| class Constant(Initializer): | class Constant(Initializer): | ||||
| """ | """ | ||||
| Initialize a constant. | Initialize a constant. | ||||
| @@ -372,6 +479,7 @@ __all__ = [ | |||||
| 'Normal', | 'Normal', | ||||
| 'Uniform', | 'Uniform', | ||||
| 'HeUniform', | 'HeUniform', | ||||
| 'HeNormal', | |||||
| 'XavierUniform', | 'XavierUniform', | ||||
| 'One', | 'One', | ||||
| 'Zero', | 'Zero', | ||||