Browse Source

Eliminate AllReduce when the input is a constant

tags/v0.7.0-beta
BowenK 5 years ago
parent
commit
f3a9fbdd78
4 changed files with 57 additions and 0 deletions
  1. +2
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +1
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +53
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h
  4. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc

+ 2
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -83,6 +83,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
reset_defer_inline_ = reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>); MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend); depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
all_reduce_const_elim_ =
MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);


// Env Item Eliminate // Env Item Eliminate
env_get_item_eliminate_ = env_get_item_eliminate_ =


+ 1
- 0
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -50,6 +50,7 @@ class OptimizeIRPassLib {
SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr check_bprop_eliminate_;
SubstitutionPtr reset_defer_inline_; SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_; SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;


// Env Item Eliminate // Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr env_get_item_eliminate_;


+ 53
- 0
mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h View File

@@ -29,6 +29,8 @@
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h" #include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"


namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -203,6 +205,57 @@ class DependValueElim : public OptimizerCaller {
return nullptr; return nullptr;
} }
}; };

class AllReduceConstElim : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x;
auto pattern = PPrimitive(prim::kPrimAllReduce, x);
// If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode)
if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) &&
(pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) ||
pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) {
auto cur_func_graph = pattern.GetFuncGraph();
// If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant
auto prim_cnode = pattern.GetOriginalNode();
MS_EXCEPTION_IF_NULL(prim_cnode);
auto primitive = GetCNodePrimitive(prim_cnode);
auto reduce_op = primitive->GetAttr("op");
auto group = primitive->GetAttr("group")->ToString();
// For sum operation, multiply constant tensor by number of devices
if (reduce_op->ToString() == "sum") {
unsigned int num_of_devices;
// Get number of devices
if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) {
MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]";
}
// Multiply constant by number of devices then return
std::vector<AnfNodePtr> mul_inputs;
auto constant_node = x.GetNode(node);
MS_EXCEPTION_IF_NULL(constant_node);
auto constant_value_node = constant_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(constant_value_node);
if (!constant_value_node->value()->isa<tensor::Tensor>()) {
MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " +
constant_value_node->value()->ToString();
}
auto constant_tensor = constant_value_node->value()->cast<tensor::TensorPtr>();
auto tensor_dtype = constant_tensor->Dtype();
auto num_of_device_node = NewValueNode(std::make_shared<tensor::Tensor>((int64_t)num_of_devices, tensor_dtype));
// Multiply nodes
auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional");
MS_EXCEPTION_IF_NULL(mul_prim);
mul_inputs.push_back(NewValueNode(mul_prim));
mul_inputs.push_back(constant_node);
mul_inputs.push_back(num_of_device_node);
return cur_func_graph->NewCNode(mul_inputs);
} else {
return x.GetNode(node);
}
}
return nullptr;
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore


+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -132,6 +132,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_env_getitem_switch_, irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_, irpass.new_env_get_item_,
irpass.depend_value_elim_, irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
}); });
opt::OptPassConfig a_3 = opt::OptPassConfig({ opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_, irpass.arithmetic_simplify2_,


Loading…
Cancel
Save