| @@ -19,17 +19,17 @@ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, | |||
| __global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, | |||
| T* mean_square, T*moment, T* gradients, const size_t size) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||
| mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; | |||
| moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i]; | |||
| mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; | |||
| moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; | |||
| variable[i] -= moment[i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, | |||
| void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, | |||
| T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { | |||
| RmsPropKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(learning_rate, decay, momentum, epsilon, | |||
| variable, mean_square, moment, gradients, size); | |||
| @@ -58,7 +58,7 @@ void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, co | |||
| } | |||
| template | |||
| void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, | |||
| void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, | |||
| float* variable, float* mean_square, float* moment, float* gradients, const size_t size, | |||
| cudaStream_t cuda_stream); | |||
| @@ -19,7 +19,7 @@ | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square, | |||
| void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, | |||
| T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| @@ -25,9 +25,6 @@ MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| RMSPropGpuKernel, float) | |||
| @@ -27,7 +27,7 @@ namespace kernel { | |||
| template <typename T> | |||
| class RMSPropGpuKernel : public GpuKernel { | |||
| public: | |||
| RMSPropGpuKernel() : size_(1), use_center_(false) {} | |||
| RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} | |||
| ~RMSPropGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| @@ -40,13 +40,10 @@ class RMSPropGpuKernel : public GpuKernel { | |||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||
| T *mean_square = GetDeviceAddress<T>(inputs, 1); | |||
| T *moment = GetDeviceAddress<T>(inputs, 2); | |||
| T *gradients = GetDeviceAddress<T>(inputs, 3); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 4); | |||
| T *decay = GetDeviceAddress<T>(inputs, 5); | |||
| T *momentum = GetDeviceAddress<T>(inputs, 6); | |||
| T *epsilon = GetDeviceAddress<T>(inputs, 7); | |||
| T *learning_rate = GetDeviceAddress<T>(inputs, 3); | |||
| T *gradients = GetDeviceAddress<T>(inputs, 4); | |||
| RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_, | |||
| RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, | |||
| reinterpret_cast<cudaStream_t>(stream)); | |||
| } else { | |||
| T *variable = GetDeviceAddress<T>(inputs, 0); | |||
| @@ -70,6 +67,11 @@ class RMSPropGpuKernel : public GpuKernel { | |||
| use_center_ = true; | |||
| } | |||
| if (node_name == "ApplyRMSProp") { | |||
| decay_ = GetAttr<float>(kernel_node, "rho"); | |||
| momentum_ = GetAttr<float>(kernel_node, "momentum"); | |||
| epsilon_ = GetAttr<float>(kernel_node, "epsilon"); | |||
| } | |||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| for (auto &dim : input_shape) { | |||
| size_ *= dim; | |||
| @@ -81,24 +83,33 @@ class RMSPropGpuKernel : public GpuKernel { | |||
| protected: | |||
| void InitSizeLists() override { | |||
| size_t input_size = size_ * sizeof(T); | |||
| input_size_list_.push_back(input_size); | |||
| if (use_center_) { | |||
| if (!use_center_) { | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(input_size); | |||
| output_size_list_.push_back(input_size); | |||
| } else { | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(input_size); | |||
| } | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(input_size); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| input_size_list_.push_back(sizeof(T)); | |||
| output_size_list_.push_back(0); | |||
| } | |||
| private: | |||
| size_t size_; | |||
| bool use_center_; | |||
| float decay_; | |||
| float momentum_; | |||
| float epsilon_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| @@ -175,7 +175,7 @@ class FakeQuantWithMinMaxAscend(Cell): | |||
| else: | |||
| quant_fun = P.FakeQuantPerLayer | |||
| ema_fun = P.FakeQuantMinMaxPerLayerUpdate | |||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| @@ -272,7 +272,7 @@ class FakeQuantWithMinMaxGPU(Cell): | |||
| 0, self.out_channels)]).astype(np.float32) | |||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||
| if per_channel: | |||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | |||
| else: | |||
| @@ -175,8 +175,7 @@ class Adam(Optimizer): | |||
| If True, updates the gradients using NAG. | |||
| If False, updates the gradients without using NAG. Default: False. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: | |||
| 1.0. | |||
| loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -210,7 +209,7 @@ class Adam(Optimizer): | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | |||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| self.beta1 = Tensor(beta1, mstype.float32) | |||
| self.beta2 = Tensor(beta2, mstype.float32) | |||
| @@ -122,7 +122,8 @@ class SameTypeShape(PrimitiveWithInfer): | |||
| Checks whether data type and shape of two tensors are the same. | |||
| Raises: | |||
| TypeError or ValueError: If not the same. | |||
| TypeError - If data type not the same. | |||
| ValueError - If shape of two tensors not the same. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. | |||
| @@ -1031,7 +1032,7 @@ class InvertPermutation(PrimitiveWithInfer): | |||
| - **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple | |||
| integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices. | |||
| The values must include 0. There can be no duplicate values or negative values. | |||
| If the input is Tensor, it must be 1-d and the dtype is int. | |||
| If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed. | |||
| Outputs: | |||
| @@ -1061,7 +1062,9 @@ class InvertPermutation(PrimitiveWithInfer): | |||
| z = [x_value[i] for i in range(len(x_value))] | |||
| z.sort() | |||
| validator.check(f'value length', len(x_value), f'unique value length', len(set(x_value)), Rel.EQ, self.name) | |||
| for i in range(1, len(z)): | |||
| if z[i-1] == z[i]: | |||
| raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.") | |||
| validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name) | |||
| validator.check(f'value max', max(x_value), '', len(x_value)-1, Rel.EQ, self.name) | |||
| @@ -258,6 +258,8 @@ class _Reduce(PrimitiveWithInfer): | |||
| args = {'input_x': input_x['dtype']} | |||
| validator.check_tensor_type_same(args, valid_dtype, self.name) | |||
| if axis_v is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name) | |||
| return {'shape': input_shp, | |||
| 'dtype': input_x['dtype'], | |||
| @@ -445,8 +447,9 @@ class ReduceProd(_Reduce): | |||
| Default : False, don't keep these reduced dimensions. | |||
| Inputs: | |||
| - **input_x** (Tensor[Number]) - The input tensor. | |||
| - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||
| - **input_x** (Tensor[Number]) - The input tensor. | |||
| - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions. | |||
| Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, has the same dtype as the 'input_x'. | |||
| @@ -474,8 +477,9 @@ class CumProd(PrimitiveWithInfer): | |||
| reverse (bool): If True, reverse the result along axis. Default: False | |||
| Inputs: | |||
| - **input_x** (Tensor[Number]) - The input tensor. | |||
| - **axis** (int) - The dimensions to compute the cumulative product. | |||
| - **input_x** (Tensor[Number]) - The input tensor. | |||
| - **axis** (int) - The dimensions to compute the cumulative product. | |||
| Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, has the same shape and dtype as the 'input_x'. | |||
| @@ -507,6 +511,10 @@ class CumProd(PrimitiveWithInfer): | |||
| validator.check_subclass("axis", axis_type, mstype.int_, cls_name) | |||
| return x_type | |||
| def infer_value(self, x, axis): | |||
| if axis is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| class MatMul(PrimitiveWithInfer): | |||
| """ | |||
| @@ -669,6 +677,10 @@ class CumSum(PrimitiveWithInfer): | |||
| 'dtype': x['dtype'], | |||
| 'value': None} | |||
| def infer_value(self, x, axis): | |||
| if axis is None: | |||
| raise ValueError(f"For {self.name}, axis must be const.") | |||
| class AddN(PrimitiveWithInfer): | |||
| """ | |||
| @@ -1707,9 +1707,9 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||
| - **moment** (Tensor) - Delta of `var`, must have the same type as `var`. | |||
| - **learning_rate** (Union[Number, Tensor]) - Learning rate. | |||
| - **grad** (Tensor) - Gradients, must have the same type as `var`. | |||
| - **decay** (float) - Decay rate. | |||
| - **momentum** (float) - Momentum. | |||
| - **epsilon** (float) - Ridge term. | |||
| - **decay** (float) - Decay rate. Only constant value is allowed. | |||
| - **momentum** (float) - Momentum. Only constant value is allowed. | |||
| - **epsilon** (float) - Ridge term. Only constant value is allowed. | |||
| Outputs: | |||
| Tensor, parameters to be update. | |||
| @@ -1759,6 +1759,13 @@ class ApplyRMSProp(PrimitiveWithInfer): | |||
| return var_dtype, var_dtype, var_dtype | |||
| return var_dtype | |||
| def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): | |||
| if decay is None or momentum is None or epsilon is None: | |||
| raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") | |||
| if not self.is_ge and self.is_d: | |||
| return None, None, None | |||
| return None | |||
| class ApplyCenteredRMSProp(PrimitiveWithInfer): | |||
| """ | |||
| @@ -379,7 +379,7 @@ class ConfusionMatrix(PrimitiveWithInfer): | |||
| Inputs: | |||
| - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. | |||
| - **predictions** (Tensor) - the labels from prediction, tensor of 1-D. | |||
| the shape same as `labels` and the dtype must be non-negative Integer. | |||
| the shape same as `labels` and the dtype must be non-negative Integer. | |||
| - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. | |||
| Outputs: | |||
| @@ -24,19 +24,25 @@ from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| class NetCenteredRMSProp(nn.Cell): | |||
| def __init__(self): | |||
| super(NetCenteredRMSProp, self).__init__() | |||
| self.rms_opt = P.ApplyCenteredRMSProp() | |||
| def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): | |||
| return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) | |||
| class NetRMSProp(nn.Cell): | |||
| def __init__(self, use_centered): | |||
| def __init__(self, decay, momentum, epsilon): | |||
| super(NetRMSProp, self).__init__() | |||
| self.use_centered = use_centered | |||
| if use_centered: | |||
| self.rms_opt = P.ApplyCenteredRMSProp() | |||
| else: | |||
| self.rms_opt = P.ApplyRMSProp() | |||
| self.decay = decay | |||
| self.momentum = momentum | |||
| self.epsilon = epsilon | |||
| self.rms_opt = P.ApplyRMSProp() | |||
| def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): | |||
| if self.use_centered: | |||
| return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) | |||
| return self.rms_opt(var, rms, mom, lr, g, decay, momentum, epsilon) | |||
| def construct(self, var, g, mg, rms, mom, lr): | |||
| return self.rms_opt(var, rms, mom, lr, g, self.decay, self.momentum, self.epsilon) | |||
| def rmsprop_numpy(variable, gradients, mean_square, moment, | |||
| @@ -76,13 +82,16 @@ def test_rmsprop(): | |||
| if centered: | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetCenteredRMSProp() | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, | |||
| moment_ms, learning_rate, decay, momentum, epsilon) | |||
| else: | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(centered) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, | |||
| moment_ms, learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(decay, momentum, epsilon) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, | |||
| moment_ms, learning_rate) | |||
| error = np.ones(shape=variable_np.shape) * 10e-6 | |||
| diff = variable_ms.asnumpy() - variable_np | |||
| @@ -126,13 +135,15 @@ def test_rmspropcenter(): | |||
| if centered: | |||
| rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetCenteredRMSProp() | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, | |||
| learning_rate, decay, momentum, epsilon) | |||
| else: | |||
| rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(centered) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, | |||
| learning_rate, decay, momentum, epsilon) | |||
| net = NetRMSProp(decay, momentum, epsilon) | |||
| _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, | |||
| learning_rate) | |||
| error = np.ones(shape=variable_np.shape) * 10e-6 | |||
| diff = variable_ms.asnumpy() - variable_np | |||