Browse Source

workspace of comm op can be reused

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.2.0-rc1
zhoufeng 4 years ago
parent
commit
b7e5f956e5
7 changed files with 62 additions and 24 deletions
  1. +1
    -1
      graphengine
  2. +24
    -7
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h
  4. +1
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  5. +26
    -3
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc
  6. +7
    -10
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h
  7. +1
    -1
      tests/ut/cpp/stub/ge/ge_task_launch_stub.cc

+ 1
- 1
graphengine

@@ -1 +1 @@
Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b
Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f

+ 24
- 7
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc View File

@@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {

namespace mindspore {
namespace kernel {
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
hcclKernelMap_.emplace(name, std::move(fun));
}

@@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
if (op_name_ == kReceive) {
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
return false;
}
hccl_data_type_list_.emplace_back(iter->second);
@@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
return output_size_list_;
}

const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
return workspace_size_list_;
}

workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
return workspace_size_list_;
}

std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
if (hccl_type == kReceive) {
@@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
}

void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}

results.emplace_back(std::make_shared<HcclTaskInfo>(
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size,
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type,
group_, NeedDump()));
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
op_type_, data_type, group_, NeedDump()));
}

return results;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h View File

@@ -68,7 +68,7 @@ class HcclKernelFactory {

public:
static HcclKernelFactory &Get();
void Registe(const string &name, HcclKernelCreater &&fun);
void Register(const string &name, HcclKernelCreater &&fun);
static std::shared_ptr<HcclKernel> Get(const string &name);

private:
@@ -78,7 +78,7 @@ class HcclKernelFactory {
class _HcclKernelRegister {
public:
_HcclKernelRegister(const string &name, HcclKernelCreater &&fun) {
HcclKernelFactory::Get().Registe(name, std::move(fun));
HcclKernelFactory::Get().Register(name, std::move(fun));
}
~_HcclKernelRegister() = default;
};


+ 1
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(type, node);
AssignCommunicationNodeOutputMem(type, node);
AssignWorkSpaceMem(type, node);
}

void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {


+ 26
- 3
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc View File

@@ -99,7 +99,7 @@ bool FinalizeHccl() {
if (ops_kernel_info_store != nullptr) {
auto ret = ops_kernel_info_store->Finalize();
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
return false;
}
}
@@ -107,7 +107,7 @@ bool FinalizeHccl() {
if (ops_kernel_builder != nullptr) {
auto ret = ops_kernel_builder->Finalize();
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
return false;
}
}
@@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
return true;
}

bool CalcOpRunningParam(const AnfNodePtr &node) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
MS_EXCEPTION_IF_NULL(ge_node);
auto op = ge_node->GetOpDesc();
MS_EXCEPTION_IF_NULL(op);

MS_LOG(INFO) << "Start to call CalcOpRunningParam";
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
return false;
}

auto workspace_sizes = op->GetWorkspaceBytes();
if (workspace_sizes.size() != 1) {
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
}
int64_t workspace_size = workspace_sizes[0];
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size;
ge_graph.reset();
return workspace_size;
}

void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }



+ 7
- 10
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h View File

@@ -23,21 +23,18 @@
#include "mindspore/core/ir/anf.h"
#include "hccl/hccl_types.h"

#define MS_API __attribute__((visibility("default")))

namespace mindspore::hccl {
struct MS_API HcclTaskInfo {
struct HcclTaskInfo {
std::string private_def;
int64_t workspace_size;
int64_t stream_num;
};

MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
MS_API bool FinalizeHccl();
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
MS_API bool CalcOpRunningParam(const AnfNodePtr &node);
MS_API void *GetHcclOpsKernelInfoStore();
MS_API std::string GetHcclType(const AnfNodePtr &node);
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool FinalizeHccl();
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
void *GetHcclOpsKernelInfoStore();
std::string GetHcclType(const AnfNodePtr &node);
} // namespace mindspore::hccl
#undef MS_API
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H

+ 1
- 1
tests/ut/cpp/stub/ge/ge_task_launch_stub.cc View File

@@ -64,7 +64,7 @@ namespace hccl {
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool FinalizeHccl() { return true; }
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
bool CalcOpRunningParam(const AnfNodePtr &) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
void *GetHcclOpsKernelInfoStore() { return nullptr; }
std::string GetHcclType(const AnfNodePtr &) { return ""; }
} // namespace hccl


Loading…
Cancel
Save