/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ps/util.h" #include #include #include #include "ps/constants.h" #include "ps/ps_context.h" #include "utils/ms_utils.h" namespace mindspore { namespace ps { std::unordered_map Util::optimizer_to_ids{ {kApplyMomentum, 0}, {kSparseAdam, 1}, {kSparseLazyAdam, 2}, {kSparseFtrl, 3}, }; std::unordered_map Util::id_to_optimizers{ {0, kApplyMomentum}, {1, kSparseAdam}, {2, kSparseLazyAdam}, {3, kSparseFtrl}, }; std::unordered_map Util::id_to_optimizer_nodes{ {0, kApplyMomentumOp}, {1, kSparseAdamOp}, {2, kSparseLazyAdamOp}, {3, kSparseFtrlOp}, }; bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); } bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); } int64_t Util::optimizer_id(const std::string &name) { if (optimizer_to_ids.count(name) > 0) { return optimizer_to_ids[name]; } return -1; } std::string Util::optimizer_name(int64_t id) { if (id_to_optimizers.count(id) > 0) { return id_to_optimizers[id]; } return ""; } std::string Util::optimizer_node_name(int64_t id) { if (id_to_optimizer_nodes.count(id) > 0) { return id_to_optimizer_nodes[id]; } return ""; } bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; } int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) { std::map shard_dims = AllRankLocalShard(first_dim, rank_id, server_num); if (shard_dims.count(rank_id) == 0) { MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id; } return shard_dims[rank_id]; } std::map Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) { if (first_dim <= 0 || server_num <= 0 || rank_id < 0) { MS_LOG(EXCEPTION) << "Input values are invalid."; } if (rank_id >= server_num) { MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num; } std::map shard_dims; for (int64_t i = 0; i < server_num; i++) { shard_dims[i] = 0; } if (server_num != static_cast(shard_dims.size())) { MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size(); } int64_t server_index = -1; for (int64_t i = 0; i < first_dim; i++) { server_index = (server_index + 1) % server_num; shard_dims[server_index] = shard_dims[server_index] + 1; } if (shard_dims.count(rank_id) == 0) { MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num; } return shard_dims; } void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, const size_t first_dim_size, const size_t outer_dim_size, mindspore::kernel::SparseGradient *unique_sparse_grad) { size_t slice_segment_size = indices_size * segment_size; std::vector workspace_grad(slice_segment_size); std::vector workspace_indices(indices_size); MS_EXCEPTION_IF_NULL(gradients); MS_EXCEPTION_IF_NULL(indices); mindspore::kernel::SparseGradient workspace_sparse_grad( {workspace_grad.data(), workspace_indices.data(), indices_size}); mindspore::kernel::SparseGradient input_sparse_grad({gradients, indices, indices_size}); mindspore::kernel::ReduceSparseGradientParam param; param.input_grad_ = &input_sparse_grad; param.workspace_grad_ = &workspace_sparse_grad; param.output_grad_ = unique_sparse_grad; param.max_index_ = first_dim_size; param.value_stride_ = outer_dim_size; mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param); } bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName); DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName); return true; } void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name, const std::string &fused_cnode_name) { MS_EXCEPTION_IF_NULL(func_graph); std::vector node_list = TopoSort(func_graph->get_return()); std::vector single_nodes; std::vector weight_names; std::vector indices; for (const AnfNodePtr &node : node_list) { if (node != nullptr && node->isa()) { if (AnfAlgo::GetCNodeName(node) == cnode_name) { single_nodes.push_back(node); auto weight_name_value_node = AnfAlgo::GetInputNode(node->cast(), kNodeInputWeightNameOffset)->cast(); const std::string &weight_name = GetValue(weight_name_value_node->value()); weight_names.push_back(weight_name); auto weight_index_value_node = AnfAlgo::GetInputNode(node->cast(), kNodeInputWeightIndexOffset)->cast(); int64_t weight_index = GetValue(weight_index_value_node->value()); indices.push_back(weight_index); } } } auto prim = std::make_shared(fused_cnode_name); MS_EXCEPTION_IF_NULL(prim); std::vector fused_node_inputs = {}; fused_node_inputs.push_back(NewValueNode(prim)); (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) { fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast(), 0)); }); auto fused_cnode = func_graph->NewCNode(fused_node_inputs); MS_EXCEPTION_IF_NULL(fused_cnode); AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode); AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode); AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode); auto kernel_info = std::make_shared(); MS_EXCEPTION_IF_NULL(kernel_info); fused_cnode->set_kernel_info(kernel_info); auto kernel_build_info = GenerateKernelBuildInfo(single_nodes); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get()); AbstractBasePtrList abstract_list; for (const auto &node : single_nodes) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); abstract_list.push_back(cnode->abstract()); } auto abstract_tuple = std::make_shared(abstract_list); MS_EXCEPTION_IF_NULL(abstract_tuple); fused_cnode->set_abstract(abstract_tuple); auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); for (const auto &node : single_nodes) { if (!manager->Replace(node, fused_cnode)) { MS_LOG(EXCEPTION) << "manager replace node failed"; } } return; } kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector &node_list) { std::vector inputs_device_format; std::vector outputs_device_format; std::vector inputs_device_type; std::vector outputs_device_type; std::vector> outputs_shape; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; for (size_t idx = 0; idx < node_list.size(); ++idx) { auto cnode = utils::cast(node_list[idx]); MS_EXCEPTION_IF_NULL(cnode); size_t input_num = AnfAlgo::GetInputTensorNum(cnode); for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_device_format.push_back(kOpFormat_DEFAULT); inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); } size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_device_format.push_back(kOpFormat_DEFAULT); outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); } } builder.SetInputsFormat(inputs_device_format); builder.SetOutputsFormat(outputs_device_format); builder.SetInputsDeviceType(inputs_device_type); builder.SetOutputsDeviceType(outputs_device_type); return builder.Build(); } } // namespace ps } // namespace mindspore