Browse Source

add global norm for fused adam

tags/v1.6.0
kswang 4 years ago
parent
commit
427610d3a9
5 changed files with 40 additions and 19 deletions
  1. +24
    -10
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.cc
  2. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.h
  3. +7
    -3
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c
  4. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.h
  5. +2
    -2
      mindspore/ops/operations/inner_ops.py

+ 24
- 10
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.cc View File

@@ -22,9 +22,10 @@

namespace mindspore {
namespace kernel {
static constexpr size_t BATCH_SIZE = 10000;
static constexpr float MIN_GLOBAL_NORM = 1e-10;
void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
@@ -33,6 +34,12 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
auto var = reinterpret_cast<float *>(inputs[VAR]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex];
if (global_norm < MIN_GLOBAL_NORM) {
global_norm = 1.0f;
}
auto global_norm_reciprocal = 1.0f / global_norm;
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;

@@ -42,10 +49,10 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto

task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp32(var, m, v, lr, beta1, beta2, epsilon, decay, reinterpret_cast<int16_t *>(gradient16),
start, end);
global_norm_reciprocal, start, end);
// remaining
for (; i < end; i++) {
auto temp = static_cast<float>(gradient16[i]);
auto temp = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp - m[i]) * beta1_minus;
v[i] += (temp * temp - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
@@ -53,12 +60,11 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp32(const std::vecto
var[i] -= lr * update;
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE);
}

void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &) {
auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr);
auto m = reinterpret_cast<float *>(inputs[M]->addr);
auto v = reinterpret_cast<float *>(inputs[V]->addr);
auto lr = reinterpret_cast<float *>(inputs[LR]->addr)[kScalarIndex];
@@ -67,6 +73,12 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto
auto epsilon = reinterpret_cast<float *>(inputs[EPSILON]->addr)[kScalarIndex];
auto decay = reinterpret_cast<float *>(inputs[DECAY]->addr)[kScalarIndex];
auto gradient16 = reinterpret_cast<float16 *>(inputs[GRAD]->addr);
auto var16 = reinterpret_cast<float16 *>(inputs[VAR]->addr);
auto global_norm = reinterpret_cast<float *>(inputs[GLOBAL_NORM]->addr)[kScalarIndex];
if (global_norm < MIN_GLOBAL_NORM) {
global_norm = 1.0f;
}
auto global_norm_reciprocal = 1.0f / global_norm;
const auto beta1_minus = 1 - beta1;
const auto beta2_minus = 1 - beta2;

@@ -76,11 +88,11 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto

task = [&](size_t start, size_t end) {
size_t i = FusedCastAdamFp16(reinterpret_cast<int16_t *>(var16), m, v, lr, beta1, beta2, epsilon, decay,
reinterpret_cast<int16_t *>(gradient16), start, end);
reinterpret_cast<int16_t *>(gradient16), global_norm_reciprocal, start, end);
// remaining
for (; i < end; i++) {
auto temp_var = static_cast<float>(var16[i]);
auto temp_grad = static_cast<float>(gradient16[i]);
auto temp_grad = static_cast<float>(gradient16[i]) * global_norm_reciprocal;
m[i] += (temp_grad - m[i]) * beta1_minus;
v[i] += (temp_grad * temp_grad - v[i]) * beta2_minus;
auto update = m[i] / (std::sqrt(v[i]) + epsilon);
@@ -89,7 +101,7 @@ void FusedCastAdamWeightDecayCPUKernel::LaunchFusedCastAdamFp16(const std::vecto
var16[i] = static_cast<float16>(temp_var);
}
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
CPUKernelUtils::ParallelFor(task, lens, BATCH_SIZE);
}

