|
|
@@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr |
|
|
inputs_.push_back(momentum); |
|
|
inputs_.push_back(momentum); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { |
|
|
|
|
|
size_t lr_offset = 0; |
|
|
|
|
|
float *lr = values.data() + lr_offset; |
|
|
|
|
|
auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float)); |
|
|
|
|
|
if (ret != 0) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } |
|
|
const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } |
|
|
|
|
|
|
|
|
const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } |
|
|
const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } |
|
|
|