| @@ -22,9 +22,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| static constexpr size_t BATCH_SIZE = 10000; | |||
| static constexpr float MIN_GLOBAL_NORM = 1e-10; | |||
| void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &) { | |||
| auto var = reinterpret_cast<float *>(inputs[VAR]->addr); | |||
| auto m = reinterpret_cast<float *>(inputs[M]->addr); | |||
| auto v = reinterpret_cast<float *>(inputs[V]->addr); | |||
| auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex]; | |||
| @@ -33,6 +34,12 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto | |||
| auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex]; | |||
| auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex]; | |||
| auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr); | |||
| auto var = reinterpret_cast<float *>(inputs[VAR]->addr); | |||
| auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex]; | |||
| if (global_norm < MIN_GLOBAL_NORM) { | |||
| global_norm = 1.0f; | |||
| } | |||
| auto global_norm_reciprocal = 1.0f / global_norm; | |||
| const auto beta1_minus = 1 - beta1; | |||
| const auto beta2_minus = 1 - beta2; | |||
| @@ -42,10 +49,10 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto | |||
| task = [&](size_t start, size_t end) { | |||
| size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16), | |||
| start, end); | |||
| global_norm_reciprocal, start, end); | |||
| // remaining | |||
| for (; i < end; i++) { | |||
| auto temp = static_cast<float>(gradient16[i]); | |||
| auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal; | |||
| m[i] += (temp - m[i]) * beta1_minus; | |||
| v[i] += (temp * temp - v[i]) * beta2_minus; | |||
| auto update = m[i] / (std::sqrt(v[i]) + epsilon); | |||
| @@ -53,12 +60,11 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto | |||
| var[i] -= lr * update; | |||
| } | |||
| }; | |||
| ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); | |||
| CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE); | |||
| } | |||
| void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &) { | |||
| auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr); | |||
| auto m = reinterpret_cast<float *>(inputs[M]->addr); | |||
| auto v = reinterpret_cast<float *>(inputs[V]->addr); | |||
| auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex]; | |||
| @@ -67,6 +73,12 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto | |||
| auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex]; | |||
| auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex]; | |||
| auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr); | |||
| auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr); | |||
| auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex]; | |||
| if (global_norm < MIN_GLOBAL_NORM) { | |||
| global_norm = 1.0f; | |||
| } | |||
| auto global_norm_reciprocal = 1.0f / global_norm; | |||
| const auto beta1_minus = 1 - beta1; | |||
| const auto beta2_minus = 1 - beta2; | |||
| @@ -76,11 +88,11 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto | |||
| task = [&](size_t start, size_t end) { | |||
| size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay, | |||
| reinterpret_cast<int16_t *>(gradient16), start, end); | |||
| reinterpret_cast<int16_t *>(gradient16), global_norm_reciprocal, start, end); | |||
| // remaining | |||
| for (; i < end; i++) { | |||
| auto temp_var = static_cast<float>(var16[i]); | |||
| auto temp_grad = static_cast<float>(gradient16[i]); | |||
| auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal; | |||
| m[i] += (temp_grad - m[i]) * beta1_minus; | |||
| v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus; | |||
| auto update = m[i] / (std::sqrt(v[i]) + epsilon); | |||
| @@ -89,7 +101,7 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto | |||
| var16[i] = static_cast<float16>(temp_var); | |||
| } | |||
| }; | |||
| ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_); | |||
| CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE); | |||
| } | |||
| void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| @@ -123,10 +135,12 @@ void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) | |||
| void FusedCastAdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) const { | |||
| if (inputs.size() != kFusedCastAdamWeightDecayInputNum) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs."; | |||
| MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs " | |||
| << kFusedCastAdamWeightDecayInputNum << " inputs."; | |||
| } | |||
| if (outputs.size() != kFusedCastAdamWeightDecayOutputNum) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs."; | |||
| MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs " | |||
| << kFusedCastAdamWeightDecayOutputNum << " outputs."; | |||
| } | |||
| size_t elem_size_fp32 = elem_num_ * kSizeFloat32; | |||
| size_t elem_size_fp16 = elem_num_ * kSizeFloat16; | |||
| @@ -26,7 +26,7 @@ namespace kernel { | |||
| constexpr size_t kSizeFloat32 = sizeof(float); | |||
| constexpr size_t kSizeFloat16 = sizeof(float16); | |||
| constexpr size_t kScalarIndex = 0; | |||
| constexpr size_t kFusedCastAdamWeightDecayInputNum = 9; | |||
| constexpr size_t kFusedCastAdamWeightDecayInputNum = 10; | |||
| constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3; | |||
| class FusedCastAdamWeightDecayCPUKernel : public CPUKernel { | |||
| @@ -45,7 +45,7 @@ class FusedCastAdamWeightDecayCPUKernel : public CPUKernel { | |||
| size_t elem_num_{0}; | |||
| TypeId var_dtype_{kTypeUnknown}; | |||
| TypeId gradient_dtype_{kTypeUnknown}; | |||
| enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD }; | |||
| enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD, GLOBAL_NORM }; | |||
| }; | |||
| MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay, | |||
| @@ -59,6 +59,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -75,6 +76,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| @@ -207,7 +207,7 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f | |||
| } | |||
| size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, | |||
| const int16_t *gradient16, size_t start, size_t end) { | |||
| const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end) { | |||
| size_t c1 = start; | |||
| #ifdef ENABLE_AVX512 | |||
| __m512 beta1_r = _mm512_set1_ps(beta1); | |||
| @@ -217,6 +217,7 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, | |||
| __m512 lr_neg_r = _mm512_set1_ps(-lr); | |||
| __m512 epsilon_r = _mm512_set1_ps(epsilon); | |||
| __m512 decay_r = _mm512_set1_ps(decay); | |||
| __m512 global_norm_reciprocal_r = _mm512_set1_ps(global_norm_reciprocal); | |||
| size_t c16 = ((end - start) / C16NUM) * C16NUM + start; | |||
| const int16_t *gradient16_ptr = gradient16 + start; | |||
| @@ -230,6 +231,7 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, | |||
| __m512 v_r = _mm512_loadu_ps(v_ptr); | |||
| __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr))); | |||
| g_r = _mm512_mul_ps(g_r, global_norm_reciprocal_r); | |||
| m_r = _mm512_mul_ps(m_r, beta1_r); | |||
| v_r = _mm512_mul_ps(v_r, beta2_r); | |||
| __m512 avx_r0 = _mm512_mul_ps(g_r, g_r); | |||
| @@ -253,7 +255,8 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, | |||
| } | |||
| size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, | |||
| float decay, const int16_t *gradient16, size_t start, size_t end) { | |||
| float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start, | |||
| size_t end) { | |||
| size_t c1 = start; | |||
| #ifdef ENABLE_AVX512 | |||
| __m512 beta1_r = _mm512_set1_ps(beta1); | |||
| @@ -263,6 +266,7 @@ size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float bet | |||
| __m512 lr_neg_r = _mm512_set1_ps(-lr); | |||
| __m512 epsilon_r = _mm512_set1_ps(epsilon); | |||
| __m512 decay_r = _mm512_set1_ps(decay); | |||
| __m512 global_norm_reciprocal_r = _mm512_set1_ps(global_norm_reciprocal); | |||
| size_t c16 = ((end - start) / C16NUM) * C16NUM + start; | |||
| const int16_t *gradient16_ptr = gradient16 + start; | |||
| @@ -275,7 +279,7 @@ size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float bet | |||
| __m512 m_r = _mm512_loadu_ps(m_ptr); | |||
| __m512 v_r = _mm512_loadu_ps(v_ptr); | |||
| __m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr))); | |||
| g_r = _mm512_mul_ps(g_r, global_norm_reciprocal_r); | |||
| m_r = _mm512_mul_ps(m_r, beta1_r); | |||
| v_r = _mm512_mul_ps(v_r, beta2_r); | |||
| __m512 avx_r0 = _mm512_mul_ps(g_r, g_r); | |||
| @@ -39,9 +39,10 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float | |||
| int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, | |||
| const float *gradient, size_t start, size_t end); | |||
| size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, | |||
| const int16_t *gradient16, size_t start, size_t end); | |||
| const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end); | |||
| size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, | |||
| float decay, const int16_t *gradient16, size_t start, size_t end); | |||
| float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start, | |||
| size_t end); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -541,14 +541,14 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer): | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape, | |||
| epsilon_shape, decay_shape, grad_shape): | |||
| epsilon_shape, decay_shape, grad_shape, global_norm): | |||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | |||
| return var_shape, m_shape, v_shape | |||
| def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, | |||
| epsilon_dtype, decay_dtype, grad_dtype): | |||
| epsilon_dtype, decay_dtype, grad_dtype, global_norm): | |||
| args = {"m": m_dtype, "v": v_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) | |||
| validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name) | |||