|
|
|
@@ -13,14 +13,12 @@ |
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
#include "pre_activate/pass/allreduce_fusion.h"
|
|
|
|
#include "pre_activate/pass/communication_op_fusion.h"
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
#include <string>
|
|
|
|
#include <memory>
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
#include "utils/graph_utils.h"
|
|
|
|
#include "operator/ops.h"
|
|
|
|
#include "device/kernel_info.h"
|
|
|
|
@@ -31,9 +29,12 @@ |
|
|
|
namespace mindspore {
|
|
|
|
namespace opt {
|
|
|
|
namespace {
|
|
|
|
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allreduce_node_info, size_t start_index,
|
|
|
|
constexpr auto kAttrDefaultGroup = "default_group";
|
|
|
|
constexpr auto kAttrDefaultOp = "default_op";
|
|
|
|
|
|
|
|
kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index,
|
|
|
|
size_t end_index) {
|
|
|
|
if (end_index >= allreduce_node_info.allreduce_node.size()) {
|
|
|
|
if (end_index >= communication_op_info.communication_op_nodes.size()) {
|
|
|
|
MS_LOG(EXCEPTION) << "end index out of vector size";
|
|
|
|
}
|
|
|
|
std::vector<std::string> inputs_device_format;
|
|
|
|
@@ -43,7 +44,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred |
|
|
|
std::vector<std::vector<size_t>> outputs_shape;
|
|
|
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
|
|
|
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
|
|
|
auto cnode = allreduce_node_info.allreduce_node[idx];
|
|
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) {
|
|
|
|
inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index));
|
|
|
|
@@ -64,14 +65,38 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred |
|
|
|
builder.SetOutputsDeviceType(outputs_device_type);
|
|
|
|
return builder.Build();
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion);
|
|
|
|
if (attr_fusion == nullptr) {
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
int fusion = GetValue<int>(attr_fusion);
|
|
|
|
if (fusion == 0) {
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
std::string group = kAttrDefaultGroup;
|
|
|
|
ValuePtr attr_group = primitive->GetAttr(kAttrGroup);
|
|
|
|
if (attr_group != nullptr) {
|
|
|
|
group = GetValue<std::string>(attr_group);
|
|
|
|
}
|
|
|
|
std::string op = kAttrDefaultOp;
|
|
|
|
ValuePtr attr_op = primitive->GetAttr(kAttrOp);
|
|
|
|
if (attr_op != nullptr) {
|
|
|
|
op = GetValue<std::string>(attr_op);
|
|
|
|
}
|
|
|
|
return group + op + std::to_string(fusion);
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num,
|
|
|
|
std::vector<size_t> *segment_index) const {
|
|
|
|
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
|
|
|
std::vector<size_t> *segment_index) const {
|
|
|
|
MS_EXCEPTION_IF_NULL(segment_num);
|
|
|
|
MS_EXCEPTION_IF_NULL(segment_index);
|
|
|
|
size_t allreduce_node_size = allreduce_node_info.allreduce_node.size();
|
|
|
|
MS_LOG(INFO) << "graph all reduce node size " << allreduce_node_size;
|
|
|
|
size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
|
|
|
|
MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size;
|
|
|
|
|
|
|
|
auto parallel_context = parallel::ParallelContext::GetInstance();
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel_context);
|
|
|
|
@@ -82,30 +107,31 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf |
|
|
|
uint32_t last_index = 0;
|
|
|
|
for (size_t i = 0; i < split_indices.size(); ++i) {
|
|
|
|
uint32_t index = split_indices[i];
|
|
|
|
if (index <= last_index || index >= allreduce_node_size) {
|
|
|
|
MS_LOG(EXCEPTION) << "invalid allreduce split index " << i << " " << index;
|
|
|
|
if (index <= last_index || index >= communication_op_node_size) {
|
|
|
|
MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index;
|
|
|
|
}
|
|
|
|
segment_index->push_back(index);
|
|
|
|
last_index = index;
|
|
|
|
segments++;
|
|
|
|
}
|
|
|
|
if (last_index != allreduce_node_size - 1) {
|
|
|
|
segment_index->push_back(allreduce_node_size - 1);
|
|
|
|
if (last_index != communication_op_node_size - 1) {
|
|
|
|
segment_index->push_back(communication_op_node_size - 1);
|
|
|
|
segments++;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
segments = groups_;
|
|
|
|
for (size_t i = 0; i < segments - 1; ++i) {
|
|
|
|
segment_index->push_back((i + 1) * (allreduce_node_size / segments) - 1);
|
|
|
|
segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1);
|
|
|
|
}
|
|
|
|
segment_index->push_back(allreduce_node_size - 1);
|
|
|
|
segment_index->push_back(communication_op_node_size - 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (segments >= allreduce_node_size) {
|
|
|
|
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments << ", allreduce_node_size=" << allreduce_node_size;
|
|
|
|
if (segments >= communication_op_node_size) {
|
|
|
|
MS_LOG(INFO) << "fusion not changed: segment_num=" << segments
|
|
|
|
<< ", communication_op_node_size=" << communication_op_node_size;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (segment_index->at(segments - 1) != allreduce_node_size - 1) {
|
|
|
|
if (segment_index->at(segments - 1) != communication_op_node_size - 1) {
|
|
|
|
MS_LOG(EXCEPTION) << "the last segment index is invalid.";
|
|
|
|
}
|
|
|
|
for (size_t i = 0; i < segments - 1; ++i) {
|
|
|
|
@@ -118,19 +144,19 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf |
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph,
|
|
|
|
const AllReduceInfo_t &allreduce_node_info, size_t start_index,
|
|
|
|
size_t end_index) const {
|
|
|
|
AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph,
|
|
|
|
const CommunicationOpInfo &communication_op_info,
|
|
|
|
size_t start_index, size_t end_index) const {
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
auto prim = std::make_shared<Primitive>(kAllReduceOpName);
|
|
|
|
auto prim = std::make_shared<Primitive>(op_name_);
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
std::vector<AnfNodePtr> fusion_inputs = {NewValueNode(prim)};
|
|
|
|
// get all inputs of current segment
|
|
|
|
if (end_index >= allreduce_node_info.allreduce_node.size()) {
|
|
|
|
if (end_index >= communication_op_info.communication_op_nodes.size()) {
|
|
|
|
MS_LOG(EXCEPTION) << "end index out of vector size";
|
|
|
|
}
|
|
|
|
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
|
|
|
auto cnode = allreduce_node_info.allreduce_node[idx];
|
|
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
|
|
|
}
|
|
|
|
@@ -141,14 +167,14 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph, |
|
|
|
fused_node->set_kernel_info(kernel_info);
|
|
|
|
AbstractBasePtrList abstract_list;
|
|
|
|
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
|
|
|
auto cnode = allreduce_node_info.allreduce_node[idx];
|
|
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node);
|
|
|
|
AnfAlgo::CopyNodeAttr("op", cnode, fused_node);
|
|
|
|
AnfAlgo::CopyNodeAttr("group", cnode, fused_node);
|
|
|
|
abstract_list.push_back(cnode->abstract());
|
|
|
|
}
|
|
|
|
auto kernel_build_info = GenerateKernelBuildInfo(allreduce_node_info, start_index, end_index);
|
|
|
|
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get());
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
|
|
|
@@ -156,8 +182,8 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph, |
|
|
|
return fused_node;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info,
|
|
|
|
size_t segment_num, const std::vector<size_t> &segment_index) const {
|
|
|
|
bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info,
|
|
|
|
size_t segment_num, const std::vector<size_t> &segment_index) const {
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
@@ -169,12 +195,13 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn |
|
|
|
start_index = end_index + 1;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
AnfNodePtr new_allreduce = CreateFusedAllReduce(func_graph, allreduce_node_info, start_index, end_index);
|
|
|
|
// replace old allreduce with new allreduce
|
|
|
|
AnfNodePtr new_communication_op =
|
|
|
|
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
|
|
|
|
// replace old communication op with new communication op
|
|
|
|
for (auto idx = start_index; idx <= end_index; ++idx) {
|
|
|
|
std::vector<AnfNodePtr> tuple_getitem_input;
|
|
|
|
tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
|
|
|
tuple_getitem_input.push_back(new_allreduce);
|
|
|
|
tuple_getitem_input.push_back(new_communication_op);
|
|
|
|
auto index = NewValueNode(SizeToInt(idx - start_index));
|
|
|
|
MS_EXCEPTION_IF_NULL(index);
|
|
|
|
auto imm = std::make_shared<Int32Imm>(idx - start_index);
|
|
|
|
@@ -185,10 +212,10 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn |
|
|
|
tuple_getitem_input.push_back(index);
|
|
|
|
AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input);
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
auto allreduce_node_item = allreduce_node_info.allreduce_node.at(idx);
|
|
|
|
MS_EXCEPTION_IF_NULL(allreduce_node_item);
|
|
|
|
tuple_getitem->set_abstract(allreduce_node_item->abstract());
|
|
|
|
if (!manager->Replace(allreduce_node_item, tuple_getitem)) {
|
|
|
|
auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx);
|
|
|
|
MS_EXCEPTION_IF_NULL(communication_op_node_item);
|
|
|
|
tuple_getitem->set_abstract(communication_op_node_item->abstract());
|
|
|
|
if (!manager->Replace(communication_op_node_item, tuple_getitem)) {
|
|
|
|
MS_LOG(EXCEPTION) << "manager replace node failed";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@@ -198,29 +225,24 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn |
|
|
|
return changed;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
const float input_grad_size_num = 0.0;
|
|
|
|
const float input_grad_time_num = 0.0;
|
|
|
|
// divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion
|
|
|
|
std::unordered_map<std::string, AllReduceInfo_t> candidate_groups;
|
|
|
|
std::unordered_map<std::string, CommunicationOpInfo> candidate_groups;
|
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
|
|
|
for (auto &node : node_list) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) {
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
int fusion = GetValue<int>(primitive->GetAttr("fusion"));
|
|
|
|
if (fusion == 0) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == op_name_) {
|
|
|
|
std::string key = GetFusionGroupKey(node);
|
|
|
|
if (key.empty()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
std::string group = GetValue<std::string>(primitive->GetAttr("group"));
|
|
|
|
std::string op = GetValue<std::string>(primitive->GetAttr("op"));
|
|
|
|
std::string key = group + op + std::to_string(fusion);
|
|
|
|
if (candidate_groups.find(key) == candidate_groups.end()) {
|
|
|
|
AllReduceInfo_t allreduce_node_info;
|
|
|
|
candidate_groups[key] = allreduce_node_info;
|
|
|
|
CommunicationOpInfo communication_op_info;
|
|
|
|
candidate_groups[key] = communication_op_info;
|
|
|
|
}
|
|
|
|
candidate_groups[key].allreduce_node.push_back(node->cast<CNodePtr>());
|
|
|
|
candidate_groups[key].communication_op_nodes.push_back(node->cast<CNodePtr>());
|
|
|
|
candidate_groups[key].input_grad_size.push_back(input_grad_size_num);
|
|
|
|
candidate_groups[key].input_grad_time.push_back(input_grad_time_num);
|
|
|
|
}
|
|
|
|
@@ -228,7 +250,7 @@ bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) { |
|
|
|
// split candidate group to segments according to _group class member
|
|
|
|
bool changed = false;
|
|
|
|
for (auto &it : candidate_groups) {
|
|
|
|
if (it.second.allreduce_node.size() <= 1) {
|
|
|
|
if (it.second.communication_op_nodes.size() <= 1) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
size_t segment_num = 0;
|