|
|
|
@@ -276,6 +276,31 @@ bool ExecuteAction(const ResourcePtr& res) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// The parallel primitive related valuenode might be partitioned so that its value changes by device, |
|
|
|
// that will result in a syncronization error due to different executing order. |
|
|
|
// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, |
|
|
|
// the final solution will be proposed later as a parallel feature. |
|
|
|
bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) { |
|
|
|
auto& node_users = res->manager()->node_users(); |
|
|
|
auto& users = node_users[value_node]; |
|
|
|
auto used_by_keep_value_prim = |
|
|
|
std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int>& user) -> bool { |
|
|
|
MS_EXCEPTION_IF_NULL(user.first); |
|
|
|
auto cnode = user.first->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto prim_node = cnode->input(0); |
|
|
|
if (IsValueNode<Primitive>(prim_node)) { |
|
|
|
auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value()); |
|
|
|
// value_node is referenced by some parallel primitive |
|
|
|
return prim->HasAttr("keep_value_node_input"); |
|
|
|
} |
|
|
|
return false; |
|
|
|
}); |
|
|
|
return used_by_keep_value_prim; |
|
|
|
} |
|
|
|
|
|
|
|
bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { |
|
|
|
if (res->func_graph() == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Remove value node duplications error."; |
|
|
|
@@ -287,6 +312,9 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { |
|
|
|
HashCache hash_cache; |
|
|
|
HashValue hashes; |
|
|
|
for (const auto& value_pair : value_nodes) { |
|
|
|
if (KeepValueNodeDuplication(value_pair.first, res)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); |
|
|
|
} |
|
|
|
return true; |
|
|
|
|