Merge pull request !3027 from ZPaC/adaptation-for-ps-modetags/v0.6.0-beta
| @@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod { | |||
| public: | |||
| CPUKernel() = default; | |||
| ~CPUKernel() override = default; | |||
| void Init(const CNodePtr &kernel_node); | |||
| virtual void Init(const CNodePtr &kernel_node); | |||
| virtual void InitKernel(const CNodePtr &kernel_node) = 0; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override { | |||
| @@ -62,10 +62,12 @@ class CPUKernelRegistrar { | |||
| static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ | |||
| []() { return std::make_shared<OPCLASS>(); }); | |||
| #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ | |||
| #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) | |||
| #define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) | |||
| #define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ | |||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \ | |||
| static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \ | |||
| []() { return std::make_shared<OPCLASS<T>>(); }); | |||
| static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ | |||
| #OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); }); | |||
| #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ | |||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \ | |||
| @@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel( | |||
| if (grad_shape[0] != indices_size_) { | |||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | |||
| } | |||
| /* | |||
| lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr"); | |||
| if (lr_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "lr should be a positive scalar"; | |||
| } | |||
| l1_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l1"); | |||
| if (l1_ < 0) { | |||
| MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; | |||
| } | |||
| l2_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l2"); | |||
| if (l2_ < 0) { | |||
| MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; | |||
| } | |||
| lr_power_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr_power"); | |||
| if (lr_power_ > 0) { | |||
| MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; | |||
| } | |||
| */ | |||
| lr_ = 0.01; | |||
| l1_ = 1e-8; | |||
| l2_ = 1e-8; | |||
| lr_power_ = -0.5; | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| } | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/pass/replace_node_by_proxy.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "device/kernel_info.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<std::string> inputs_device_format; | |||
| std::vector<std::string> outputs_device_format; | |||
| std::vector<TypeId> inputs_device_type; | |||
| std::vector<TypeId> outputs_device_type; | |||
| std::vector<std::vector<size_t>> outputs_shape; | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||
| inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | |||
| inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | |||
| } | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||
| } | |||
| builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); | |||
| builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); | |||
| builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); | |||
| builder.SetInputsFormat(inputs_device_format); | |||
| builder.SetOutputsFormat(outputs_device_format); | |||
| builder.SetInputsDeviceType(inputs_device_type); | |||
| builder.SetOutputsDeviceType(outputs_device_type); | |||
| return builder.Build(); | |||
| } | |||
| bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| for (auto node : node_list) { | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)}; | |||
| proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||
| AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); | |||
| MS_EXCEPTION_IF_NULL(proxy_node); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| proxy_node->set_kernel_info(kernel_info); | |||
| AbstractBasePtrList abstract_list; | |||
| AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); | |||
| AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); | |||
| AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); | |||
| abstract_list.push_back(cnode->abstract()); | |||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||
| proxy_node->set_abstract(abstract_tuple); | |||
| auto kernel_build_info = GenerateKernelBuildInfo(cnode); | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); | |||
| if (!manager->Replace(cnode, proxy_node)) { | |||
| MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "pre_activate/common/pass.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/anf.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ReplaceNodeByProxy : public Pass { | |||
| public: | |||
| explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} | |||
| ~ReplaceNodeByProxy() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """comm_helper""" | |||
| import os | |||
| from ._hccl_management import load_lib as hccl_load_lib | |||
| _HCCL_AVAILABLE = False | |||
| @@ -44,7 +44,7 @@ else: | |||
| HCCL_WORLD_COMM_GROUP = "hccl_world_group" | |||
| NCCL_WORLD_COMM_GROUP = "nccl_world_group" | |||
| MS_ROLE = os.getenv("MS_ROLE") | |||
| class Backend: | |||
| """ | |||
| @@ -152,6 +152,9 @@ def _get_rank_helper(group, backend): | |||
| Integer. The local rank id of the calling process. | |||
| """ | |||
| rank_id = None | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| rank_id = 0 | |||
| return rank_id | |||
| if backend == Backend.HCCL: | |||
| if group == HCCL_WORLD_COMM_GROUP: | |||
| rank_id = hccl.get_rank_id() | |||
| @@ -211,6 +214,9 @@ def _get_size_helper(group, backend): | |||
| Integer. The rank size of specified group. | |||
| """ | |||
| size = None | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| size = 1 | |||
| return size | |||
| if backend == Backend.HCCL: | |||
| if group == HCCL_WORLD_COMM_GROUP: | |||
| size = hccl.get_rank_size() | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Communication management API""" | |||
| import os | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | |||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | |||
| @@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", | |||
| DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | |||
| DEFAULT_BACKEND = Backend("hccl") | |||
| MS_ROLE = os.getenv("MS_ROLE") | |||
| def _get_group(group): | |||
| @@ -58,6 +60,8 @@ def init(backend_name="hccl"): | |||
| TypeError: If backend name is not a string. | |||
| RuntimeError: If backend is invalid or distributed init fails. | |||
| """ | |||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||
| return | |||
| if not isinstance(backend_name, str): | |||
| raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | |||