|
|
|
@@ -24,27 +24,62 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
template <typename T> |
|
|
|
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient, |
|
|
|
size_t size) { |
|
|
|
std::function<void(size_t, size_t)> task; |
|
|
|
if (dtype_ == kNumberTypeFloat32) { |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
AdamFp32(var, m, v, lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_); |
|
|
|
}; |
|
|
|
} else { |
|
|
|
task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; i++) { |
|
|
|
m[i] += (gradient[i] - m[i]) * (1 - beta1); |
|
|
|
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); |
|
|
|
if (use_nesterov_) { |
|
|
|
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon); |
|
|
|
} else { |
|
|
|
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); |
|
|
|
} |
|
|
|
void AdamCPUKernel::LaunchAdam(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
T *var = reinterpret_cast<T *>(inputs[0]->addr); |
|
|
|
T *m = reinterpret_cast<T *>(inputs[1]->addr); |
|
|
|
T *v = reinterpret_cast<T *>(inputs[2]->addr); |
|
|
|
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0]; |
|
|
|
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0]; |
|
|
|
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0]; |
|
|
|
T beta1 = static_cast<T>(reinterpret_cast<float *>(inputs[6]->addr)[0]); |
|
|
|
T beta2 = static_cast<T>(reinterpret_cast<float *>(inputs[7]->addr)[0]); |
|
|
|
T epsilon = static_cast<T>(reinterpret_cast<float *>(inputs[8]->addr)[0]); |
|
|
|
T *gradient = reinterpret_cast<T *>(inputs[9]->addr); |
|
|
|
if (beta1_power - 1.0 == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1."; |
|
|
|
} |
|
|
|
T new_lr = static_cast<T>(lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power)); |
|
|
|
T one = static_cast<T>(1.0); |
|
|
|
// multithreading |
|
|
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(T)) : 1; |
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
for (size_t i = start; i < end; i++) { |
|
|
|
m[i] += (gradient[i] - m[i]) * (one - beta1); |
|
|
|
v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2); |
|
|
|
T sqrt_v = static_cast<T>(std::sqrt(static_cast<double>(v[i]))); |
|
|
|
if (use_nesterov_) { |
|
|
|
var[i] -= new_lr * (m[i] * beta1 + (one - beta1) * gradient[i]) / (sqrt_v + epsilon); |
|
|
|
} else { |
|
|
|
var[i] -= new_lr * m[i] / (sqrt_v + epsilon); |
|
|
|
} |
|
|
|
}; |
|
|
|
} |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, lens); |
|
|
|
} |
|
|
|
|
|
|
|
void AdamCPUKernel::LaunchAdamNnacl(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
float *var = reinterpret_cast<float *>(inputs[0]->addr); |
|
|
|
float *m = reinterpret_cast<float *>(inputs[1]->addr); |
|
|
|
float *v = reinterpret_cast<float *>(inputs[2]->addr); |
|
|
|
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0]; |
|
|
|
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0]; |
|
|
|
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0]; |
|
|
|
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0]; |
|
|
|
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0]; |
|
|
|
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; |
|
|
|
float *gradient = reinterpret_cast<float *>(inputs[9]->addr); |
|
|
|
if (beta1_power - 1.0 == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1."; |
|
|
|
} |
|
|
|
CPUKernelUtils::ParallelFor(task, size); |
|
|
|
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power); |
|
|
|
// multithreading |
|
|
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; |
|
|
|
auto task = [&](size_t start, size_t end) { |
|
|
|
AdamFp32(var, m, v, new_lr, beta1, beta2, epsilon, gradient, start, end, use_nesterov_); |
|
|
|
}; |
|
|
|
CPUKernelUtils::ParallelFor(task, lens); |
|
|
|
} |
|
|
|
|
|
|
|
void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
@@ -77,23 +112,15 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const |
|
|
|
inputs[6]->size != f_size || inputs[7]->size != f_size || inputs[8]->size != f_size) { |
|
|
|
MS_LOG(EXCEPTION) << "The attribute beta_power, beta, lr and epsilon must be float!"; |
|
|
|
} |
|
|
|
auto var = reinterpret_cast<float *>(inputs[0]->addr); |
|
|
|
auto m = reinterpret_cast<float *>(inputs[1]->addr); |
|
|
|
auto v = reinterpret_cast<float *>(inputs[2]->addr); |
|
|
|
float beta1_power = reinterpret_cast<float *>(inputs[3]->addr)[0]; |
|
|
|
float beta2_power = reinterpret_cast<float *>(inputs[4]->addr)[0]; |
|
|
|
float lr = reinterpret_cast<float *>(inputs[5]->addr)[0]; |
|
|
|
float beta1 = reinterpret_cast<float *>(inputs[6]->addr)[0]; |
|
|
|
float beta2 = reinterpret_cast<float *>(inputs[7]->addr)[0]; |
|
|
|
float epsilon = reinterpret_cast<float *>(inputs[8]->addr)[0]; |
|
|
|
auto gradient = reinterpret_cast<float *>(inputs[9]->addr); |
|
|
|
if (beta1_power == 1) { |
|
|
|
MS_LOG(EXCEPTION) << "The beta1_power can't be set 1."; |
|
|
|
|
|
|
|
if (dtype_ == kNumberTypeFloat32) { |
|
|
|
LaunchAdamNnacl(inputs, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat16) { |
|
|
|
LaunchAdam<float16>(inputs, outputs); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Adam not support " << dtype_; |
|
|
|
return false; |
|
|
|
} |
|
|
|
float new_lr = lr * std::sqrt(1.0 - beta2_power) / (1 - beta1_power); |
|
|
|
// multithreading |
|
|
|
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; |
|
|
|
LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
|