/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ #define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ #include #include #include "ps/core/communicator/communicator_base.h" #include "ps/server/common.h" #include "ps/server/iteration_timer.h" #include "ps/server/distributed_count_service.h" #include "ps/server/kernel/round/round_kernel.h" namespace mindspore { namespace ps { namespace server { // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting // and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic. class Round { public: explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000, bool check_count = false, size_t threshold_count = 8); ~Round() = default; void Initialize(const std::shared_ptr &communicator, TimeOutCb timeout_cb, FinishIterCb finish_iteration_cb); // Bind a round kernel to this Round. This method should be called after Initialize. void BindRoundKernel(const std::shared_ptr &kernel); // This method is the callback which will be set to the communicator and called after the corresponding round message // is sent to the server. void LaunchRoundKernel(const std::shared_ptr &message); // Round needs to be reset after each iteration is finished or its timer expires. void Reset(); const std::string &name() const; size_t threshold_count() const; size_t time_window() const; private: // The callbacks which will be set to DistributedCounterService. void OnFirstCountEvent(const std::shared_ptr &message); void OnLastCountEvent(const std::shared_ptr &message); std::string name_; // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set // check_timeout_ to true. bool check_timeout_; // The time window duration for this round when check_timeout_ is set to true. size_t time_window_; // If check_count_ is true, it means the round has to do counting for every round message and the first/last count // event will be triggered. bool check_count_; // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether // the round message count has reached threshold_count_. size_t threshold_count_; std::shared_ptr communicator_; // The round kernel for this Round. std::shared_ptr kernel_; // Some rounds may need timer to eliminate the long tail effect. std::shared_ptr iter_timer_; // The callbacks which will be set to the round kernel. StopTimerCb stop_timer_cb_; FinishIterCb finish_iteration_cb_; FinalizeCb finalize_cb_; }; } // namespace server } // namespace ps } // namespace mindspore #endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_