Browse Source

workspace of comm op can be reused

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
tags/v1.2.0-rc1
zhoufeng 4 years ago
parent
commit
b7e5f956e5
7 changed files with 62 additions and 24 deletions
  1. +1
    -1
      graphengine
  2. +24
    -7
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc
  3. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h
  4. +1
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  5. +26
    -3
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc
  6. +7
    -10
      mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h
  7. +1
    -1
      tests/ut/cpp/stub/ge/ge_task_launch_stub.cc

+ 1
- 1
graphengine

@@ -1 +1 @@
Subproject commit f65be61197ed36dfc9dc10b91b58bf93835fa27b
Subproject commit 40e5c42a12c4daa1530e8db9d006d5b3be5b378f

+ 24
- 7
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc View File

@@ -46,7 +46,7 @@ std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {


namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
void HcclKernelFactory::Register(const std::string &name, HcclKernelCreater &&fun) {
hcclKernelMap_.emplace(name, std::move(fun)); hcclKernelMap_.emplace(name, std::move(fun));
} }


@@ -99,7 +99,7 @@ bool HcclKernel::Init(const AnfNodePtr &anf_node) {
if (op_name_ == kReceive) { if (op_name_ == kReceive) {
auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_); auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(receive_type_);
if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) {
MS_LOG(ERROR) << "HcomDataType cann't support Current Ascend Data Type : " << receive_type_;
MS_LOG(ERROR) << "HcomDataType cannot support Current Ascend Data Type : " << receive_type_;
return false; return false;
} }
hccl_data_type_list_.emplace_back(iter->second); hccl_data_type_list_.emplace_back(iter->second);
@@ -180,9 +180,17 @@ const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
return output_size_list_; return output_size_list_;
} }


const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const {
if (!workspace_size_list_.empty() || hccl_data_type_list_.empty()) {
return workspace_size_list_;
}

workspace_size_list_.emplace_back(hccl::CalcWorkspaceSize(anf_node_, hccl_data_type_list_[0]));
return workspace_size_list_;
}


std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) { const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
if (hccl_type == kReceive) { if (hccl_type == kReceive) {
@@ -221,10 +229,19 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret; MS_LOG(EXCEPTION) << "Set data memcpy_s failed, ret = " << sec_ret;
} }


void *workspace_addr = nullptr;
if (task.workspace_size != 0) {
if (workspace.empty()) {
MS_LOG(EXCEPTION) << "Workspace size list of " << anf_node_->DebugString() << " is empty";
}
MS_EXCEPTION_IF_NULL(workspace.at(0));
workspace_addr = workspace.at(0)->addr;
}

results.emplace_back(std::make_shared<HcclTaskInfo>( results.emplace_back(std::make_shared<HcclTaskInfo>(
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, task.workspace_size,
task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_, op_type_, data_type,
group_, NeedDump()));
kernel_name_, stream_id, hccl::GetHcclType(anf_node_), input_data_addr, output_data_addr, workspace_addr,
task.workspace_size, task.stream_num, private_def, hccl::GetHcclOpsKernelInfoStore(), hccl_count_, root_id_,
op_type_, data_type, group_, NeedDump()));
} }


return results; return results;


+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h View File

@@ -68,7 +68,7 @@ class HcclKernelFactory {


public: public:
static HcclKernelFactory &Get(); static HcclKernelFactory &Get();
void Registe(const string &name, HcclKernelCreater &&fun);
void Register(const string &name, HcclKernelCreater &&fun);
static std::shared_ptr<HcclKernel> Get(const string &name); static std::shared_ptr<HcclKernel> Get(const string &name);


private: private:
@@ -78,7 +78,7 @@ class HcclKernelFactory {
class _HcclKernelRegister { class _HcclKernelRegister {
public: public:
_HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) {
HcclKernelFactory::Get().Registe(name, std::move(fun));
HcclKernelFactory::Get().Register(name, std::move(fun));
} }
~_HcclKernelRegister() = default; ~_HcclKernelRegister() = default;
}; };


+ 1
- 0
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -433,6 +433,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(type, node); AssignCommunicationNodeInputMem(type, node);
AssignCommunicationNodeOutputMem(type, node); AssignCommunicationNodeOutputMem(type, node);
AssignWorkSpaceMem(type, node);
} }


void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {


+ 26
- 3
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.cc View File

@@ -99,7 +99,7 @@ bool FinalizeHccl() {
if (ops_kernel_info_store != nullptr) { if (ops_kernel_info_store != nullptr) {
auto ret = ops_kernel_info_store->Finalize(); auto ret = ops_kernel_info_store->Finalize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy info store failed, ret = " << ret;
return false; return false;
} }
} }
@@ -107,7 +107,7 @@ bool FinalizeHccl() {
if (ops_kernel_builder != nullptr) { if (ops_kernel_builder != nullptr) {
auto ret = ops_kernel_builder->Finalize(); auto ret = ops_kernel_builder->Finalize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret;
MS_LOG(ERROR) << "Destroy builder failed, ret = " << ret;
return false; return false;
} }
} }
@@ -151,7 +151,30 @@ bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTask
return true; return true;
} }


bool CalcOpRunningParam(const AnfNodePtr &node) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype) {
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
MS_LOG(INFO) << "Start calc workspace size for hccl node " << node->DebugString() << " ,dtype is " << datatype;
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
MS_EXCEPTION_IF_NULL(ge_node);
auto op = ge_node->GetOpDesc();
MS_EXCEPTION_IF_NULL(op);

MS_LOG(INFO) << "Start to call CalcOpRunningParam";
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
if (ret != ge::SUCCESS) {
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
return false;
}

auto workspace_sizes = op->GetWorkspaceBytes();
if (workspace_sizes.size() != 1) {
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
}
int64_t workspace_size = workspace_sizes[0];
MS_LOG(INFO) << "Node " << node->DebugString() << " workspace size is " << workspace_size;
ge_graph.reset();
return workspace_size;
}


void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); } void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }




+ 7
- 10
mindspore/ccsrc/runtime/hccl_adapter/hccl_adapter.h View File

@@ -23,21 +23,18 @@
#include "mindspore/core/ir/anf.h" #include "mindspore/core/ir/anf.h"
#include "hccl/hccl_types.h" #include "hccl/hccl_types.h"


#define MS_API __attribute__((visibility("default")))

namespace mindspore::hccl { namespace mindspore::hccl {
struct MS_API HcclTaskInfo {
struct HcclTaskInfo {
std::string private_def; std::string private_def;
int64_t workspace_size; int64_t workspace_size;
int64_t stream_num; int64_t stream_num;
}; };


MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
MS_API bool FinalizeHccl();
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
MS_API bool CalcOpRunningParam(const AnfNodePtr &node);
MS_API void *GetHcclOpsKernelInfoStore();
MS_API std::string GetHcclType(const AnfNodePtr &node);
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
bool FinalizeHccl();
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
int64_t CalcWorkspaceSize(const AnfNodePtr &node, HcclDataType datatype);
void *GetHcclOpsKernelInfoStore();
std::string GetHcclType(const AnfNodePtr &node);
} // namespace mindspore::hccl } // namespace mindspore::hccl
#undef MS_API
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H #endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H

+ 1
- 1
tests/ut/cpp/stub/ge/ge_task_launch_stub.cc View File

@@ -64,7 +64,7 @@ namespace hccl {
bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; } bool InitHccl(uint32_t, std::string_view, std::string_view) { return true; }
bool FinalizeHccl() { return true; } bool FinalizeHccl() { return true; }
bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; } bool GenTask(const AnfNodePtr &, HcclDataType, std::vector<HcclTaskInfo> *) { return true; }
bool CalcOpRunningParam(const AnfNodePtr &) { return true; }
int64_t CalcWorkspaceSize(const AnfNodePtr &, HcclDataType) { return 0; }
void *GetHcclOpsKernelInfoStore() { return nullptr; } void *GetHcclOpsKernelInfoStore() { return nullptr; }
std::string GetHcclType(const AnfNodePtr &) { return ""; } std::string GetHcclType(const AnfNodePtr &) { return ""; }
} // namespace hccl } // namespace hccl


Loading…
Cancel
Save