Browse Source

!11272 Fix Apply Adagrad for cpu CodeDex.

From: @yang_chun
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5e5489d59f
2 changed files with 4 additions and 2 deletions
  1. +3
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.h

+ 3
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.cc View File

@@ -84,6 +84,7 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)


if (batch_size == 0) { if (batch_size == 0) {
MS_LOG(EXCEPTION) << "Error occur in launch kernel"; MS_LOG(EXCEPTION) << "Error occur in launch kernel";
return;
} }
while (start < length) { while (start < length) {
size_t end = (start + batch_size) > length ? length : (start + batch_size); size_t end = (start + batch_size) > length ? length : (start + batch_size);
@@ -98,7 +99,8 @@ void ApplyAdagradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs)
} }


template <typename T> template <typename T>
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end) {
void ApplyAdagradCPUKernel::LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start,
size_t end) {
// DataType can only be float32 or float16, so eps will not be zero. // DataType can only be float32 or float16, so eps will not be zero.
using DataType = typename std::iterator_traits<T>::value_type; using DataType = typename std::iterator_traits<T>::value_type;
const DataType one = DataType(1); const DataType one = DataType(1);


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/apply_adagrad_cpu_kernel.h View File

@@ -38,7 +38,7 @@ class ApplyAdagradCPUKernel : public CPUKernel {
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs); void LaunchKernel(const std::vector<AddressPtr> &inputs);
template <typename T> template <typename T>
void LaunchApplyAdagrad(T var, T accum, T lr, T gradient, size_t start, size_t end);
void LaunchApplyAdagrad(T const var, T const accum, const T lr, const T gradient, size_t start, size_t end);
bool update_slots_{true}; bool update_slots_{true};
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
}; };


Loading…
Cancel
Save