Browse Source

add fp16 support for CPU Adam

tags/v1.3.0
zhaoting 4 years ago
parent
commit
e938898336
3 changed files with 71 additions and 39 deletions
  1. +62
    -35
      mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc
  2. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h
  3. +5
    -0
      mindspore/ops/_op_impl/cpu/adam.py

+ 62
- 35
mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc View File

@@ -24,27 +24,62 @@
namespace mindspore {
namespace kernel {
template <typename T>
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t size) {
std::function<void(size_t, size_t)> task;
if (dtype_ == kNumberTypeFloat32) {
task = [&](size_t start, size_t end) {
AdamFp32(var, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
};
} else {
task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
if (use_nesterov_) {
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
} else {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
}
void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *var = reinterpret_cast<T *>(inputs[0]->addr);
T *m = reinterpret_cast<T *>(inputs[1]->addr);
T *v = reinterpret_cast<T *>(inputs[2]->addr);
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[6]->addr)[0]);
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[7]->addr)[0]);
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[8]->addr)[0]);
T *gradient = reinterpret_cast<T *>(inputs[9]->addr);
if (beta1_power - 1.0 == 0) {
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
}
T new_lr = static_cast<T>(lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power));
T one = static_cast<T>(1.0);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1;
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (one - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2);
T sqrt_v = static_cast<T>(std::sqrt(static_cast<double>(v[i])));
if (use_nesterov_) {
var[i] -= new_lr * (m[i] * beta1 + (one - beta1) * gradient[i]) / (sqrt_v + epsilon);
} else {
var[i] -= new_lr * m[i] / (sqrt_v + epsilon);
}
};
}
};
CPUKernelUtils::ParallelFor(task, lens);
}

void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
float *var = reinterpret_cast<float *>(inputs[0]->addr);
float *m = reinterpret_cast<float *>(inputs[1]->addr);
float *v = reinterpret_cast<float *>(inputs[2]->addr);
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0];
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
float *gradient = reinterpret_cast<float *>(inputs[9]->addr);
if (beta1_power - 1.0 == 0) {
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";
}
CPUKernelUtils::ParallelFor(task, size);
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
auto task = [&](size_t start, size_t end) {
AdamFp32(var, m, v, new_lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_);
};
CPUKernelUtils::ParallelFor(task, lens);
}

void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -77,23 +112,15 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const
inputs[6]->size != f_size || inputs[7]->size != f_size || inputs[8]->size != f_size) {
MS_LOG(EXCEPTION) << "The attribute beta_power, beta, lr and epsilon must be float!";
}
auto var = reinterpret_cast<float *>(inputs[0]->addr);
auto m = reinterpret_cast<float *>(inputs[1]->addr);
auto v = reinterpret_cast<float *>(inputs[2]->addr);
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0];
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0];
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0];
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0];
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0];
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0];
auto gradient = reinterpret_cast<float *>(inputs[9]->addr);
if (beta1_power == 1) {
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1.";

if (dtype_ == kNumberTypeFloat32) {
LaunchAdamNnacl(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchAdam<float16>(inputs, outputs);
} else {
MS_LOG(EXCEPTION) << "Adam not support " << dtype_;
return false;
}
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power);
// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
return true;
}
} // namespace kernel


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h View File

@@ -28,10 +28,10 @@ class AdamCPUKernel : public CPUKernel {
AdamCPUKernel() = default;
~AdamCPUKernel() override = default;
template <typename T>
void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t size);
void InitKernel(const CNodePtr &kernel_node) override;
void LaunchAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

void LaunchAdamNnacl(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

@@ -40,7 +40,7 @@ class AdamCPUKernel : public CPUKernel {
TypeId dtype_{kTypeUnknown};
};

MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel)
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCPUKernel);
} // namespace kernel
} // namespace mindspore



+ 5
- 0
mindspore/ops/_op_impl/cpu/adam.py View File

@@ -35,6 +35,11 @@ adam_op_info = CpuRegOp("Adam") \
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.get_op_info()




Loading…
Cancel
Save