|
|
|
@@ -29,6 +29,8 @@ |
|
|
|
#include "frontend/optimizer/irpass.h" |
|
|
|
#include "frontend/optimizer/irpass/prim_eliminate.h" |
|
|
|
#include "frontend/optimizer/optimizer.h" |
|
|
|
#include "utils/comm_manager.h" |
|
|
|
#include "frontend/parallel/context.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
@@ -203,6 +205,57 @@ class DependValueElim : public OptimizerCaller { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class AllReduceConstElim : public OptimizerCaller { |
|
|
|
public: |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
PatternNode<AnfNodePtr> x; |
|
|
|
auto pattern = PPrimitive(prim::kPrimAllReduce, x); |
|
|
|
// If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode) |
|
|
|
if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && |
|
|
|
(pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || |
|
|
|
pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { |
|
|
|
auto cur_func_graph = pattern.GetFuncGraph(); |
|
|
|
// If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant |
|
|
|
auto prim_cnode = pattern.GetOriginalNode(); |
|
|
|
MS_EXCEPTION_IF_NULL(prim_cnode); |
|
|
|
auto primitive = GetCNodePrimitive(prim_cnode); |
|
|
|
auto reduce_op = primitive->GetAttr("op"); |
|
|
|
auto group = primitive->GetAttr("group")->ToString(); |
|
|
|
// For sum operation, multiply constant tensor by number of devices |
|
|
|
if (reduce_op->ToString() == "sum") { |
|
|
|
unsigned int num_of_devices; |
|
|
|
// Get number of devices |
|
|
|
if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) { |
|
|
|
MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]"; |
|
|
|
} |
|
|
|
// Multiply constant by number of devices then return |
|
|
|
std::vector<AnfNodePtr> mul_inputs; |
|
|
|
auto constant_node = x.GetNode(node); |
|
|
|
MS_EXCEPTION_IF_NULL(constant_node); |
|
|
|
auto constant_value_node = constant_node->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(constant_value_node); |
|
|
|
if (!constant_value_node->value()->isa<tensor::Tensor>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " + |
|
|
|
constant_value_node->value()->ToString(); |
|
|
|
} |
|
|
|
auto constant_tensor = constant_value_node->value()->cast<tensor::TensorPtr>(); |
|
|
|
auto tensor_dtype = constant_tensor->Dtype(); |
|
|
|
auto num_of_device_node = NewValueNode(std::make_shared<tensor::Tensor>((int64_t)num_of_devices, tensor_dtype)); |
|
|
|
// Multiply nodes |
|
|
|
auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional"); |
|
|
|
MS_EXCEPTION_IF_NULL(mul_prim); |
|
|
|
mul_inputs.push_back(NewValueNode(mul_prim)); |
|
|
|
mul_inputs.push_back(constant_node); |
|
|
|
mul_inputs.push_back(num_of_device_node); |
|
|
|
return cur_func_graph->NewCNode(mul_inputs); |
|
|
|
} else { |
|
|
|
return x.GetNode(node); |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
}; |
|
|
|
} // namespace irpass |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |
|
|
|
|