Browse Source

!4151 Fix ps training precision error

Merge pull request !4151 from ZPaC/master-fix-ps-training-precision-error
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
9efbfb8af1
2 changed files with 10 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h

+ 9
- 0
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc View File

@@ -126,6 +126,15 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr
inputs_.push_back(momentum);
}

void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) {
size_t lr_offset = 0;
float *lr = values.data() + lr_offset;
auto ret = memcpy_s(inputs_[2]->addr, sizeof(float), lr, sizeof(float));
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
}
}

const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; }

const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; }


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h View File

@@ -82,6 +82,7 @@ class MomentumOptimInfo : public DenseOptimInfo {
const AddressPtr &gradient, const AddressPtr &momentum);
~MomentumOptimInfo() override = default;

void Update(const Values &values, const Lengths &lens) override;
const AddressPtr &gradient();
const AddressPtr &indices();
size_t grad_index() override;


Loading…
Cancel
Save