Browse Source

!27690 Fix multi node issue.

Merge pull request !27690 from ZPaC/dataset-helper-cache-enable
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
23afd952c8
4 changed files with 22 additions and 9 deletions
  1. +1
    -6
      mindspore/ccsrc/ps/core/abstract_node.cc
  2. +2
    -2
      mindspore/ccsrc/ps/parameter_server.cc
  3. +18
    -1
      mindspore/ccsrc/ps/util.cc
  4. +1
    -0
      mindspore/ccsrc/ps/util.h

+ 1
- 6
mindspore/ccsrc/ps/core/abstract_node.cc View File

@@ -1261,12 +1261,7 @@ void AbstractNode::CreateTcpServer() {
MS_EXCEPTION_IF_NULL(config_);
std::string interface;
std::string server_ip;
if (ps::PSContext::instance()->server_mode().empty()) {
// If the server mode is not set, use 127.0.0.1 as server ip address for distributed learning.
server_ip = "127.0.0.1";
} else {
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
}
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
MS_EXCEPTION_IF_NULL(server_);
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,


+ 2
- 2
mindspore/ccsrc/ps/parameter_server.cc View File

@@ -649,7 +649,7 @@ void ParameterServer::GetEmbeddingTableParamPtr() {
Key count = 0;
for (auto cnode : cnodes) {
MS_EXCEPTION_IF_NULL(cnode);
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
std::string cnode_name = Util::GetPrimitiveName(cnode);
if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
@@ -675,7 +675,7 @@ void ParameterServer::CacheEmbeddingTableParamPtr() {
auto cnodes = func_graph_->GetOrderedCnodes();
for (auto cnode : cnodes) {
MS_EXCEPTION_IF_NULL(cnode);
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
std::string cnode_name = Util::GetPrimitiveName(cnode);
if (cnode_name != kGatherV2OpName && cnode_name != kSparseGatherV2OpName) {
continue;
}


+ 18
- 1
mindspore/ccsrc/ps/util.cc View File

@@ -148,6 +148,23 @@ WeightPtr Util::MakeWeightPtr(const std::shared_ptr<std::vector<float>> &data, b
return weight_ptr;
}

std::string Util::GetPrimitiveName(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of node " << cnode->fullname_with_scope() << " is empty.";
return "";
}
auto fn = inputs[0];
if (!IsValueNode<Primitive>(fn)) {
return "";
}

auto node_prim = GetValueNode<PrimitivePtr>(fn);
MS_EXCEPTION_IF_NULL(node_prim);
return node_prim->name();
}

void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
const std::string &fused_cnode_name) {
MS_EXCEPTION_IF_NULL(func_graph);
@@ -158,7 +175,7 @@ void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_nam
std::vector<int64_t> indices;
for (const AnfNodePtr &node : node_list) {
if (node != nullptr && node->isa<CNode>()) {
if (AnfAlgo::GetCNodeName(node) == cnode_name) {
if (GetPrimitiveName(node->cast<CNodePtr>()) == cnode_name) {
single_nodes.push_back(node);

auto weight_name_value_node =


+ 1
- 0
mindspore/ccsrc/ps/util.h View File

@@ -59,6 +59,7 @@ class Util {
static bool FuseServerCommOps(const pipeline::ResourcePtr &res);
static WeightPtr MakeWeightPtr(const std::shared_ptr<std::vector<float>> &data, bool enable_recovery,
const std::shared_ptr<std::vector<int>> &shape = nullptr);
static std::string GetPrimitiveName(const CNodePtr &cnode);

private:
static void DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,


Loading…
Cancel
Save