|
|
|
@@ -84,6 +84,7 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) |
|
|
|
|
|
|
|
if (batch_size == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Error occur in launch kernel"; |
|
|
|
return; |
|
|
|
} |
|
|
|
while (start < length) { |
|
|
|
size_t end = (start + batch_size) > length ? length : (start + batch_size); |
|
|
|
@@ -98,7 +99,8 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end) { |
|
|
|
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start, |
|
|
|
size_t end) { |
|
|
|
// DataType can only be float32 or float16, so eps will not be zero. |
|
|
|
using DataType = typename std::iterator_traits<T>::value_type; |
|
|
|
const DataType one = DataType(1); |
|
|
|
|