From: @huaweib Reviewed-by: @jjfeing,@kisnwang Signed-off-by: @kisnwangtags/v1.1.0
| @@ -18,6 +18,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) | |||
| MS_REG_GPU_KERNEL_REGULAR(StreamRecv, KernelAttr(), RecvGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) | |||
| MS_REG_GPU_KERNEL_REGULAR(StreamSend, KernelAttr(), SendGpuKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -32,6 +32,8 @@ static std::map<std::string, std::string> kMsOpNameToHcomHcclType = { | |||
| {mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce}, | |||
| {mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather}, | |||
| {mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast}, | |||
| {mindspore::kHcomSendOpName, mindspore::kHcomOpTypeSend}, | |||
| {mindspore::kReceiveOpName, mindspore::kHcomOpTypeReceive}, | |||
| {mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}}; | |||
| std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { | |||
| auto iter = kMsOpNameToHcomHcclType.find(ms_op_type); | |||
| @@ -80,7 +82,12 @@ HcclKernel::~HcclKernel() { | |||
| bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| op_name_ = AnfAlgo::GetCNodeName(anf_node); | |||
| if (op_name_ == kReceive) { | |||
| if (!HcomUtil::GetHcomReceiveType(anf_node, &receive_type_)) { | |||
| MS_LOG(ERROR) << "GetHcomReceiveType fail!"; | |||
| return false; | |||
| } | |||
| } | |||
| if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { | |||
| MS_LOG(ERROR) << "GetKernelInputShape fail!"; | |||
| return false; | |||
| @@ -89,13 +96,27 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) { | |||
| MS_LOG(ERROR) << "GetKernelOutputShape fail!"; | |||
| return false; | |||
| } | |||
| if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { | |||
| if (op_name_ == kReceive) { | |||
| auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); | |||
| if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { | |||
| MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_; | |||
| return false; | |||
| } | |||
| hccl_data_type_list_.emplace_back(iter->second); | |||
| } else if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { | |||
| MS_LOG(ERROR) << "GetHcomDataType fail!"; | |||
| return false; | |||
| } | |||
| if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { | |||
| MS_LOG(ERROR) << "GetHcomCount fail!"; | |||
| return false; | |||
| if (op_name_ == kReceive) { | |||
| if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_output_shape_list_, &hccl_count_)) { | |||
| MS_LOG(ERROR) << "GetHcomCount fail!"; | |||
| return false; | |||
| } | |||
| } else { | |||
| if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { | |||
| MS_LOG(ERROR) << "GetHcomCount fail!"; | |||
| return false; | |||
| } | |||
| } | |||
| if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { | |||
| if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { | |||
| @@ -160,17 +181,24 @@ const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return wor | |||
| std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, uint32_t stream_id) { | |||
| if (inputs.empty() || outputs.empty()) { | |||
| std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); | |||
| if (hccl_type == kReceive) { | |||
| if (outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Outputs is empty"; | |||
| } | |||
| } else if (inputs.empty() || outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; | |||
| } | |||
| stream_id_ = stream_id; | |||
| MS_EXCEPTION_IF_NULL(inputs.at(0)); | |||
| auto input_data_addr = inputs.at(0)->addr; | |||
| void *input_data_addr = nullptr; | |||
| if (hccl_type != kReceive) { | |||
| MS_EXCEPTION_IF_NULL(inputs.at(0)); | |||
| input_data_addr = inputs.at(0)->addr; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(outputs.at(0)); | |||
| auto output_data_addr = outputs.at(0)->addr; | |||
| std::vector<uint8_t> private_def; | |||
| HcclDataType data_type = hccl_data_type_list_[0]; | |||
| std::vector<hccl::HcclTaskInfo> task_info; | |||
| bool ret = hccl::GenTask(anf_node_, data_type, &task_info); | |||
| if (!ret) { | |||
| @@ -51,6 +51,7 @@ class HcclKernel : public AscendKernelMod { | |||
| uint64_t hccl_count_; | |||
| HcclReduceOp op_type_; | |||
| uint32_t root_id_; | |||
| int64_t receive_type_; | |||
| mutable std::vector<size_t> input_size_list_; | |||
| mutable std::vector<size_t> output_size_list_; | |||
| mutable std::vector<size_t> workspace_size_list_; | |||
| @@ -33,6 +33,9 @@ std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { | |||
| if (parallel_context_instance->enable_parallel_optimizer() && op_name == kBroadcast) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| if (op_name == kReceive || op_name == kHcomSend) { | |||
| return kOpFormat_DEFAULT; | |||
| } | |||
| auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); | |||
| if (op_name != kReduceScatter && op_name != kAllGatherOpName) { | |||
| return format; | |||
| @@ -52,7 +55,8 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string op_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) { | |||
| if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter && | |||
| op_name != kHcomSend && op_name != kReceive) { | |||
| MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; | |||
| return; | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/hccl/hcom_receive.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomReceiveKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, | |||
| const std::vector<AddressPtr> & /*workspace*/, | |||
| const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) { | |||
| MS_LOG(INFO) << "HcomReceive launch"; | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/hccl/hccl_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class HcomReceiveKernel : public HcclKernel { | |||
| public: | |||
| HcomReceiveKernel() = default; | |||
| ~HcomReceiveKernel() override = default; | |||
| /* Inherit from kernelmod */ | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||
| private: | |||
| }; | |||
| MS_HCCL_REG_KERNEL(Receive, HcomReceiveKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_RECEIVE_H_ | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/hccl/hcom_send.h" | |||
| #include <memory> | |||
| #include "utils/ms_context.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| bool HcomSendKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, const std::vector<AddressPtr> & /*workspace*/, | |||
| const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) { | |||
| MS_LOG(INFO) << "HcomSend launch"; | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/hccl/hccl_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class HcomSendKernel : public HcclKernel { | |||
| public: | |||
| HcomSendKernel() = default; | |||
| ~HcomSendKernel() override = default; | |||
| /* Inherit from kernelmod */ | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override; | |||
| private: | |||
| }; | |||
| MS_HCCL_REG_KERNEL(Send, HcomSendKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_SEND_H_ | |||
| @@ -102,8 +102,11 @@ bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataTyp | |||
| uint64_t block_size; | |||
| size_t input_size; | |||
| uint32_t type_size = 4; | |||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { | |||
| size_t size = AnfAlgo::GetInputTensorNum(anf_node); | |||
| if (AnfAlgo::GetCNodeName(anf_node) == kReceiveOpName) { | |||
| size = AnfAlgo::GetOutputTensorNum(anf_node); | |||
| } | |||
| for (size_t i = 0; i < size; ++i) { | |||
| if (!GetHcomTypeSize(data_type_list[i], &type_size)) { | |||
| return false; | |||
| } | |||
| @@ -188,6 +191,20 @@ bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { | |||
| return true; | |||
| } | |||
| bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, int64_t *receive_type) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| MS_EXCEPTION_IF_NULL(receive_type); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (primitive->GetAttr("dtype") != nullptr) { | |||
| *receive_type = (int64_t)(GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id()); | |||
| } else { | |||
| MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) { | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| @@ -34,6 +34,8 @@ using std::vector; | |||
| constexpr auto kAllGather = "AllGather"; | |||
| constexpr auto kAllReduce = "AllReduce"; | |||
| constexpr auto kBroadcast = "Broadcast"; | |||
| constexpr auto kHcomSend = "Send"; | |||
| constexpr auto kReceive = "Receive"; | |||
| constexpr auto kReduceScatter = "ReduceScatter"; | |||
| /* Correspondence between data_type and hcom data type in Ascend */ | |||
| @@ -64,6 +66,7 @@ class HcomUtil { | |||
| static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type); | |||
| static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id); | |||
| static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group); | |||
| static bool GetHcomReceiveType(const AnfNodePtr &anf_node, int64_t *receive_type); | |||
| }; | |||
| } // namespace mindspore | |||
| @@ -39,7 +39,7 @@ class RecvKernel : public RtKernel { | |||
| uint32_t event_id_; | |||
| }; | |||
| MS_REG_RTKERNEL(recv, RecvKernel); | |||
| MS_REG_RTKERNEL(streamrecv, RecvKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -37,7 +37,7 @@ class SendKernel : public RtKernel { | |||
| uint32_t event_id_; | |||
| }; | |||
| MS_REG_RTKERNEL(send, SendKernel); | |||
| MS_REG_RTKERNEL(streamsend, SendKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -29,7 +29,7 @@ MemReuseChecker &MemReuseChecker::GetInstance() { | |||
| void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { | |||
| std::string node_name = AnfAlgo::GetCNodeName(c_node); | |||
| if (node_name == kSend || node_name == kRecv) { | |||
| if (node_name == kSendOpName || node_name == kRecvOpName) { | |||
| MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; | |||
| // get op's info && check | |||
| MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) | |||
| @@ -29,8 +29,6 @@ | |||
| #include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" | |||
| namespace mindspore { | |||
| namespace memreuse { | |||
| constexpr auto kSend = "Send"; | |||
| constexpr auto kRecv = "Recv"; | |||
| constexpr auto kSplitC = '/'; | |||
| class MemReuseChecker { | |||
| public: | |||
| @@ -1132,7 +1132,7 @@ bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { | |||
| } | |||
| auto kernel_name = AnfAlgo::GetCNodeName(node); | |||
| if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || | |||
| kernel_name == kReduceScatterOpName) { | |||
| kernel_name == kReduceScatterOpName || kernel_name == kHcomSendOpName || kernel_name == kReceiveOpName) { | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -28,6 +28,8 @@ | |||
| #include "mindspore/core/base/core_ops.h" | |||
| #include "transform/graph_ir/util.h" | |||
| static constexpr char kGeOpNameHcclSend[] = "HcomSend"; | |||
| static constexpr char kGeOpNameHcclReceive[] = "HcomReceive"; | |||
| static constexpr char kGeOpNameHcclAllRudece[] = "HcomAllReduce"; | |||
| static constexpr char kGeOpNameHcclAllGather[] = "HcomAllGather"; | |||
| static constexpr char kGeOpNameHcclBroadcast[] = "HcomBroadcast"; | |||
| @@ -63,6 +65,18 @@ struct IsString<std::string> { | |||
| static constexpr bool value = true; | |||
| }; | |||
| template <class T> | |||
| struct IsVector { | |||
| // cppcheck-suppress unusedStructMember | |||
| static constexpr bool value = false; | |||
| }; | |||
| template <> | |||
| struct IsVector<std::vector<int64_t>> { | |||
| // cppcheck-suppress unusedStructMember | |||
| static constexpr bool value = true; | |||
| }; | |||
| namespace mindspore::hccl { | |||
| template <class T> | |||
| static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const std::string &anf_attr_name, | |||
| @@ -78,6 +92,8 @@ static T ConvertAttr(const CNodePtr &cnode, const ge::OpDescPtr &ge_op, const st | |||
| auto attr = AnfAlgo::GetNodeAttr<T>(cnode, anf_attr_name); | |||
| if constexpr (IsString<T>::value) { | |||
| ret = ge::AttrUtils::SetStr(*ge_op, ge_attr_name, attr); | |||
| } else if constexpr (IsVector<T>::value) { | |||
| ret = ge::AttrUtils::SetListInt(*ge_op, ge_attr_name, attr); | |||
| } else { | |||
| ret = ge::AttrUtils::SetInt(*ge_op, ge_attr_name, attr); | |||
| } | |||
| @@ -99,6 +115,10 @@ std::string GetGeNodeName(const CNodePtr &cnode) { | |||
| return kGeOpNameHcclBroadcast; | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimReduceScatter)) { | |||
| return kGeOpNameHcclReduceScatter; | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimSend)) { | |||
| return kGeOpNameHcclSend; | |||
| } else if (IsPrimitiveCNode(cnode, prim::kPrimReceive)) { | |||
| return kGeOpNameHcclReceive; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unknown hccl node type " << cnode->DebugString(); | |||
| @@ -133,6 +153,10 @@ std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr | |||
| // set node attr | |||
| (void)ConvertAttr<int64_t>(cnode, op_desc, kAttrRankSize, ge::HCOM_ATTR_RANK_SIZE); | |||
| (void)ConvertAttr<std::string>(cnode, op_desc, kAttrGroup, ge::HCOM_ATTR_GROUP); | |||
| (void)ConvertAttr<int64_t>(cnode, op_desc, kAttrSrcRank, ge::HCOM_ATTR_SRC_RANK); | |||
| (void)ConvertAttr<int64_t>(cnode, op_desc, kAttrDestRank, ge::HCOM_ATTR_DEST_RANK); | |||
| (void)ConvertAttr<int64_t>(cnode, op_desc, kAttrSrTag, ge::HCOM_ATTR_SR_TAG); | |||
| (void)ConvertAttr<std::vector<int64_t>>(cnode, op_desc, kAttrShape, ge::HCOM_ATTR_SHAPE); | |||
| ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName); | |||
| MS_EXCEPTION_IF_NULL(ge_graph); | |||
| @@ -57,6 +57,8 @@ constexpr auto kAllReduceOpName = "AllReduce"; | |||
| constexpr auto kAllGatherOpName = "AllGather"; | |||
| constexpr auto kHostAllGatherOpName = "HostAllGather"; | |||
| constexpr auto kBroadcastOpName = "Broadcast"; | |||
| constexpr auto kReceiveOpName = "Receive"; | |||
| constexpr auto kHcomSendOpName = "Send"; | |||
| constexpr auto kReduceScatterOpName = "ReduceScatter"; | |||
| constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; | |||
| constexpr auto kMemCpyAsyncOpName = "memcpy_async"; | |||
| @@ -142,8 +144,8 @@ constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; | |||
| constexpr auto kStreamSwitchOpName = "StreamSwitch"; | |||
| constexpr auto kStreamActiveOpName = "StreamActive"; | |||
| constexpr auto kAssignAddOpName = "AssignAdd"; | |||
| constexpr auto kSendOpName = "Send"; | |||
| constexpr auto kRecvOpName = "Recv"; | |||
| constexpr auto kSendOpName = "StreamSend"; | |||
| constexpr auto kRecvOpName = "StreamRecv"; | |||
| constexpr auto kReluV2OpName = "ReLUV2"; | |||
| constexpr auto kReluGradV2OpName = "ReluGradV2"; | |||
| constexpr auto kAddNOpName = "AddN"; | |||
| @@ -248,6 +250,8 @@ constexpr auto kBroadcastToOpName = "BroadcastTo"; | |||
| constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; | |||
| constexpr auto kHcomOpTypeAllGather = "HcomAllGather"; | |||
| constexpr auto kHcomOpTypeBroadcast = "HcomBroadcast"; | |||
| constexpr auto kHcomOpTypeSend = "HcomSend"; | |||
| constexpr auto kHcomOpTypeReceive = "HcomReceive"; | |||
| constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter"; | |||
| // attr key name | |||
| @@ -292,6 +296,9 @@ constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active | |||
| constexpr auto kAttrFusion = "fusion"; | |||
| constexpr auto kAttrGroup = "group"; | |||
| constexpr auto kAttrOp = "op"; | |||
| constexpr auto kAttrDestRank = "dest_rank"; | |||
| constexpr auto kAttrSrcRank = "src_rank"; | |||
| constexpr auto kAttrSrTag = "sr_tag"; | |||
| constexpr auto kAttrRootRank = "root_rank"; | |||
| constexpr auto kAttrIsTraining = "is_training"; | |||
| constexpr auto kAttrFusionId = "fusion_id"; | |||
| @@ -193,6 +193,7 @@ inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | |||
| inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv"); | |||
| inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | |||