Add class Message, AddrStorage Add Kernel folder Add consistent hash ring Add updater for server. Add iteration timer Add executor for serverpull/15916/head
| @@ -39,6 +39,20 @@ if(NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| endif() | |||
| if(WIN32 OR NOT ENABLE_CPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | |||
| add_subdirectory(ps_cache) | |||
| @@ -0,0 +1,189 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <numeric> | |||
| #include <climits> | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "proto/ps.pb.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/utils.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/core/communicator/http_message_handler.h" | |||
| #include "ps/core/communicator/tcp_server.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // Definitions for the server framework. | |||
| enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; | |||
| enum CommType { HTTP = 0, TCP }; | |||
| enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | |||
| using kernel::Address; | |||
| using kernel::AddressPtr; | |||
| using kernel::CPUKernel; | |||
| using TimeOutCb = std::function<void(void)>; | |||
| using StopTimerCb = std::function<void(void)>; | |||
| using FinishIterCb = std::function<void(void)>; | |||
| using FinalizeCb = std::function<void(void)>; | |||
| // Information about whether server kernel will reuse kernel node memory from the front end. | |||
| // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". | |||
| // Value refers to the kernel node's parameter index. | |||
| using ReuseKernelNodeInfo = std::map<std::string, size_t>; | |||
| // UploadData refers to the data which is uploaded by workers. | |||
| // Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker. | |||
| // Value refers to the data of the key. | |||
| // We use Address instead of AddressPtr because: | |||
| // 1. Address doesn't need to call make_shared<T> so it has better performance. | |||
| // 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new | |||
| // weights, etc. Address is enough to store these data. | |||
| // Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the | |||
| // related logic is done. | |||
| using UploadData = std::map<std::string, Address>; | |||
| constexpr auto kWeight = "weight"; | |||
| constexpr auto kAccumulation = "accum"; | |||
| constexpr auto kLearningRate = "lr"; | |||
| constexpr auto kGradient = "grad"; | |||
| constexpr auto kNewGradient = "new_grad"; | |||
| constexpr auto kMomentum = "momentum"; | |||
| constexpr auto kIndices = "indices"; | |||
| constexpr auto kAdamM = "m"; | |||
| constexpr auto kAdamV = "v"; | |||
| constexpr auto kAdamBeta1Power = "beta1_power"; | |||
| constexpr auto kAdamBeta2Power = "beta2_power"; | |||
| constexpr auto kAdamBeta1 = "beta1"; | |||
| constexpr auto kAdamBeta2 = "beta2"; | |||
| constexpr auto kAdamEps = "eps"; | |||
| constexpr auto kFtrlLinear = "linear"; | |||
| // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is | |||
| // launched. | |||
| using OptimParamNameToIndex = std::map<std::string, std::map<std::string, size_t>>; | |||
| const OptimParamNameToIndex kMomentumNameToIdx = { | |||
| {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}}; | |||
| const OptimParamNameToIndex kAdamNameToIdx = {{"inputs", | |||
| {{kWeight, 0}, | |||
| {kAdamM, 1}, | |||
| {kAdamV, 2}, | |||
| {kAdamBeta1Power, 3}, | |||
| {kAdamBeta2Power, 4}, | |||
| {kLearningRate, 5}, | |||
| {kAdamBeta1, 6}, | |||
| {kAdamBeta2, 7}, | |||
| {kAdamEps, 8}, | |||
| {kGradient, 9}}}, | |||
| {"outputs", {}}}; | |||
| const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs", | |||
| {{kWeight, 0}, | |||
| {kAdamM, 1}, | |||
| {kAdamV, 2}, | |||
| {kAdamBeta1Power, 3}, | |||
| {kAdamBeta2Power, 4}, | |||
| {kLearningRate, 5}, | |||
| {kAdamBeta1, 6}, | |||
| {kAdamBeta1, 7}, | |||
| {kAdamEps, 8}, | |||
| {kGradient, 9}, | |||
| {kIndices, 10}}}, | |||
| {"outputs", {}}}; | |||
| const OptimParamNameToIndex kSparseFtrlNameToIdx = { | |||
| {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}}; | |||
| const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = { | |||
| {kApplyMomentumOpName, kMomentumNameToIdx}, | |||
| {kFusedSparseAdamName, kSparseAdamNameToIdx}, | |||
| {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, | |||
| {kApplyAdamOpName, kAdamNameToIdx}, | |||
| }; | |||
| constexpr uint32_t kLeaderServerRank = 0; | |||
| constexpr size_t kWorkerMgrThreadPoolSize = 32; | |||
| constexpr size_t kWorkerMgrMaxTaskNum = 64; | |||
| constexpr size_t kCipherMgrThreadPoolSize = 32; | |||
| constexpr size_t kCipherMgrMaxTaskNum = 64; | |||
| constexpr size_t kExecutorThreadPoolSize = 32; | |||
| constexpr size_t kExecutorMaxTaskNum = 32; | |||
| constexpr int kHttpSuccess = 200; | |||
| constexpr auto kPBProtocol = "PB"; | |||
| constexpr auto kFBSProtocol = "FBS"; | |||
| constexpr auto kAggregationKernelType = "Aggregation"; | |||
| constexpr auto kOptimizerKernelType = "Optimizer"; | |||
| constexpr auto kCtxFuncGraph = "FuncGraph"; | |||
| constexpr auto kCtxIterNum = "iteration"; | |||
| constexpr auto kCtxDeviceMetas = "device_metas"; | |||
| constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; | |||
| constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; | |||
| constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; | |||
| // This macro the current timestamp in milliseconds. | |||
| #define CURRENT_TIME_MILLI \ | |||
| std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) | |||
| #define RETURN_IF_NULL(expr, ret) \ | |||
| if (expr == nullptr) { \ | |||
| MS_LOG(ERROR) << #expr << " is nullptr."; \ | |||
| return ret; \ | |||
| } | |||
| // This method returns the size in bytes of the given TypeId. | |||
| inline size_t GetTypeIdByte(const TypeId &type) { | |||
| switch (type) { | |||
| case kNumberTypeFloat16: | |||
| return 2; | |||
| case kNumberTypeUInt32: | |||
| case kNumberTypeFloat32: | |||
| return 4; | |||
| case kNumberTypeUInt64: | |||
| return 8; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "TypeId " << type << " not supported."; | |||
| return 0; | |||
| } | |||
| } | |||
| inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) { | |||
| RETURN_IF_NULL(kernel_node, nullptr); | |||
| auto param_node = | |||
| AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>(); | |||
| RETURN_IF_NULL(param_node, nullptr); | |||
| auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>(); | |||
| RETURN_IF_NULL(param_tensor, nullptr); | |||
| AddressPtr addr = std::make_shared<kernel::Address>(); | |||
| addr->addr = param_tensor->data_c(); | |||
| addr->size = param_tensor->data().nbytes(); | |||
| return addr; | |||
| } | |||
| // Definitions for Federated Learning. | |||
| // Definitions for Parameter Server. | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_ | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/consistent_hash_ring.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| bool ConsistentHashRing::Insert(uint32_t rank) { | |||
| std::string physical_node_hash_key = std::to_string(rank); | |||
| for (uint32_t i = 0; i < virtual_node_num_; i++) { | |||
| physical_node_hash_key += "#" + std::to_string(i); | |||
| MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank; | |||
| size_t hash_value = std::hash<std::string>()(physical_node_hash_key); | |||
| if (ring_.count(hash_value) != 0) { | |||
| MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring."; | |||
| continue; | |||
| } | |||
| ring_[hash_value] = rank; | |||
| } | |||
| return true; | |||
| } | |||
| bool ConsistentHashRing::Erase(uint32_t rank) { | |||
| for (auto iterator = ring_.begin(); iterator != ring_.end();) { | |||
| if (iterator->second == rank) { | |||
| ring_.erase(iterator++); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| uint32_t ConsistentHashRing::Find(const std::string &key) { | |||
| size_t hash_value = std::hash<std::string>()(key); | |||
| auto iterator = ring_.lower_bound(hash_value); | |||
| if (iterator == ring_.end()) { | |||
| // If the virtual node is not found clockwise, the key will be mapped to the first virtual node on the ring. | |||
| iterator = ring_.begin(); | |||
| } | |||
| return iterator->second; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in | |||
| // server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node. | |||
| // Class ConsistentHashRing implements the algorithm described in the paper | |||
| // <https://dl.acm.org/doi/pdf/10.1145/258533.258660>. | |||
| // This class will create a ring for hash values of metadata and server nodes. Each server could use this ring to | |||
| // retrieve data stored in other servers according to the hash keys. The time complexity for adding/deleting/searching | |||
| // of this algorithm is basically O(log n). | |||
| class ConsistentHashRing { | |||
| public: | |||
| // The parameter virtual_node_num for constructor means the virtual node number to be created for each physical server | |||
| // node. According to the paper, these virtual nodes could help spread data to all the servers and ensuring balancing | |||
| // at the same time. And when we say "adding/deleting/searching", we are talking about operations on thease virtual | |||
| // nodes instead of the physical nodes. | |||
| explicit ConsistentHashRing(uint32_t virtual_node_num = 128) : virtual_node_num_(virtual_node_num) {} | |||
| ~ConsistentHashRing() = default; | |||
| // Insert several virtual nodes for a server into this ring according to its rank id. | |||
| bool Insert(uint32_t rank); | |||
| // Remove virtual nodes for a server according to its rank id. | |||
| bool Erase(uint32_t rank); | |||
| // Find the physical server node's rank according to the metadata's key. | |||
| uint32_t Find(const std::string &key); | |||
| private: | |||
| uint32_t virtual_node_num_; | |||
| // The hash ring for the server nodes. | |||
| // Key is the hash value of the virtual node. | |||
| // Value is the physical node' rank id. | |||
| std::map<size_t, uint32_t> ring_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_ | |||
| @@ -0,0 +1,315 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/executor.h" | |||
| #include <set> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void Executor::Init(const FuncGraphPtr &func_graph, size_t aggregation_count) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| if (aggregation_count == 0) { | |||
| MS_LOG(EXCEPTION) << "Server aggregation count must be greater than 0"; | |||
| return; | |||
| } | |||
| aggregation_count_ = aggregation_count; | |||
| // Initialize each trainable parameter's aggregator, including memory register, aggregation algorithms and optimizers. | |||
| bool ret = InitParamAggregator(func_graph); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Initializing parameter aggregators failed."; | |||
| return; | |||
| } | |||
| initialized_ = true; | |||
| return; | |||
| } | |||
| bool Executor::initialized() const { return initialized_; } | |||
| bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) { | |||
| MS_LOG(DEBUG) << "Do Push for parameter " << param_name; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | |||
| return false; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| // Push operation needs to wait until the pulling process is done. | |||
| while (!param_aggr->IsPullingDone()) { | |||
| lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| lock.lock(); | |||
| } | |||
| // 1.Update data with the uploaded data of the worker. | |||
| if (!param_aggr->UpdateData(upload_data)) { | |||
| MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| // 2.Launch aggregation for this trainable parameter. | |||
| if (!param_aggr->LaunchAggregators()) { | |||
| MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| if (param_aggr->IsAggregationDone()) { | |||
| // 3.After the aggregation is done, optimize the trainable parameter. | |||
| if (!param_aggr->LaunchOptimizers()) { | |||
| MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| // 4.Reset pulling and aggregation status after optimizing is done. | |||
| param_aggr->ResetPullingStatus(); | |||
| param_aggr->ResetAggregationStatus(); | |||
| } | |||
| return true; | |||
| } | |||
| bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) { | |||
| MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| // The param_name could include some other parameters like momentum, but we don't think it's invalid. So here we | |||
| // just print a warning log and return true. | |||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | |||
| return true; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| if (!param_aggr->UpdateData(upload_data)) { | |||
| MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| // Different from Push, UpdateModel doesn't need to checkout the aggregation status. | |||
| if (!param_aggr->LaunchAggregators()) { | |||
| MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) { | |||
| std::unique_lock<std::mutex> model_lock(model_mutex_); | |||
| for (const auto &trainable_param : feature_map) { | |||
| const std::string ¶m_name = trainable_param.first; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | |||
| continue; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| const UploadData &upload_data = trainable_param.second; | |||
| if (!param_aggr->UpdateData(upload_data)) { | |||
| MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| if (!param_aggr->LaunchAggregators()) { | |||
| MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map) { | |||
| for (const auto &trainable_param : feature_map) { | |||
| const std::string ¶m_name = trainable_param.first; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(WARNING) << "Weight " << param_name << " is not registered in server."; | |||
| continue; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| AddressPtr old_weight = param_aggr->GetWeight(); | |||
| if (old_weight == nullptr) { | |||
| MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; | |||
| return false; | |||
| } | |||
| const Address &new_weight = trainable_param.second; | |||
| if (new_weight.addr == nullptr) { | |||
| MS_LOG(ERROR) << "The new weight is nullptr."; | |||
| return false; | |||
| } | |||
| int ret = memcpy_s(old_weight->addr, old_weight->size, new_weight.addr, new_weight.size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||
| MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | |||
| return nullptr; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| auto ¶m_aggr = param_aggrs_[param_name]; | |||
| // Pulling must wait until the optimizing process is done. | |||
| while (!param_aggr->IsOptimizingDone()) { | |||
| lock.unlock(); | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| lock.lock(); | |||
| } | |||
| AddressPtr addr = param_aggr->Pull(); | |||
| // If this Pull is the last one, reset pulling and optimizing status. | |||
| if (param_aggr->IsPullingDone()) { | |||
| param_aggr->ResetOptimizingStatus(); | |||
| } | |||
| return addr; | |||
| } | |||
| std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() { | |||
| std::unique_lock<std::mutex> lock(model_mutex_); | |||
| return GetModel(); | |||
| } | |||
| std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) { | |||
| std::map<std::string, AddressPtr> weights; | |||
| for (const auto ¶m_name : param_names) { | |||
| if (param_aggrs_.count(param_name) == 0) { | |||
| MS_LOG(ERROR) << "Parameter " << param_name << " is not registered in server."; | |||
| return weights; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| const auto ¶m_aggr = param_aggrs_[param_name]; | |||
| AddressPtr addr = param_aggr->GetWeight(); | |||
| if (addr == nullptr) { | |||
| MS_LOG(ERROR) << "Get weight of " << param_name << " failed: the AddressPtr is nullptr."; | |||
| continue; | |||
| } | |||
| weights[param_name] = addr; | |||
| } | |||
| return weights; | |||
| } | |||
| bool Executor::IsAllWeightAggregationDone() { return IsWeightAggrDone(param_names_); } | |||
| bool Executor::IsWeightAggrDone(const std::vector<std::string> ¶m_names) { | |||
| for (const auto &name : param_names) { | |||
| if (param_aggrs_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Weight " << name << " is invalid in server."; | |||
| return false; | |||
| } | |||
| std::mutex &mtx = parameter_mutex_[name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| if (!param_aggrs_[name]->IsAggregationDone()) { | |||
| MS_LOG(DEBUG) << "Update model for " << name << " is not done yet."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void Executor::ResetAggregationStatus() { | |||
| for (const auto ¶m_name : param_names_) { | |||
| std::mutex &mtx = parameter_mutex_[param_name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| param_aggrs_[param_name]->ResetAggregationStatus(); | |||
| } | |||
| return; | |||
| } | |||
| std::map<std::string, AddressPtr> Executor::GetModel() { | |||
| std::map<std::string, AddressPtr> model = {}; | |||
| for (const auto &name : param_names_) { | |||
| std::mutex &mtx = parameter_mutex_[name]; | |||
| std::unique_lock<std::mutex> lock(mtx); | |||
| AddressPtr addr = param_aggrs_[name]->GetWeight(); | |||
| if (addr == nullptr) { | |||
| MS_LOG(WARNING) << "Get weight of " << name << " failed."; | |||
| continue; | |||
| } | |||
| model[name] = addr; | |||
| } | |||
| return model; | |||
| } | |||
| // bool Executor::Unmask() { | |||
| // auto model = GetModel(); | |||
| // return mindarmour::CipherMgr::GetInstance().UnMask(model); | |||
| // } | |||
| const std::vector<std::string> &Executor::param_names() const { return param_names_; } | |||
| std::string Executor::GetTrainableParamName(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (kNameToIdxMap.count(cnode_name) == 0) { | |||
| return ""; | |||
| } | |||
| const OptimParamNameToIndex &index_info = kNameToIdxMap.at(cnode_name); | |||
| size_t weight_idx = index_info.at("inputs").at(kWeight); | |||
| AnfNodePtr weight_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, weight_idx), 0).first; | |||
| MS_EXCEPTION_IF_NULL(weight_node); | |||
| if (!weight_node->isa<Parameter>()) { | |||
| MS_LOG(EXCEPTION) << weight_idx << " input of " << cnode_name << " is not a Parameter."; | |||
| } | |||
| return weight_node->fullname_with_scope(); | |||
| } | |||
| bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| const auto &cnodes = func_graph->GetOrderedCnodes(); | |||
| for (const auto &cnode : cnodes) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const std::string ¶m_name = GetTrainableParamName(cnode); | |||
| if (param_name.empty()) { | |||
| continue; | |||
| } | |||
| if (param_aggrs_.count(param_name) != 0) { | |||
| MS_LOG(WARNING) << param_name << " already has its control flow."; | |||
| continue; | |||
| } | |||
| std::shared_ptr<ParameterAggregator> param_aggr = std::make_shared<ParameterAggregator>(); | |||
| MS_EXCEPTION_IF_NULL(param_aggr); | |||
| param_names_.push_back(param_name); | |||
| param_aggrs_[param_name] = param_aggr; | |||
| parameter_mutex_[param_name]; | |||
| param_aggr->Init(cnode, aggregation_count_); | |||
| MS_LOG(DEBUG) << "Initializing control flow for param_name " << param_name << " success."; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ | |||
| #include <map> | |||
| #include <set> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include <condition_variable> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/parameter_aggregator.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles | |||
| // logics relevant to kernel launching. | |||
| class Executor { | |||
| public: | |||
| static Executor &GetInstance() { | |||
| static Executor instance; | |||
| return instance; | |||
| } | |||
| // FuncGraphPtr func_graph is the graph compiled by the frontend. aggregation_count is the number which will | |||
| // be used for aggregators. | |||
| // As noted in header file parameter_aggregator.h, we create aggregators by trainable parameters, which is the | |||
| // optimizer cnode's input. So we need to initialize server executor using func_graph. | |||
| void Init(const FuncGraphPtr &func_graph, size_t aggregation_count); | |||
| // Called in parameter server training mode to do Push operation. | |||
| // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered | |||
| // as completed. | |||
| bool HandlePush(const std::string ¶m_name, const UploadData &upload_data); | |||
| // Called in parameter server training mode to do Pull operation. | |||
| // Returns the value of parameter param_name. | |||
| // HandlePull method must be called the same times as HandlePush is called before it's considered as | |||
| // completed. | |||
| AddressPtr HandlePull(const std::string ¶m_name); | |||
| // Called in federated learning training mode. Update value for parameter param_name. | |||
| bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data); | |||
| // Called in asynchronous federated learning training mode. Update current model with the new feature map | |||
| // asynchronously. | |||
| bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); | |||
| // Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the | |||
| // parameter name. | |||
| std::map<std::string, AddressPtr> HandleAsyncGetModel(); | |||
| // Forcibly overwrite specific weights in overwriteWeights message. | |||
| bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map); | |||
| // Returns value for multiple trainable parameters passed by weight_names. | |||
| std::map<std::string, AddressPtr> HandleGetWeightsByKey(const std::vector<std::string> ¶m_names); | |||
| // Reset the aggregation status for all aggregation kernels in the server. | |||
| void ResetAggregationStatus(); | |||
| // Judge whether aggregation processes for all weights/gradients are completed. | |||
| bool IsAllWeightAggregationDone(); | |||
| // Judge whether the aggregation processes for the given param_names are completed. | |||
| bool IsWeightAggrDone(const std::vector<std::string> ¶m_names); | |||
| // Returns whole model in key-value where key refers to the parameter name. | |||
| std::map<std::string, AddressPtr> GetModel(); | |||
| // Returns whether the executor singleton is already initialized. | |||
| bool initialized() const; | |||
| const std::vector<std::string> ¶m_names() const; | |||
| private: | |||
| Executor() {} | |||
| ~Executor() = default; | |||
| Executor(const Executor &) = delete; | |||
| Executor &operator=(const Executor &) = delete; | |||
| // Returns the trainable parameter name parsed from this cnode. | |||
| std::string GetTrainableParamName(const CNodePtr &cnode); | |||
| // Server's graph is basically the same as Worker's graph, so we can get all information from func_graph for later | |||
| // computations. Including forward and backward propagation, aggregation, optimizing, etc. | |||
| bool InitParamAggregator(const FuncGraphPtr &func_graph); | |||
| bool initialized_; | |||
| size_t aggregation_count_; | |||
| std::vector<std::string> param_names_; | |||
| // The map for trainable parameter names and its ParameterAggregator, as noted in the header file | |||
| // parameter_aggregator.h | |||
| std::map<std::string, std::shared_ptr<ParameterAggregator>> param_aggrs_; | |||
| // The mutex ensures that the operation on whole model is threadsafe. | |||
| // The whole model is constructed by all trainable parameters. | |||
| std::mutex model_mutex_; | |||
| // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can | |||
| // acquire lock before calling its method. | |||
| std::map<std::string, std::mutex> parameter_mutex_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/iteration_timer.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void IterationTimer::Start(const std::chrono::milliseconds &duration) { | |||
| if (running_.load()) { | |||
| MS_LOG(WARNING) << "The timer already started."; | |||
| return; | |||
| } | |||
| running_ = true; | |||
| end_time_ = CURRENT_TIME_MILLI + duration; | |||
| monitor_thread_ = std::thread([&]() { | |||
| while (running_.load()) { | |||
| if (CURRENT_TIME_MILLI > end_time_) { | |||
| timeout_callback_(); | |||
| running_ = false; | |||
| } | |||
| // The time tick is 1 millisecond. | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(1)); | |||
| } | |||
| }); | |||
| monitor_thread_.detach(); | |||
| } | |||
| void IterationTimer::Stop() { running_ = false; } | |||
| void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { | |||
| timeout_callback_ = timeout_cb; | |||
| return; | |||
| } | |||
| bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) { | |||
| return timestamp > end_time_ ? true : false; | |||
| } | |||
| bool IterationTimer::IsRunning() { return running_; } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ | |||
| #include <chrono> | |||
| #include <atomic> | |||
| #include <thread> | |||
| #include <functional> | |||
| #include "ps/server/common.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration. | |||
| class IterationTimer { | |||
| public: | |||
| IterationTimer() : running_(false), end_time_(0) {} | |||
| ~IterationTimer() = default; | |||
| // Start timing. The timer will stop after parameter 'duration' milliseconds. | |||
| void Start(const std::chrono::milliseconds &duration); | |||
| // Caller could use this method to manually stop timing, otherwise the timer will keep timing until it expires. | |||
| void Stop(); | |||
| // Set the callback which will be called when the timer expires. | |||
| void SetTimeOutCallBack(const TimeOutCb &timeout_cb); | |||
| // Judge whether current timestamp is out of time window's range since the Start function is called. | |||
| bool IsTimeOut(const std::chrono::milliseconds ×tamp); | |||
| // Judge whether the timer is keeping timing. | |||
| bool IsRunning(); | |||
| private: | |||
| // The running state for the timer. | |||
| std::atomic<bool> running_; | |||
| // The timestamp in millesecond at which the timer should stop timing. | |||
| std::chrono::milliseconds end_time_; | |||
| // The thread that keeps timing and call timeout_callback_ when the timer expires. | |||
| std::thread monitor_thread_; | |||
| TimeOutCb timeout_callback_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_ | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/memory_register.h" | |||
| #include "ps/server/kernel/params_info.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| // AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation. | |||
| // For example, dense gradients accumulation, federated average, etc. | |||
| // Normally the aggregation process in AggregationKernel is like a finite-state machine: | |||
| // Initial->Aggregating->Aggregation done->Initial. | |||
| class AggregationKernel : public CPUKernel { | |||
| public: | |||
| AggregationKernel() : name_(""), done_(false), done_count_(0), accum_count_(0) {} | |||
| virtual ~AggregationKernel() = default; | |||
| // InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation. | |||
| virtual void InitKernel(const CNodePtr &kernel_node) {} | |||
| virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| return true; | |||
| } | |||
| // Server kernel's memory allocation method, which is different from the workflow in | |||
| // Session(GPUSession/CPUSession/AscendSession). | |||
| // virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr<MemoryRegister> memory_register) = 0; | |||
| // Set the cumulative count this aggregation kernel needs before aggregation is done. | |||
| void set_done_count(size_t count) { done_count_ = count; } | |||
| // So we use Reset to set the finite-state machine state to Initial after considering this round of aggregation is | |||
| // done. | |||
| virtual void Reset() = 0; | |||
| virtual bool IsAggregationDone() = 0; | |||
| // Setter and getter of kernels parameters information. | |||
| void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } | |||
| const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } | |||
| const std::vector<std::string> &workspace_names() { return params_info_.workspace_names(); } | |||
| const std::vector<std::string> &output_names() { return params_info_.outputs_names(); } | |||
| // Returns information about whether some inputs should reuse kernel node inputs memory. | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; } | |||
| protected: | |||
| virtual void GenerateReuseKernelNodeInfo() = 0; | |||
| // Aggregation kernel's name which is set by kernel register function. | |||
| std::string name_; | |||
| // The aggregation is considered done after done_count_ times of accumulation. | |||
| bool done_; | |||
| // Cumulative count this aggregation kernel needs before aggregation is done. | |||
| size_t done_count_; | |||
| // Current cumulative count. | |||
| size_t accum_count_; | |||
| // Parameters information used for kernel register, memory assignment, etc. | |||
| ParamsInfo params_info_; | |||
| // Information about server kernel reusing kernel node inputs memory from the front end. | |||
| // Key refers to the server kernel's input index. Value refers to the kernel node's input index. | |||
| ReuseKernelNodeInfo reuse_kernel_node_inputs_info_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_ | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/aggregation_kernel_factory.h" | |||
| #include <utility> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kNameToIdxMap.count(cnode_name) == 0) { | |||
| MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; | |||
| return false; | |||
| } | |||
| auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs"); | |||
| size_t input_num = params_info.inputs_num(); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto one_input_name_type = params_info.inputs_name_type(i); | |||
| std::string name = one_input_name_type.first; | |||
| if (input_name_to_idx.count(name) == 0) { | |||
| MS_LOG(DEBUG) << cnode_name << " does not have input named " << name | |||
| << ". This is the new input for this aggregation kernel."; | |||
| continue; | |||
| } | |||
| size_t input_idx = input_name_to_idx.at(name); | |||
| TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx); | |||
| TypeId registered_input_type = one_input_name_type.second; | |||
| if (registered_input_type != kernel_node_input_type) { | |||
| return false; | |||
| } | |||
| } | |||
| auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs"); | |||
| size_t output_num = params_info.outputs_num(); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| auto one_output_name_type = params_info.outputs_name_type(i); | |||
| std::string name = one_output_name_type.first; | |||
| if (output_name_to_idx.count(name) == 0) { | |||
| MS_LOG(DEBUG) << cnode_name << " does not have output named " << name | |||
| << ". This is the new output for this aggregation kernel."; | |||
| continue; | |||
| } | |||
| size_t output_idx = output_name_to_idx.at(name); | |||
| TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx); | |||
| TypeId registered_output_type = one_output_name_type.second; | |||
| if (registered_output_type != kernel_node_output_type) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ps/server/kernel/kernel_factory.h" | |||
| #include "ps/server/kernel/aggregation_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>; | |||
| class AggregationKernelFactory : public KernelFactory<std::shared_ptr<AggregationKernel>, AggregationKernelCreator> { | |||
| public: | |||
| static AggregationKernelFactory &GetInstance() { | |||
| static AggregationKernelFactory instance; | |||
| return instance; | |||
| } | |||
| private: | |||
| AggregationKernelFactory() = default; | |||
| ~AggregationKernelFactory() override = default; | |||
| AggregationKernelFactory(const AggregationKernelFactory &) = delete; | |||
| AggregationKernelFactory &operator=(const AggregationKernelFactory &) = delete; | |||
| // Judge whether the server aggregation kernel can be created according to registered ParamsInfo. | |||
| bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; | |||
| }; | |||
| class AggregationKernelRegister { | |||
| public: | |||
| AggregationKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, | |||
| AggregationKernelCreator &&creator) { | |||
| AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); | |||
| } | |||
| }; | |||
| // Register aggregation kernel with one template type T. | |||
| #define REG_AGGREGATION_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ | |||
| static_assert(std::is_base_of<AggregationKernel, CLASS<T>>::value, " must be base of AggregationKernel"); \ | |||
| static const AggregationKernelRegister g_##NAME##_##T##_aggregation_kernel_reg( \ | |||
| #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); }); | |||
| // Register aggregation kernel with two template types: T and S. | |||
| #define REG_AGGREGATION_KERNEL_TWO(NAME, PARAMS_INFO, CLASS, T, S) \ | |||
| static_assert(std::is_base_of<AggregationKernel, CLASS<T, S>>::value, " must be base of AggregationKernel"); \ | |||
| static const AggregationKernelRegister g_##NAME##_##T##_##S##_aggregation_kernel_reg( \ | |||
| #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); }); | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_ | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/apply_momentum_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| REG_OPTIMIZER_KERNEL(ApplyMomentum, | |||
| ParamsInfo() | |||
| .AddInputNameType(kWeight, kNumberTypeFloat32) | |||
| .AddInputNameType(kAccumulation, kNumberTypeFloat32) | |||
| .AddInputNameType(kLearningRate, kNumberTypeFloat32) | |||
| .AddInputNameType(kGradient, kNumberTypeFloat32) | |||
| .AddInputNameType(kMomentum, kNumberTypeFloat32), | |||
| ApplyMomentumKernel, float) | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" | |||
| #include "ps/server/kernel/optimizer_kernel.h" | |||
| #include "ps/server/kernel/optimizer_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| using mindspore::kernel::ApplyMomentumCPUKernel; | |||
| template <typename T> | |||
| class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKernel { | |||
| public: | |||
| ApplyMomentumKernel() = default; | |||
| ~ApplyMomentumKernel() override = default; | |||
| void InitKernel(const CNodePtr &cnode) override { | |||
| ApplyMomentumCPUKernel::InitKernel(cnode); | |||
| InitServerKernelInputOutputSize(cnode); | |||
| GenerateReuseKernelNodeInfo(); | |||
| } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| return ApplyMomentumCPUKernel::Launch(inputs, workspace, outputs); | |||
| } | |||
| void GenerateReuseKernelNodeInfo() override { | |||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0)); | |||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 1)); | |||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2)); | |||
| reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4)); | |||
| return; | |||
| } | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/dense_grad_accum_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| REG_AGGREGATION_KERNEL( | |||
| DenseGradAccum, | |||
| ParamsInfo().AddInputNameType(kGradient, kNumberTypeFloat32).AddInputNameType(kNewGradient, kNumberTypeFloat32), | |||
| DenseGradAccumKernel, float) | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "ps/server/kernel/aggregation_kernel.h" | |||
| #include "ps/server/kernel/aggregation_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class DenseGradAccumKernel : public AggregationKernel { | |||
| public: | |||
| DenseGradAccumKernel() = default; | |||
| ~DenseGradAccumKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || | |||
| kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) { | |||
| MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name; | |||
| return; | |||
| } | |||
| size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad"); | |||
| std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx); | |||
| size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies<size_t>()); | |||
| input_size_list_.push_back(grad_size); | |||
| size_t new_grad_size = grad_size; | |||
| input_size_list_.push_back(new_grad_size); | |||
| GenerateReuseKernelNodeInfo(); | |||
| return; | |||
| } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override { | |||
| if (accum_count_ == 0) { | |||
| int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| } | |||
| T *grad_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *new_grad_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { | |||
| grad_addr[i] += new_grad_addr[i]; | |||
| } | |||
| accum_count_++; | |||
| if (accum_count_ > done_count_) { | |||
| MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_; | |||
| return false; | |||
| } | |||
| if (accum_count_ == done_count_) { | |||
| for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { | |||
| grad_addr[i] /= done_count_; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void Reset() { accum_count_ = 0; } | |||
| bool IsAggregationDone() { return accum_count_ >= done_count_; } | |||
| void GenerateReuseKernelNodeInfo() override { return; } | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_ | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/kernel/params_info.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| // KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory | |||
| // and AggregationKernelFactory. | |||
| // Unlike normal MindSpore operator kernels, the server defines multiple types of kernels. For example: Aggregation | |||
| // Kernel, Optimizer Kernel, Forward Kernel, etc. So we define KernelFactory as a template class for register of all | |||
| // types of kernels. | |||
| // Because most information we need to create a server kernel is in func_graph passed by the front end, we create a | |||
| // server kernel based on a cnode. | |||
| // Typename K refers to the shared_ptr of the kernel type. | |||
| // Typename C refers to the creator function of the kernel. | |||
| template <typename K, typename C> | |||
| class KernelFactory { | |||
| public: | |||
| KernelFactory() = default; | |||
| virtual ~KernelFactory() = default; | |||
| static KernelFactory &GetInstance() { | |||
| static KernelFactory instance; | |||
| return instance; | |||
| } | |||
| // Kernels are registered by parameter information and its creator(constructor). | |||
| void Register(const std::string &name, const ParamsInfo ¶ms_info, C &&creator) { | |||
| name_to_creator_map_[name].push_back(std::make_pair(params_info, creator)); | |||
| } | |||
| // The kernels in server are created from func_graph's kernel_node passed by the front end. | |||
| K Create(const std::string &name, const CNodePtr &kernel_node) { | |||
| if (name_to_creator_map_.count(name) == 0) { | |||
| MS_LOG(ERROR) << "Creating kernel failed: " << name << " is not registered."; | |||
| } | |||
| for (const auto &name_type_creator : name_to_creator_map_[name]) { | |||
| const ParamsInfo ¶ms_info = name_type_creator.first; | |||
| const C &creator = name_type_creator.second; | |||
| if (Matched(params_info, kernel_node)) { | |||
| auto kernel = creator(); | |||
| kernel->set_params_info(params_info); | |||
| return kernel; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| private: | |||
| KernelFactory(const KernelFactory &) = delete; | |||
| KernelFactory &operator=(const KernelFactory &) = delete; | |||
| // Judge whether the server kernel can be created according to registered ParamsInfo. | |||
| virtual bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { return true; } | |||
| // Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in | |||
| // server kernel's *.cc files. | |||
| std::unordered_map<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_ | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/memory_register.h" | |||
| #include "ps/server/kernel/params_info.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| using mindspore::kernel::IsSameShape; | |||
| using mindspore::kernel::USE_NESTEROV; | |||
| // OptimizerKernel is the kernel in server for weights' optimizing. | |||
| // Normally server's optimizer kernels should be inherited from CPU's optimzier kernels to reuse the implementation. | |||
| class OptimizerKernel : public CPUKernel { | |||
| public: | |||
| OptimizerKernel() = default; | |||
| virtual ~OptimizerKernel() = default; | |||
| // InitKernel and Launch methods are inherited from pure virtual function of CPUKernel so it must have implementation. | |||
| virtual void InitKernel(const CNodePtr &kernel_node) {} | |||
| virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| return true; | |||
| } | |||
| // Server kernel's memory allocation method, which is different from the workflow in | |||
| // Session(GPUSession/CPUSession/AscendSession). | |||
| // virtual void AssignMemory(const CNodePtr &kernel_node, std::shared_ptr<MemoryRegister> memory_register) = 0; | |||
| // Setter and getter of kernels parameters information. | |||
| void set_params_info(const ParamsInfo ¶ms_info) { params_info_ = params_info; } | |||
| const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } | |||
| const std::vector<std::string> &workspace_names() { return params_info_.workspace_names(); } | |||
| const std::vector<std::string> &output_names() { return params_info_.outputs_names(); } | |||
| // Returns information about whether some inputs should reuse kernel node inputs memory. | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info() { return reuse_kernel_node_inputs_info_; } | |||
| protected: | |||
| virtual void GenerateReuseKernelNodeInfo() = 0; | |||
| void InitServerKernelInputOutputSize(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t type_size = sizeof(float); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index); | |||
| size_t tensor_size = | |||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| input_size_list_.emplace_back(tensor_size); | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| for (size_t output_index = 0; output_index < output_num; ++output_index) { | |||
| std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, output_index); | |||
| size_t tensor_size = | |||
| shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||
| output_size_list_.emplace_back(tensor_size); | |||
| } | |||
| } | |||
| // Parameters information used for kernel register, memory assignment, etc. | |||
| ParamsInfo params_info_; | |||
| // Information about server kernel reusing kernel node inputs memory from the front end. | |||
| // Key refers to the server kernel's input index. Value refers to the kernel node's input index. | |||
| ReuseKernelNodeInfo reuse_kernel_node_inputs_info_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/optimizer_kernel_factory.h" | |||
| #include <utility> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { | |||
| std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kNameToIdxMap.count(cnode_name) == 0) { | |||
| MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; | |||
| return false; | |||
| } | |||
| auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs"); | |||
| size_t input_num = params_info.inputs_num(); | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto one_input_name_type = params_info.inputs_name_type(i); | |||
| std::string name = one_input_name_type.first; | |||
| if (input_name_to_idx.count(name) == 0) { | |||
| MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name; | |||
| return false; | |||
| } | |||
| size_t input_idx = input_name_to_idx.at(name); | |||
| TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx); | |||
| TypeId registered_input_type = one_input_name_type.second; | |||
| if (registered_input_type != kernel_node_input_type) { | |||
| return false; | |||
| } | |||
| } | |||
| auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs"); | |||
| size_t output_num = params_info.outputs_num(); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| auto one_output_name_type = params_info.outputs_name_type(i); | |||
| std::string name = one_output_name_type.first; | |||
| if (output_name_to_idx.count(name) == 0) { | |||
| MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name; | |||
| return false; | |||
| } | |||
| size_t output_idx = output_name_to_idx.at(name); | |||
| TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx); | |||
| TypeId registered_output_type = one_output_name_type.second; | |||
| if (registered_output_type != kernel_node_output_type) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ps/server/kernel/kernel_factory.h" | |||
| #include "ps/server/kernel/optimizer_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>; | |||
| class OptimizerKernelFactory : public KernelFactory<std::shared_ptr<OptimizerKernel>, OptimizerKernelCreator> { | |||
| public: | |||
| static OptimizerKernelFactory &GetInstance() { | |||
| static OptimizerKernelFactory instance; | |||
| return instance; | |||
| } | |||
| private: | |||
| OptimizerKernelFactory() = default; | |||
| ~OptimizerKernelFactory() override = default; | |||
| OptimizerKernelFactory(const OptimizerKernelFactory &) = delete; | |||
| OptimizerKernelFactory &operator=(const OptimizerKernelFactory &) = delete; | |||
| // Judge whether the server optimizer kernel can be created according to registered ParamsInfo. | |||
| bool Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) override; | |||
| }; | |||
| class OptimizerKernelRegister { | |||
| public: | |||
| OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) { | |||
| OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator)); | |||
| } | |||
| }; | |||
| // Register optimizer kernel with one template type T. | |||
| #define REG_OPTIMIZER_KERNEL(NAME, PARAMS_INFO, CLASS, T) \ | |||
| static_assert(std::is_base_of<OptimizerKernel, CLASS<T>>::value, " must be base of OptimizerKernel"); \ | |||
| static const OptimizerKernelRegister g_##NAME##_##T##_optimizer_kernel_reg( \ | |||
| #NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); }); | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_ | |||
| @@ -0,0 +1,68 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/kernel/params_info.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) { | |||
| inputs_name_type_.push_back(std::make_pair(name, type)); | |||
| inputs_names_.push_back(name); | |||
| return *this; | |||
| } | |||
| ParamsInfo &ParamsInfo::AddWorkspaceNameType(const std::string &name, TypeId type) { | |||
| workspaces_name_type_.push_back(std::make_pair(name, type)); | |||
| workspace_names_.push_back(name); | |||
| return *this; | |||
| } | |||
| ParamsInfo &ParamsInfo::AddOutputNameType(const std::string &name, TypeId type) { | |||
| outputs_name_type_.push_back(std::make_pair(name, type)); | |||
| outputs_names_.push_back(name); | |||
| return *this; | |||
| } | |||
| size_t ParamsInfo::inputs_num() const { return inputs_name_type_.size(); } | |||
| size_t ParamsInfo::outputs_num() const { return outputs_name_type_.size(); } | |||
| const std::pair<std::string, TypeId> &ParamsInfo::inputs_name_type(size_t index) const { | |||
| if (index >= inputs_name_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of inputs_name_type_."; | |||
| } | |||
| return inputs_name_type_[index]; | |||
| } | |||
| const std::pair<std::string, TypeId> &ParamsInfo::outputs_name_type(size_t index) const { | |||
| if (index >= outputs_name_type_.size()) { | |||
| MS_LOG(EXCEPTION) << "Index " << index << " is out of bound of outputs_name_type_."; | |||
| } | |||
| return outputs_name_type_[index]; | |||
| } | |||
| const std::vector<std::string> &ParamsInfo::inputs_names() const { return inputs_names_; } | |||
| const std::vector<std::string> &ParamsInfo::workspace_names() const { return workspace_names_; } | |||
| const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; } | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ | |||
| #include <utility> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/dtype/type_id.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| namespace kernel { | |||
| // ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc. | |||
| // Register of a server kernel needs every inputs/workspace/outputs parameters' name and type. | |||
| // For example: | |||
| // ParamsInfo() | |||
| // .AddInputNameType("input1_name", kNumberTypeFloat32) | |||
| // .AddInputNameType("input2_name", kNumberTypeUInt64) | |||
| // .AddWorkspaceNameType("workspace1_name", kNumberTypeFloat32) | |||
| // .AddOutputNameType("output1_name", kNumberTypeUInt64) | |||
| // This invocation describes a server kernel with parameters below: | |||
| // An input with name "input1_name" and type float32. | |||
| // An input with name "input1_name" and type uint_64. | |||
| // A workspace with name "workspace1_name" and type float32. | |||
| // An output with name "output1_name" and type float32. | |||
| class ParamsInfo { | |||
| public: | |||
| ParamsInfo() = default; | |||
| ~ParamsInfo() = default; | |||
| ParamsInfo &AddInputNameType(const std::string &name, TypeId type); | |||
| ParamsInfo &AddWorkspaceNameType(const std::string &name, TypeId type); | |||
| ParamsInfo &AddOutputNameType(const std::string &name, TypeId type); | |||
| size_t inputs_num() const; | |||
| size_t outputs_num() const; | |||
| const std::pair<std::string, TypeId> &inputs_name_type(size_t index) const; | |||
| const std::pair<std::string, TypeId> &outputs_name_type(size_t index) const; | |||
| const std::vector<std::string> &inputs_names() const; | |||
| const std::vector<std::string> &workspace_names() const; | |||
| const std::vector<std::string> &outputs_names() const; | |||
| private: | |||
| std::vector<std::pair<std::string, TypeId>> inputs_name_type_; | |||
| std::vector<std::pair<std::string, TypeId>> workspaces_name_type_; | |||
| std::vector<std::pair<std::string, TypeId>> outputs_name_type_; | |||
| std::vector<std::string> inputs_names_; | |||
| std::vector<std::string> workspace_names_; | |||
| std::vector<std::string> outputs_names_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_ | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/local_meta_storage.h" | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void LocalMetaStorage::remove_value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| if (key_to_meta_.count(name) != 0) { | |||
| key_to_meta_.erase(key_to_meta_.find(name)); | |||
| } | |||
| } | |||
| bool LocalMetaStorage::has_value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| return key_to_meta_.count(name) != 0; | |||
| } | |||
| void LocalMetaStorage::set_curr_iter_num(size_t num) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| curr_iter_num_ = num; | |||
| } | |||
| const size_t LocalMetaStorage::curr_iter_num() { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| return curr_iter_num_; | |||
| } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| #include <any> | |||
| #include <mutex> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "ps/server/common.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // LocalMetaStorage class is used for metadata storage of this server process. | |||
| // For example, the current iteration number, time windows for round kernels, etc. | |||
| // LocalMetaStorage is threadsafe. | |||
| class LocalMetaStorage { | |||
| public: | |||
| static LocalMetaStorage &GetInstance() { | |||
| static LocalMetaStorage instance; | |||
| return instance; | |||
| } | |||
| template <typename T> | |||
| void put_value(const std::string &name, const T &value) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| key_to_meta_[name] = value; | |||
| } | |||
| template <typename T> | |||
| const T &value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| try { | |||
| T value = std::any_cast<T>(key_to_meta_[name]); | |||
| return value; | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "Value of " << name << " is not set."; | |||
| } | |||
| } | |||
| // This method returns a reference so that user can change this value without calling put_value. | |||
| template <typename T> | |||
| T &mutable_value(const std::string &name) { | |||
| std::unique_lock<std::mutex> lock(mtx_); | |||
| try { | |||
| return std::any_cast<T &>(key_to_meta_[name]); | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(EXCEPTION) << "Value of " << name << " is not set."; | |||
| } | |||
| } | |||
| void remove_value(const std::string &name); | |||
| bool has_value(const std::string &name); | |||
| void set_curr_iter_num(size_t num); | |||
| const size_t curr_iter_num(); | |||
| private: | |||
| LocalMetaStorage() = default; | |||
| ~LocalMetaStorage() = default; | |||
| LocalMetaStorage(const LocalMetaStorage &) = delete; | |||
| LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; | |||
| // key_to_meta_ stores metadata with key-value format. | |||
| std::unordered_map<std::string, std::any> key_to_meta_; | |||
| // This mutex makes sure that the operations on key_to_meta_ is threadsafe. | |||
| std::mutex mtx_; | |||
| size_t curr_iter_num_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/memory_register.h" | |||
| #include <utility> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) { | |||
| addresses_.try_emplace(name, address); | |||
| } | |||
| void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } | |||
| void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <typeinfo> | |||
| #include "ps/server/common.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc. | |||
| // MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights", | |||
| // etc) and value is AddressPtr. | |||
| class MemoryRegister { | |||
| public: | |||
| MemoryRegister() = default; | |||
| ~MemoryRegister() = default; | |||
| std::map<std::string, AddressPtr> &addresses() { return addresses_; } | |||
| void RegisterAddressPtr(const std::string &name, const AddressPtr &address); | |||
| // In some cases, memory is passed by unique_ptr which is allocated by caller. They needs to be stored as well to | |||
| // avoid its data being released. | |||
| template <typename T> | |||
| void RegisterArray(const std::string &name, std::unique_ptr<T[]> *array, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(array); | |||
| void *data = array->get(); | |||
| AddressPtr addr = std::make_shared<Address>(); | |||
| addr->addr = data; | |||
| addr->size = size; | |||
| if (typeid(T) == typeid(int)) { | |||
| auto int_arr = CastUniquePtr<int, T>(array); | |||
| StoreInt32Array(&int_arr); | |||
| } else if (typeid(T) == typeid(float)) { | |||
| auto float_arr = CastUniquePtr<float, T>(array); | |||
| StoreFloatArray(&float_arr); | |||
| } else if (typeid(T) == typeid(size_t)) { | |||
| auto uint64_arr = CastUniquePtr<size_t, T>(array); | |||
| StoreUint64Array(&uint64_arr); | |||
| } else { | |||
| MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); | |||
| return; | |||
| } | |||
| RegisterAddressPtr(name, addr); | |||
| return; | |||
| } | |||
| private: | |||
| std::map<std::string, AddressPtr> addresses_; | |||
| std::vector<std::unique_ptr<float[]>> float_arrays_; | |||
| std::vector<std::unique_ptr<int[]>> int32_arrays_; | |||
| std::vector<std::unique_ptr<size_t[]>> uint64_arrays_; | |||
| void StoreInt32Array(std::unique_ptr<int[]> *array); | |||
| void StoreFloatArray(std::unique_ptr<float[]> *array); | |||
| void StoreUint64Array(std::unique_ptr<size_t[]> *array); | |||
| template <typename T, typename S> | |||
| std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { | |||
| return std::unique_ptr<T[]>{reinterpret_cast<T *>(array->release())}; | |||
| } | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_ | |||
| @@ -0,0 +1,321 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/server/parameter_aggregator.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <algorithm> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| memory_register_ = std::make_shared<MemoryRegister>(); | |||
| MS_EXCEPTION_IF_NULL(memory_register_); | |||
| required_push_count_ = required_count; | |||
| // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. | |||
| // required_pull_count_ normally used in parameter server training mode. | |||
| required_pull_count_ = required_count; | |||
| MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); | |||
| InitAggregationKernels(cnode); | |||
| InitOptimizerKernels(cnode); | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::UpdateData(const std::map<std::string, Address> &new_data) { | |||
| std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses(); | |||
| for (const auto &data : new_data) { | |||
| const std::string &name = data.first; | |||
| if (name_to_addr.count(name) == 0) { | |||
| continue; | |||
| } | |||
| MS_LOG(DEBUG) << "Update data for " << name << ". Destination size: " << name_to_addr[name]->size | |||
| << ". Source size: " << data.second.size; | |||
| int ret = memcpy_s(name_to_addr[name]->addr, name_to_addr[name]->size, data.second.addr, data.second.size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::LaunchAggregators() { | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| KernelParams ¶ms = aggregator_with_params.second; | |||
| std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first; | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| bool ret = aggr_kernel->Launch(params.inputs, params.workspace, params.outputs); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launching aggregation kernel " << typeid(aggr_kernel.get()).name() << " failed."; | |||
| continue; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::LaunchOptimizers() { | |||
| for (auto &optimizer_with_params : optimizer_kernel_parameters_) { | |||
| KernelParams ¶ms = optimizer_with_params.second; | |||
| std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first; | |||
| RETURN_IF_NULL(optimizer_kernel, false); | |||
| bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed."; | |||
| continue; | |||
| } | |||
| } | |||
| // As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done. | |||
| optimizing_done_ = true; | |||
| return true; | |||
| } | |||
| AddressPtr ParameterAggregator::Pull() { | |||
| if (memory_register_ == nullptr) { | |||
| MS_LOG(ERROR) | |||
| << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first."; | |||
| return nullptr; | |||
| } | |||
| current_pull_count_++; | |||
| if (current_pull_count_ == required_pull_count_) { | |||
| pulling_done_ = true; | |||
| } | |||
| MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_; | |||
| std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses(); | |||
| return name_to_addr["weight"]; | |||
| } | |||
| AddressPtr ParameterAggregator::GetWeight() { | |||
| if (memory_register_ == nullptr) { | |||
| MS_LOG(ERROR) | |||
| << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first."; | |||
| return nullptr; | |||
| } | |||
| std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses(); | |||
| return name_to_addr["weight"]; | |||
| } | |||
| void ParameterAggregator::ResetAggregationStatus() { | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first; | |||
| if (aggr_kernel == nullptr) { | |||
| MS_LOG(ERROR) << "The aggregation kernel is nullptr."; | |||
| continue; | |||
| } | |||
| aggr_kernel->Reset(); | |||
| } | |||
| return; | |||
| } | |||
| void ParameterAggregator::ResetOptimizingStatus() { optimizing_done_ = false; } | |||
| void ParameterAggregator::ResetPullingStatus() { | |||
| pulling_done_ = false; | |||
| current_pull_count_ = 0; | |||
| } | |||
| bool ParameterAggregator::IsAggregationDone() const { | |||
| // Only consider aggregation done after each aggregation kernel is done. | |||
| for (auto &aggregator_with_params : aggregation_kernel_parameters_) { | |||
| std::shared_ptr<kernel::AggregationKernel> aggr_kernel = aggregator_with_params.first; | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| if (!aggr_kernel->IsAggregationDone()) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::IsOptimizingDone() const { return optimizing_done_; } | |||
| bool ParameterAggregator::IsPullingDone() const { return pulling_done_; } | |||
| bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode); | |||
| for (const std::string &name : aggr_kernel_names) { | |||
| auto aggr_kernel = kernel::AggregationKernelFactory::GetInstance().Create(name, cnode); | |||
| if (aggr_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Fail to create aggregation kernel " << name << " for " << AnfAlgo::GetCNodeName(cnode); | |||
| return false; | |||
| } | |||
| // set_done_count must be called before InitKernel because InitKernel may use this count. | |||
| aggr_kernel->set_done_count(required_push_count_); | |||
| aggr_kernel->InitKernel(cnode); | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = aggr_kernel->reuse_kernel_node_inputs_info(); | |||
| if (!AssignMemory(aggr_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) { | |||
| MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed."; | |||
| return false; | |||
| } | |||
| if (!GenerateAggregationKernelParams(aggr_kernel, memory_register_)) { | |||
| MS_LOG(EXCEPTION) << "Generating aggregation kernel parameters for " << name << " failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { | |||
| // if (PSContext::instance()->server_mode() == kServerModeFL) { | |||
| // MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; | |||
| // return false; | |||
| // } | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const std::string &name = AnfAlgo::GetCNodeName(cnode); | |||
| auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); | |||
| if (optimizer_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name; | |||
| return false; | |||
| } | |||
| optimizer_kernel->InitKernel(cnode); | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info(); | |||
| if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) { | |||
| MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed."; | |||
| return false; | |||
| } | |||
| if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) { | |||
| MS_LOG(ERROR) << "Generating optimizer kernel parameters failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename K> | |||
| bool ParameterAggregator::AssignMemory(K server_kernel, const CNodePtr &cnode, | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, | |||
| std::shared_ptr<MemoryRegister> memory_register) { | |||
| MS_EXCEPTION_IF_NULL(server_kernel); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| const std::vector<std::string> &input_names = server_kernel->input_names(); | |||
| const std::vector<size_t> &input_size_list = server_kernel->GetInputSizeList(); | |||
| if (input_names.size() != input_size_list.size()) { | |||
| MS_LOG(EXCEPTION) << "Server kernel " << typeid(server_kernel.get()).name() | |||
| << " input number is not matched: input_names size is " << input_names.size() | |||
| << ", input_size_list size is " << input_size_list.size(); | |||
| return false; | |||
| } | |||
| if (reuse_kernel_node_inputs_info.size() > input_names.size()) { | |||
| MS_LOG(EXCEPTION) << "The reuse kernel node information number is invalid: got " | |||
| << reuse_kernel_node_inputs_info.size() << ", but input_names size is " << input_names.size(); | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < input_names.size(); i++) { | |||
| const std::string &name = input_names[i]; | |||
| if (memory_register->addresses().count(name) != 0) { | |||
| MS_LOG(DEBUG) << "The memory for " << name << " is already assigned."; | |||
| continue; | |||
| } | |||
| if (reuse_kernel_node_inputs_info.count(name) != 0) { | |||
| // Reusing memory of the kernel node means the memory of the input is already assigned by the front end, which | |||
| // is to say, the input node is a parameter node. | |||
| size_t index = reuse_kernel_node_inputs_info.at(name); | |||
| MS_LOG(INFO) << "Try to reuse memory of kernel node " << AnfAlgo::GetCNodeName(cnode) << " for parameter " << name | |||
| << ", kernel node index " << index; | |||
| AddressPtr input_addr = GenerateParameterNodeAddrPtr(cnode, index); | |||
| MS_EXCEPTION_IF_NULL(input_addr); | |||
| memory_register->RegisterAddressPtr(name, input_addr); | |||
| } else { | |||
| MS_LOG(INFO) << "Assign new memory for " << name; | |||
| auto input_addr = std::make_unique<char[]>(input_size_list[i]); | |||
| MS_EXCEPTION_IF_NULL(input_addr); | |||
| memory_register->RegisterArray(name, &input_addr, input_size_list[i]); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register) { | |||
| RETURN_IF_NULL(aggr_kernel, false); | |||
| RETURN_IF_NULL(memory_register, false); | |||
| KernelParams aggr_params = {}; | |||
| const std::vector<std::string> &input_names = aggr_kernel->input_names(); | |||
| std::transform(input_names.begin(), input_names.end(), std::back_inserter(aggr_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &workspace_names = aggr_kernel->workspace_names(); | |||
| std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(aggr_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &output_names = aggr_kernel->output_names(); | |||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); | |||
| return true; | |||
| } | |||
| bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register) { | |||
| RETURN_IF_NULL(optimizer_kernel, false); | |||
| RETURN_IF_NULL(memory_register, false); | |||
| KernelParams optimizer_params = {}; | |||
| const std::vector<std::string> &input_names = optimizer_kernel->input_names(); | |||
| std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names(); | |||
| std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| const std::vector<std::string> &output_names = optimizer_kernel->output_names(); | |||
| std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), | |||
| [&](const std::string &name) { return memory_register->addresses()[name]; }); | |||
| optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params)); | |||
| return true; | |||
| } | |||
| std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { | |||
| std::vector<std::string> aggregation_algorithm = {}; | |||
| MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; | |||
| return aggregation_algorithm; | |||
| } | |||
| template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::OptimizerKernel> server_kernel, | |||
| const CNodePtr &cnode, | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, | |||
| std::shared_ptr<MemoryRegister> memory_register); | |||
| template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::AggregationKernel> server_kernel, | |||
| const CNodePtr &cnode, | |||
| const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, | |||
| std::shared_ptr<MemoryRegister> memory_register); | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,139 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ | |||
| #define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "ps/server/common.h" | |||
| #include "ps/server/memory_register.h" | |||
| #include "ps/server/kernel/aggregation_kernel_factory.h" | |||
| #include "ps/server/kernel/optimizer_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace server { | |||
| // Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server | |||
| // kernels. | |||
| typedef struct { | |||
| std::vector<AddressPtr> inputs; | |||
| std::vector<AddressPtr> workspace; | |||
| std::vector<AddressPtr> outputs; | |||
| } KernelParams; | |||
| // ParameterAggregator includes methods for aggregating gradients and optimizing weights(launching aggregation and | |||
| // optimizer kernels), getting weights, etc. It's not thread-safe, which means the caller must acquire lock before | |||
| // calling ParameterAggregator methods concurrently. | |||
| // Each ParameterAggregator is corresponding to one weight for now. | |||
| // ParameterAggregator is stateful because the process of aggregation and optimizing could be stateful. | |||
| // For example, the finite-state machine for the ParameterAggregator in parameter server training mode is below: | |||
| // Initial->Aggregating->Aggregation done->Optimizing->Optimizing done->Pulling->Pull done->Initial. | |||
| class ParameterAggregator { | |||
| public: | |||
| ParameterAggregator() | |||
| : server_mode_(ServerMode::PARAMETER_SERVER), | |||
| required_push_count_(0), | |||
| required_pull_count_(0), | |||
| current_pull_count_(0), | |||
| aggregation_done_(false), | |||
| optimizing_done_(false), | |||
| pulling_done_(true), | |||
| memory_register_(nullptr) {} | |||
| ~ParameterAggregator() = default; | |||
| // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. | |||
| // The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. | |||
| bool Init(const CNodePtr &cnode, size_t required_count = 0); | |||
| // Update old data stored in ParameterAggregator with new data. | |||
| // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. | |||
| bool UpdateData(const std::map<std::string, Address> &new_data); | |||
| // Launch aggregators/optimizers of this ParameterAggregator in order. | |||
| bool LaunchAggregators(); | |||
| bool LaunchOptimizers(); | |||
| // The implementation for primitive Pull in parameter server training mode. | |||
| // Every call of this method will increase the count for pull by 1. | |||
| AddressPtr Pull(); | |||
| // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing | |||
| // any change of status. | |||
| AddressPtr GetWeight(); | |||
| // After aggregation/optimizing/pulling of one iteration is done, caller must reset the status to ensure the | |||
| // correctness of the aggregation/optimizing/pulling for next iteration. | |||
| void ResetAggregationStatus(); | |||
| void ResetOptimizingStatus(); | |||
| void ResetPullingStatus(); | |||
| // Returns the aggregation/optimizing/pulling status to the caller. | |||
| bool IsAggregationDone() const; | |||
| bool IsOptimizingDone() const; | |||
| bool IsPullingDone() const; | |||
| private: | |||
| // Initializing aggregation/optimizer kenerls based on the cnode. The reason of this is described in the file | |||
| // kernel/kernel_factory.h. | |||
| bool InitAggregationKernels(const CNodePtr &cnode); | |||
| bool InitOptimizerKernels(const CNodePtr &cnode); | |||
| // Assign memory for server kernel K(AggregationKernel/OptimizerKernel). | |||
| // The memory assigned can be accessed by MemoryRegister. The memory could be weights, gradients, learning_rate, | |||
| // momentum, etc. | |||
| template <typename K> | |||
| bool AssignMemory(K server_kernel, const CNodePtr &cnode, const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info, | |||
| std::shared_ptr<MemoryRegister> memory_register); | |||
| // Generate kernel parameters for aggregation/optimizer kernels. All the parameters is registered and stored in | |||
| // memory_register. | |||
| bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> aggr_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register); | |||
| bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> optim_kernel, | |||
| const std::shared_ptr<MemoryRegister> memory_register); | |||
| // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user | |||
| // configuration, etc. | |||
| std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode); | |||
| ServerMode server_mode_; | |||
| size_t required_push_count_; | |||
| size_t required_pull_count_; | |||
| size_t current_pull_count_; | |||
| // The status of aggregation/optimizing/pulling. | |||
| bool aggregation_done_; | |||
| bool optimizing_done_; | |||
| bool pulling_done_; | |||
| // ParameterAggregator stores all data that it needs for aggregation, optimizing, etc. | |||
| std::shared_ptr<MemoryRegister> memory_register_; | |||
| // Update could have multiple aggregation and optimizer server kernels. | |||
| // Here stores multiple pairs of server kernels to parameters of their Launch function. | |||
| std::vector<std::pair<std::shared_ptr<kernel::AggregationKernel>, KernelParams>> aggregation_kernel_parameters_; | |||
| std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_; | |||
| }; | |||
| } // namespace server | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_ | |||
| @@ -168,6 +168,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_serve | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/server/kernel/apply_momentum_kernel.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") | |||