Browse Source

!9934 Supports embedding table with broadcast slicer

From: @zpac
Reviewed-by: @limingqi107,@kisnwang
Signed-off-by: @kisnwang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d5a2489acc
5 changed files with 39 additions and 30 deletions
  1. +17
    -13
      mindspore/ccsrc/ps/optimizer_info.cc
  2. +3
    -2
      mindspore/ccsrc/ps/optimizer_info.h
  3. +11
    -7
      mindspore/ccsrc/ps/optimizer_info_builder.cc
  4. +6
    -6
      mindspore/ccsrc/ps/optimizer_info_builder.h
  5. +2
    -2
      mindspore/ccsrc/ps/parameter_server.h

+ 17
- 13
mindspore/ccsrc/ps/optimizer_info.cc View File

@@ -197,18 +197,20 @@ void SparseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &shapes
float *grad_data = reinterpret_cast<float *>(gradient()->addr);
int *indices_data = reinterpret_cast<int *>(indices()->addr);

size_t original_row_count = input_shapes.front();
if (original_row_count > 0) {
size_t offset = 0;
std::map<int64_t, int64_t> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num);
for (size_t i = 0; i < rank_id; i++) {
if (rank_dims.count(i) == 0) {
MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
if (sharded_) {
size_t original_row_count = input_shapes.front();
if (original_row_count > 0) {
size_t offset = 0;
std::map<int64_t, int64_t> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num);
for (size_t i = 0; i < rank_id; i++) {
if (rank_dims.count(i) == 0) {
MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
}
offset += rank_dims[i];
}
for (size_t i = 0; i < indices_size; i++) {
indices_data[i] -= offset;
}
offset += rank_dims[i];
}
for (size_t i = 0; i < indices_size; i++) {
indices_data[i] -= offset;
}
}

@@ -283,7 +285,7 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
const AddressPtr &beta1_power, const AddressPtr &beta2_power,
const AddressPtr &learning_rate, const AddressPtr &beta1,
const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
const AddressPtr &indices) {
const AddressPtr &indices, bool sharded) {
inputs_.push_back(weight);
inputs_.push_back(m);
inputs_.push_back(v);
@@ -297,6 +299,7 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address
inputs_.push_back(indices);
grads_offset_ = grad->size / sizeof(float);
indices_offset_ = indices->size / sizeof(int);
sharded_ = sharded;
}

void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) {
@@ -333,7 +336,7 @@ size_t SparseAdamOptimInfo::indices_index() {
}

SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
const AddressPtr &grad, const AddressPtr &indices) {
const AddressPtr &grad, const AddressPtr &indices, bool sharded) {
inputs_.push_back(weight);
inputs_.push_back(accum);
inputs_.push_back(linear);
@@ -341,6 +344,7 @@ SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const Address
inputs_.push_back(indices);
grads_offset_ = grad->size / sizeof(float);
indices_offset_ = indices->size / sizeof(int);
sharded_ = sharded;
}

const AddressPtr &SparseFtrlOptimInfo::gradient() {


+ 3
- 2
mindspore/ccsrc/ps/optimizer_info.h View File

@@ -82,6 +82,7 @@ class SparseOptimInfo : public OptimizerInfo {
protected:
size_t grads_offset_{0};
size_t indices_offset_{0};
bool sharded_{true};
};

class MomentumOptimInfo : public DenseOptimInfo {
@@ -101,7 +102,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power,
const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1,
const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad,
const AddressPtr &indices);
const AddressPtr &indices, bool sharded);
~SparseAdamOptimInfo() override = default;

void Update(const Values &values, const Lengths &lens) override;
@@ -115,7 +116,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo {
class SparseFtrlOptimInfo : public SparseOptimInfo {
public:
SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear,
const AddressPtr &grad, const AddressPtr &indices);
const AddressPtr &grad, const AddressPtr &indices, bool sharded);
~SparseFtrlOptimInfo() override = default;

const AddressPtr &gradient();


+ 11
- 7
mindspore/ccsrc/ps/optimizer_info_builder.cc View File

@@ -25,10 +25,12 @@ namespace ps {
using mindspore::kernel::ps::SparseApplyFtrlPSKernel;
OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr<PServerKernel> &pserver_kernel,
const WeightPtr &weight, const Keys &keys, const Values &values,
const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) {
const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num,
bool sharded) {
MS_EXCEPTION_IF_NULL(pserver_kernel);
MS_EXCEPTION_IF_NULL(inputs_shape);
OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel);
OptimizerInfo *optim_info =
BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel, sharded);
MS_EXCEPTION_IF_NULL(optim_info);
std::vector<size_t> ws_sizes = pserver_kernel->workspace_sizes();
BuildWorkspaces(optim_info, ws_sizes, worker_num);
@@ -108,7 +110,7 @@ AddressPtr OptimizerInfoBuilder::GenInputAddrPtr(const std::string &optim_type,

OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
const Lengths &lens, const InputsShapePtr &inputs_shape,
size_t worker_num, const std::shared_ptr<PServerKernel> &) {
size_t worker_num, const std::shared_ptr<PServerKernel> &, bool) {
AddressPtr weight_addr = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(weight_addr);
weight_addr->addr = weight->data();
@@ -135,7 +137,8 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co

OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
const Lengths &lens, const InputsShapePtr &inputs_shape,
size_t worker_num, const std::shared_ptr<PServerKernel> &) {
size_t worker_num, const std::shared_ptr<PServerKernel> &,
bool sharded) {
AddressPtr weight_addr = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(weight_addr);
weight_addr->addr = weight->data();
@@ -178,13 +181,14 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight,
AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape);
AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape);
return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
grad, indices);
grad, indices, sharded);
}

OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
const Lengths &lens, const InputsShapePtr &inputs_shape,
size_t worker_num,
const std::shared_ptr<PServerKernel> &pserver_kernel) {
const std::shared_ptr<PServerKernel> &pserver_kernel,
bool sharded) {
MS_EXCEPTION_IF_NULL(inputs_shape);
AddressPtr weight_addr = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(weight_addr);
@@ -216,7 +220,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight,

AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape);
AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape);
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices);
return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded);
}
} // namespace ps
} // namespace mindspore

+ 6
- 6
mindspore/ccsrc/ps/optimizer_info_builder.h View File

@@ -34,12 +34,12 @@ class OptimizerInfoBuilder {
virtual ~OptimizerInfoBuilder() = default;

OptimizerInfo *Build(const std::shared_ptr<PServerKernel> &pserver_kernel, const WeightPtr &weight, const Keys &keys,
const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape,
size_t worker_num);
const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num,
bool sharded);

virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values,
const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num,
const std::shared_ptr<PServerKernel> &pserver_kernel) = 0;
const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) = 0;

virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes, size_t worker_num);
virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {}
@@ -57,7 +57,7 @@ class MomentumOptimInfoBuilder : public OptimizerInfoBuilder {
~MomentumOptimInfoBuilder() = default;
OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens,
const InputsShapePtr &inputs_shape, size_t worker_num,
const std::shared_ptr<PServerKernel> &pserver_kernel) override;
const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override;
};

class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder {
@@ -66,7 +66,7 @@ class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder {
~SparseAdamOptimInfoBuilder() = default;
OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens,
const InputsShapePtr &inputs_shape, size_t worker_num,
const std::shared_ptr<PServerKernel> &pserver_kernel) override;
const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override;
};

class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder {
@@ -75,7 +75,7 @@ class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder {
~SparseFtrlOptimInfoBuilder() = default;
OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens,
const InputsShapePtr &inputs_shape, size_t worker_num,
const std::shared_ptr<PServerKernel> &pserver_kernel) override;
const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override;
};
} // namespace ps
} // namespace mindspore


+ 2
- 2
mindspore/ccsrc/ps/parameter_server.h View File

@@ -619,8 +619,8 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const
MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
}
MS_EXCEPTION_IF_NULL(pserver_kernel);
OptimizerInfo *optim =
builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_);
OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
optim_info.reset(optim);
optim_infos_[key] = optim_info;
} else {


Loading…
Cancel
Save