|
|
|
@@ -1645,8 +1645,36 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) { |
|
|
|
FuncGraphManagerPtr manager = node->func_graph()->manager(); |
|
|
|
MS_EXCEPTION_IF_NULL(manager); |
|
|
|
AnfNodeIndexSet node_set = manager->node_users()[node]; |
|
|
|
for (auto &node_pair : node_set) { |
|
|
|
CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); |
|
|
|
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node); |
|
|
|
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(node_prim); |
|
|
|
if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) { |
|
|
|
auto layout = GetInputLayoutFromCNode(node_pair); |
|
|
|
return std::make_shared<TensorLayout>(layout); |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { |
|
|
|
// Create DataParallel tensor layout for parameter(support WideDeep). |
|
|
|
auto next_layout = FindParameterNextLayout(node); |
|
|
|
if (next_layout != nullptr) { |
|
|
|
return next_layout; |
|
|
|
} |
|
|
|
CheckGlobalDeviceManager(); |
|
|
|
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); |
|
|
|
TensorLayout input_tensor_layout; |
|
|
|
|