void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -123,10 +135,12 @@ void FusedCastAdamWeightDecayCPUKernel::InitKernel(const CNodePtr &kernel_node)
void FusedCastAdamWeightDecayCPUKernel::CheckParam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) const {
if (inputs.size() != kFusedCastAdamWeightDecayInputNum) {
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs 9 inputs.";
MS_LOG(EXCEPTION) << "Input number is " << inputs.size() << ", but AdamWeightDecay needs "
<< kFusedCastAdamWeightDecayInputNum << " inputs.";
}
if (outputs.size() != kFusedCastAdamWeightDecayOutputNum) {
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs 3 outputs.";
MS_LOG(EXCEPTION) << "Output number is " << outputs.size() << ", but AdamWeightDecay needs "
<< kFusedCastAdamWeightDecayOutputNum << " outputs.";
}
size_t elem_size_fp32 = elem_num_ * kSizeFloat32;
size_t elem_size_fp16 = elem_num_ * kSizeFloat16;


+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/fused_cast_adam_weight_decay_cpu_kernel.h View File

@@ -26,7 +26,7 @@ namespace kernel {
constexpr size_t kSizeFloat32 = sizeof(float);
constexpr size_t kSizeFloat16 = sizeof(float16);
constexpr size_t kScalarIndex = 0;
constexpr size_t kFusedCastAdamWeightDecayInputNum = 9;
constexpr size_t kFusedCastAdamWeightDecayInputNum = 10;
constexpr size_t kFusedCastAdamWeightDecayOutputNum = 3;

class FusedCastAdamWeightDecayCPUKernel : public CPUKernel {
@@ -45,7 +45,7 @@ class FusedCastAdamWeightDecayCPUKernel : public CPUKernel {
size_t elem_num_{0};
TypeId var_dtype_{kTypeUnknown};
TypeId gradient_dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD, GLOBAL_NORM };
};

MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
@@ -59,6 +59,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
@@ -75,6 +76,7 @@ MS_REG_CPU_KERNEL(FusedCastAdamWeightDecay,
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),


+ 7
- 3
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c View File

@@ -207,7 +207,7 @@ int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, f
}

size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end) {
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
__m512 beta1_r = _mm512_set1_ps(beta1);
@@ -217,6 +217,7 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1,
__m512 lr_neg_r = _mm512_set1_ps(-lr);
__m512 epsilon_r = _mm512_set1_ps(epsilon);
__m512 decay_r = _mm512_set1_ps(decay);
__m512 global_norm_reciprocal_r = _mm512_set1_ps(global_norm_reciprocal);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;

const int16_t *gradient16_ptr = gradient16 + start;
@@ -230,6 +231,7 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1,
__m512 v_r = _mm512_loadu_ps(v_ptr);
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));

g_r = _mm512_mul_ps(g_r, global_norm_reciprocal_r);
m_r = _mm512_mul_ps(m_r, beta1_r);
v_r = _mm512_mul_ps(v_r, beta2_r);
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);
@@ -253,7 +255,8 @@ size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1,
}

size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, size_t start, size_t end) {
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
size_t end) {
size_t c1 = start;
#ifdef ENABLE_AVX512
__m512 beta1_r = _mm512_set1_ps(beta1);
@@ -263,6 +266,7 @@ size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float bet
__m512 lr_neg_r = _mm512_set1_ps(-lr);
__m512 epsilon_r = _mm512_set1_ps(epsilon);
__m512 decay_r = _mm512_set1_ps(decay);
__m512 global_norm_reciprocal_r = _mm512_set1_ps(global_norm_reciprocal);
size_t c16 = ((end - start) / C16NUM) * C16NUM + start;

const int16_t *gradient16_ptr = gradient16 + start;
@@ -275,7 +279,7 @@ size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float bet
__m512 m_r = _mm512_loadu_ps(m_ptr);
__m512 v_r = _mm512_loadu_ps(v_ptr);
__m512 g_r = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(gradient16_ptr)));
g_r = _mm512_mul_ps(g_r, global_norm_reciprocal_r);
m_r = _mm512_mul_ps(m_r, beta1_r);
v_r = _mm512_mul_ps(v_r, beta2_r);
__m512 avx_r0 = _mm512_mul_ps(g_r, g_r);


+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.h View File

@@ -39,9 +39,10 @@ int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float
int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const float *gradient, size_t start, size_t end);
size_t FusedCastAdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay,
const int16_t *gradient16, size_t start, size_t end);
const int16_t *gradient16, float global_norm_reciprocal, size_t start, size_t end);
size_t FusedCastAdamFp16(int16_t *var16, float *m, float *v, float lr, float beta1, float beta2, float epsilon,
float decay, const int16_t *gradient16, size_t start, size_t end);
float decay, const int16_t *gradient16, float global_norm_reciprocal, size_t start,
size_t end);
#ifdef __cplusplus
}
#endif


+ 2
- 2
mindspore/ops/operations/inner_ops.py View File

@@ -541,14 +541,14 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer):
validator.check_value_type("use_locking", use_locking, [bool], self.name)

def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
epsilon_shape, decay_shape, grad_shape):
epsilon_shape, decay_shape, grad_shape, global_norm):
validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
return var_shape, m_shape, v_shape

def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
epsilon_dtype, decay_dtype, grad_dtype):
epsilon_dtype, decay_dtype, grad_dtype, global_norm):
args = {"m": m_dtype, "v": v_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name)


Loading…
Cancel
Save