Merge pull request !3027 from ZPaC/adaptation-for-ps-modetags/v0.6.0-beta
| @@ -55,7 +55,7 @@ class CPUKernel : public kernel::KernelMod { | |||||
| public: | public: | ||||
| CPUKernel() = default; | CPUKernel() = default; | ||||
| ~CPUKernel() override = default; | ~CPUKernel() override = default; | ||||
| void Init(const CNodePtr &kernel_node); | |||||
| virtual void Init(const CNodePtr &kernel_node); | |||||
| virtual void InitKernel(const CNodePtr &kernel_node) = 0; | virtual void InitKernel(const CNodePtr &kernel_node) = 0; | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override { | const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override { | ||||
| @@ -62,10 +62,12 @@ class CPUKernelRegistrar { | |||||
| static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ | static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ | ||||
| []() { return std::make_shared<OPCLASS>(); }); | []() { return std::make_shared<OPCLASS>(); }); | ||||
| #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ | |||||
| #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) | |||||
| #define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) | |||||
| #define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ | |||||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \ | static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \ | ||||
| static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \ | |||||
| []() { return std::make_shared<OPCLASS<T>>(); }); | |||||
| static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ | |||||
| #OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); }); | |||||
| #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ | #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ | ||||
| static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \ | static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \ | ||||
| @@ -46,24 +46,10 @@ void SparseApplyFtrlPSKernel::InitKernel( | |||||
| if (grad_shape[0] != indices_size_) { | if (grad_shape[0] != indices_size_) { | ||||
| MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; | ||||
| } | } | ||||
| /* | |||||
| lr_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr"); | |||||
| if (lr_ <= 0) { | |||||
| MS_LOG(EXCEPTION) << "lr should be a positive scalar"; | |||||
| } | |||||
| l1_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l1"); | |||||
| if (l1_ < 0) { | |||||
| MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; | |||||
| } | |||||
| l2_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "l2"); | |||||
| if (l2_ < 0) { | |||||
| MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; | |||||
| } | |||||
| lr_power_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "lr_power"); | |||||
| if (lr_power_ > 0) { | |||||
| MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; | |||||
| } | |||||
| */ | |||||
| lr_ = 0.01; | |||||
| l1_ = 1e-8; | |||||
| l2_ = 1e-8; | |||||
| lr_power_ = -0.5; | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | ||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | ||||
| } | } | ||||
| @@ -0,0 +1,92 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/pass/replace_node_by_proxy.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "device/kernel_info.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| std::vector<std::string> inputs_device_format; | |||||
| std::vector<std::string> outputs_device_format; | |||||
| std::vector<TypeId> inputs_device_type; | |||||
| std::vector<TypeId> outputs_device_type; | |||||
| std::vector<std::vector<size_t>> outputs_shape; | |||||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||||
| for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { | |||||
| inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); | |||||
| inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); | |||||
| } | |||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||||
| outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); | |||||
| outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); | |||||
| outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); | |||||
| } | |||||
| builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); | |||||
| builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); | |||||
| builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); | |||||
| builder.SetInputsFormat(inputs_device_format); | |||||
| builder.SetOutputsFormat(outputs_device_format); | |||||
| builder.SetInputsDeviceType(inputs_device_type); | |||||
| builder.SetOutputsDeviceType(outputs_device_type); | |||||
| return builder.Build(); | |||||
| } | |||||
| bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto node : node_list) { | |||||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| auto prim = std::make_shared<Primitive>(kEmbeddingLookupProxyOpName); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| std::vector<AnfNodePtr> proxy_inputs = {NewValueNode(prim)}; | |||||
| proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); | |||||
| AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); | |||||
| MS_EXCEPTION_IF_NULL(proxy_node); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| proxy_node->set_kernel_info(kernel_info); | |||||
| AbstractBasePtrList abstract_list; | |||||
| AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); | |||||
| AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); | |||||
| AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); | |||||
| abstract_list.push_back(cnode->abstract()); | |||||
| auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tuple); | |||||
| proxy_node->set_abstract(abstract_tuple); | |||||
| auto kernel_build_info = GenerateKernelBuildInfo(cnode); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); | |||||
| if (!manager->Replace(cnode, proxy_node)) { | |||||
| MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "pre_activate/common/pass.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/anf.h" | |||||
| #include "utils/utils.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class ReplaceNodeByProxy : public Pass { | |||||
| public: | |||||
| explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} | |||||
| ~ReplaceNodeByProxy() override = default; | |||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| private: | |||||
| kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ | |||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """comm_helper""" | """comm_helper""" | ||||
| import os | |||||
| from ._hccl_management import load_lib as hccl_load_lib | from ._hccl_management import load_lib as hccl_load_lib | ||||
| _HCCL_AVAILABLE = False | _HCCL_AVAILABLE = False | ||||
| @@ -44,7 +44,7 @@ else: | |||||
| HCCL_WORLD_COMM_GROUP = "hccl_world_group" | HCCL_WORLD_COMM_GROUP = "hccl_world_group" | ||||
| NCCL_WORLD_COMM_GROUP = "nccl_world_group" | NCCL_WORLD_COMM_GROUP = "nccl_world_group" | ||||
| MS_ROLE = os.getenv("MS_ROLE") | |||||
| class Backend: | class Backend: | ||||
| """ | """ | ||||
| @@ -152,6 +152,9 @@ def _get_rank_helper(group, backend): | |||||
| Integer. The local rank id of the calling process. | Integer. The local rank id of the calling process. | ||||
| """ | """ | ||||
| rank_id = None | rank_id = None | ||||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||||
| rank_id = 0 | |||||
| return rank_id | |||||
| if backend == Backend.HCCL: | if backend == Backend.HCCL: | ||||
| if group == HCCL_WORLD_COMM_GROUP: | if group == HCCL_WORLD_COMM_GROUP: | ||||
| rank_id = hccl.get_rank_id() | rank_id = hccl.get_rank_id() | ||||
| @@ -211,6 +214,9 @@ def _get_size_helper(group, backend): | |||||
| Integer. The rank size of specified group. | Integer. The rank size of specified group. | ||||
| """ | """ | ||||
| size = None | size = None | ||||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||||
| size = 1 | |||||
| return size | |||||
| if backend == Backend.HCCL: | if backend == Backend.HCCL: | ||||
| if group == HCCL_WORLD_COMM_GROUP: | if group == HCCL_WORLD_COMM_GROUP: | ||||
| size = hccl.get_rank_size() | size = hccl.get_rank_size() | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Communication management API""" | """Communication management API""" | ||||
| import os | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | ||||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | ||||
| @@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", | |||||
| DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | ||||
| DEFAULT_BACKEND = Backend("hccl") | DEFAULT_BACKEND = Backend("hccl") | ||||
| MS_ROLE = os.getenv("MS_ROLE") | |||||
| def _get_group(group): | def _get_group(group): | ||||
| @@ -58,6 +60,8 @@ def init(backend_name="hccl"): | |||||
| TypeError: If backend name is not a string. | TypeError: If backend name is not a string. | ||||
| RuntimeError: If backend is invalid or distributed init fails. | RuntimeError: If backend is invalid or distributed init fails. | ||||
| """ | """ | ||||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | |||||
| return | |||||
| if not isinstance(backend_name, str): | if not isinstance(backend_name, str): | ||||
| raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | ||||