From: @zpac Reviewed-by: @limingqi107,@kisnwang Signed-off-by: @kisnwangtags/v1.1.0
| @@ -197,18 +197,20 @@ void SparseOptimInfo::ComputeMean(const std::vector<std::vector<size_t>> &shapes | |||
| float *grad_data = reinterpret_cast<float *>(gradient()->addr); | |||
| int *indices_data = reinterpret_cast<int *>(indices()->addr); | |||
| size_t original_row_count = input_shapes.front(); | |||
| if (original_row_count > 0) { | |||
| size_t offset = 0; | |||
| std::map<int64_t, int64_t> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num); | |||
| for (size_t i = 0; i < rank_id; i++) { | |||
| if (rank_dims.count(i) == 0) { | |||
| MS_LOG(EXCEPTION) << "No local shard number for rank " << i; | |||
| if (sharded_) { | |||
| size_t original_row_count = input_shapes.front(); | |||
| if (original_row_count > 0) { | |||
| size_t offset = 0; | |||
| std::map<int64_t, int64_t> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num); | |||
| for (size_t i = 0; i < rank_id; i++) { | |||
| if (rank_dims.count(i) == 0) { | |||
| MS_LOG(EXCEPTION) << "No local shard number for rank " << i; | |||
| } | |||
| offset += rank_dims[i]; | |||
| } | |||
| for (size_t i = 0; i < indices_size; i++) { | |||
| indices_data[i] -= offset; | |||
| } | |||
| offset += rank_dims[i]; | |||
| } | |||
| for (size_t i = 0; i < indices_size; i++) { | |||
| indices_data[i] -= offset; | |||
| } | |||
| } | |||
| @@ -283,7 +285,7 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address | |||
| const AddressPtr &beta1_power, const AddressPtr &beta2_power, | |||
| const AddressPtr &learning_rate, const AddressPtr &beta1, | |||
| const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, | |||
| const AddressPtr &indices) { | |||
| const AddressPtr &indices, bool sharded) { | |||
| inputs_.push_back(weight); | |||
| inputs_.push_back(m); | |||
| inputs_.push_back(v); | |||
| @@ -297,6 +299,7 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address | |||
| inputs_.push_back(indices); | |||
| grads_offset_ = grad->size / sizeof(float); | |||
| indices_offset_ = indices->size / sizeof(int); | |||
| sharded_ = sharded; | |||
| } | |||
| void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { | |||
| @@ -333,7 +336,7 @@ size_t SparseAdamOptimInfo::indices_index() { | |||
| } | |||
| SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, | |||
| const AddressPtr &grad, const AddressPtr &indices) { | |||
| const AddressPtr &grad, const AddressPtr &indices, bool sharded) { | |||
| inputs_.push_back(weight); | |||
| inputs_.push_back(accum); | |||
| inputs_.push_back(linear); | |||
| @@ -341,6 +344,7 @@ SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const Address | |||
| inputs_.push_back(indices); | |||
| grads_offset_ = grad->size / sizeof(float); | |||
| indices_offset_ = indices->size / sizeof(int); | |||
| sharded_ = sharded; | |||
| } | |||
| const AddressPtr &SparseFtrlOptimInfo::gradient() { | |||
| @@ -82,6 +82,7 @@ class SparseOptimInfo : public OptimizerInfo { | |||
| protected: | |||
| size_t grads_offset_{0}; | |||
| size_t indices_offset_{0}; | |||
| bool sharded_{true}; | |||
| }; | |||
| class MomentumOptimInfo : public DenseOptimInfo { | |||
| @@ -101,7 +102,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo { | |||
| SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, | |||
| const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, | |||
| const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, | |||
| const AddressPtr &indices); | |||
| const AddressPtr &indices, bool sharded); | |||
| ~SparseAdamOptimInfo() override = default; | |||
| void Update(const Values &values, const Lengths &lens) override; | |||
| @@ -115,7 +116,7 @@ class SparseAdamOptimInfo : public SparseOptimInfo { | |||
| class SparseFtrlOptimInfo : public SparseOptimInfo { | |||
| public: | |||
| SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, | |||
| const AddressPtr &grad, const AddressPtr &indices); | |||
| const AddressPtr &grad, const AddressPtr &indices, bool sharded); | |||
| ~SparseFtrlOptimInfo() override = default; | |||
| const AddressPtr &gradient(); | |||
| @@ -25,10 +25,12 @@ namespace ps { | |||
| using mindspore::kernel::ps::SparseApplyFtrlPSKernel; | |||
| OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr<PServerKernel> &pserver_kernel, | |||
| const WeightPtr &weight, const Keys &keys, const Values &values, | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) { | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| bool sharded) { | |||
| MS_EXCEPTION_IF_NULL(pserver_kernel); | |||
| MS_EXCEPTION_IF_NULL(inputs_shape); | |||
| OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel); | |||
| OptimizerInfo *optim_info = | |||
| BuildInputs(weight, keys, values, lens, inputs_shape, worker_num, pserver_kernel, sharded); | |||
| MS_EXCEPTION_IF_NULL(optim_info); | |||
| std::vector<size_t> ws_sizes = pserver_kernel->workspace_sizes(); | |||
| BuildWorkspaces(optim_info, ws_sizes, worker_num); | |||
| @@ -108,7 +110,7 @@ AddressPtr OptimizerInfoBuilder::GenInputAddrPtr(const std::string &optim_type, | |||
| OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, | |||
| size_t worker_num, const std::shared_ptr<PServerKernel> &) { | |||
| size_t worker_num, const std::shared_ptr<PServerKernel> &, bool) { | |||
| AddressPtr weight_addr = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(weight_addr); | |||
| weight_addr->addr = weight->data(); | |||
| @@ -135,7 +137,8 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co | |||
| OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, | |||
| size_t worker_num, const std::shared_ptr<PServerKernel> &) { | |||
| size_t worker_num, const std::shared_ptr<PServerKernel> &, | |||
| bool sharded) { | |||
| AddressPtr weight_addr = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(weight_addr); | |||
| weight_addr->addr = weight->data(); | |||
| @@ -178,13 +181,14 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape); | |||
| AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape); | |||
| return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, | |||
| grad, indices); | |||
| grad, indices, sharded); | |||
| } | |||
| OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, | |||
| size_t worker_num, | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel) { | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel, | |||
| bool sharded) { | |||
| MS_EXCEPTION_IF_NULL(inputs_shape); | |||
| AddressPtr weight_addr = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(weight_addr); | |||
| @@ -216,7 +220,7 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape); | |||
| AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape); | |||
| return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices); | |||
| return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded); | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -34,12 +34,12 @@ class OptimizerInfoBuilder { | |||
| virtual ~OptimizerInfoBuilder() = default; | |||
| OptimizerInfo *Build(const std::shared_ptr<PServerKernel> &pserver_kernel, const WeightPtr &weight, const Keys &keys, | |||
| const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, | |||
| size_t worker_num); | |||
| const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| bool sharded); | |||
| virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, | |||
| const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel) = 0; | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) = 0; | |||
| virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector<size_t> &ws_sizes, size_t worker_num); | |||
| virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} | |||
| @@ -57,7 +57,7 @@ class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { | |||
| ~MomentumOptimInfoBuilder() = default; | |||
| OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, | |||
| const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel) override; | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; | |||
| }; | |||
| class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { | |||
| @@ -66,7 +66,7 @@ class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { | |||
| ~SparseAdamOptimInfoBuilder() = default; | |||
| OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, | |||
| const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel) override; | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; | |||
| }; | |||
| class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { | |||
| @@ -75,7 +75,7 @@ class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { | |||
| ~SparseFtrlOptimInfoBuilder() = default; | |||
| OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, | |||
| const InputsShapePtr &inputs_shape, size_t worker_num, | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel) override; | |||
| const std::shared_ptr<PServerKernel> &pserver_kernel, bool sharded) override; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -619,8 +619,8 @@ void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const | |||
| MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pserver_kernel); | |||
| OptimizerInfo *optim = | |||
| builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); | |||
| OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, | |||
| optim_inputs_shape_[key], worker_num_, is_embedding_[key]); | |||
| optim_info.reset(optim); | |||
| optim_infos_[key] = optim_info; | |||
| } else { | |||