|
|
@@ -1287,7 +1287,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr |
|
|
auto allgather = cnode->input(index)->cast<CNodePtr>(); |
|
|
auto allgather = cnode->input(index)->cast<CNodePtr>(); |
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0)); |
|
|
auto attrs = prim->attrs(); |
|
|
auto attrs = prim->attrs(); |
|
|
attrs["fusion"] = MakeValue(1); |
|
|
|
|
|
|
|
|
// enable fusion flag later when it's supported in backend |
|
|
|
|
|
attrs["fusion"] = MakeValue(0); |
|
|
prim->SetAttrs(attrs); |
|
|
prim->SetAttrs(attrs); |
|
|
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); |
|
|
MS_LOG(INFO) << "Parallel optimizer is applied on " << parameter->ToString(); |
|
|
} else { |
|
|
} else { |
|
|
|