|
|
|
@@ -139,76 +139,8 @@ class CheckTensorConstant { |
|
|
|
int check_value_; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} |
|
|
|
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} |
|
|
|
class TensorMultiplyByZeroOrOne : public AnfVisitor { |
|
|
|
public: |
|
|
|
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {} |
|
|
|
~TensorMultiplyByZeroOrOne() override = default; |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node); |
|
|
|
|
|
|
|
if (is_zero_) { |
|
|
|
if (x_->func_graph() != node->func_graph()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return NewTensorFilledWithData(node); |
|
|
|
} |
|
|
|
if (is_one_) { |
|
|
|
return NewTensorFilledWithData(node, x_); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_zero_ || is_one_) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsParam(node)) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsCNode(node)) { |
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto value = node->cast<ValueNodePtr>()->value(); |
|
|
|
if (CheckTensorConstant(0).IsTensorConstant(value)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} else if (CheckTensorConstant(1).IsTensorConstant(value)) { |
|
|
|
is_one_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = node; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
auto value = vnode->value(); |
|
|
|
if (CheckTensorConstant(0).IsTensorConstant(value)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} else if (CheckTensorConstant(1).IsTensorConstant(value)) { |
|
|
|
is_one_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = vnode; |
|
|
|
} |
|
|
|
void Reset() { |
|
|
|
x_ = nullptr; |
|
|
|
is_one_ = false; |
|
|
|
is_zero_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
class TensorMultiplyBase : public AnfVisitor { |
|
|
|
protected: |
|
|
|
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) { |
|
|
|
if (!node->isa<ValueNode>()) { |
|
|
|
return nullptr; |
|
|
|
@@ -287,10 +219,122 @@ class TensorMultiplyByZeroOrOne : public AnfVisitor { |
|
|
|
return new_vnode; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr x_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} |
|
|
|
class TensorMultiplyByZero : public TensorMultiplyBase { |
|
|
|
public: |
|
|
|
TensorMultiplyByZero() : zero_(MakeValue(0)) {} |
|
|
|
~TensorMultiplyByZero() override = default; |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node); |
|
|
|
|
|
|
|
if (is_zero_) { |
|
|
|
if (x_->func_graph() != node->func_graph()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return NewTensorFilledWithData(node); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_zero_) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsParam(node)) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsCNode(node)) { |
|
|
|
CNodePtr cnode = node->cast<CNodePtr>(); |
|
|
|
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto value = node->cast<ValueNodePtr>()->value(); |
|
|
|
if (CheckTensorConstant(0).IsTensorConstant(value)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = node; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
auto value = vnode->value(); |
|
|
|
if (CheckTensorConstant(0).IsTensorConstant(value)) { |
|
|
|
is_zero_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = vnode; |
|
|
|
} |
|
|
|
void Reset() { |
|
|
|
x_ = nullptr; |
|
|
|
is_zero_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
bool is_zero_{false}, is_one_{false}; |
|
|
|
bool is_zero_{false}; |
|
|
|
ValuePtr zero_; |
|
|
|
AnfNodePtr x_{nullptr}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} |
|
|
|
class TensorMultiplyByOne : public TensorMultiplyBase { |
|
|
|
public: |
|
|
|
TensorMultiplyByOne() {} |
|
|
|
~TensorMultiplyByOne() override = default; |
|
|
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { |
|
|
|
Reset(); |
|
|
|
AnfVisitor::Match(prim::kPrimMul)(node); |
|
|
|
|
|
|
|
if (is_one_) { |
|
|
|
return NewTensorFilledWithData(node, x_); |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const AnfNodePtr &node) override { |
|
|
|
if (is_one_) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsParam(node) || IsCNode(node)) { |
|
|
|
x_ = node; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
auto value = node->cast<ValueNodePtr>()->value(); |
|
|
|
if (CheckTensorConstant(1).IsTensorConstant(value)) { |
|
|
|
is_one_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = node; |
|
|
|
} |
|
|
|
|
|
|
|
void Visit(const ValueNodePtr &vnode) override { |
|
|
|
auto value = vnode->value(); |
|
|
|
if (CheckTensorConstant(1).IsTensorConstant(value)) { |
|
|
|
is_one_ = true; |
|
|
|
return; |
|
|
|
} |
|
|
|
x_ = vnode; |
|
|
|
} |
|
|
|
void Reset() { |
|
|
|
x_ = nullptr; |
|
|
|
is_one_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
bool is_one_{false}; |
|
|
|
}; |
|
|
|
|
|
|
|
// {prim::kPrimScalarAdd, X, 0} |
|
|
|
@@ -699,7 +743,7 @@ class ArithmeticSimplify { |
|
|
|
public: |
|
|
|
ArithmeticSimplify() |
|
|
|
: multiply_by_zero_or_one_(), |
|
|
|
tensor_multiply_by_zero_or_one_(), |
|
|
|
tensor_multiply_by_one_(), |
|
|
|
add_by_zero_(), |
|
|
|
tensor_add_by_zero_(), |
|
|
|
identity_(prim::kPrimIdentity), |
|
|
|
@@ -707,7 +751,7 @@ class ArithmeticSimplify { |
|
|
|
constant_duplicate_mul_(), |
|
|
|
power_one_() { |
|
|
|
eliminaters_.emplace_back(multiply_by_zero_or_one_); |
|
|
|
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_); |
|
|
|
eliminaters_.emplace_back(tensor_multiply_by_one_); |
|
|
|
eliminaters_.emplace_back(add_by_zero_); |
|
|
|
eliminaters_.emplace_back(tensor_add_by_zero_); |
|
|
|
eliminaters_.emplace_back(identity_); |
|
|
|
@@ -730,7 +774,7 @@ class ArithmeticSimplify { |
|
|
|
|
|
|
|
private: |
|
|
|
MultiplyByZeroOrOne multiply_by_zero_or_one_; |
|
|
|
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_; |
|
|
|
TensorMultiplyByOne tensor_multiply_by_one_; |
|
|
|
AddByZero add_by_zero_; |
|
|
|
TensorAddByZero tensor_add_by_zero_; |
|
|
|
PrimEliminater identity_; |
|
|
|
@@ -739,6 +783,32 @@ class ArithmeticSimplify { |
|
|
|
PowerOneEliminate power_one_; |
|
|
|
std::vector<TransformFuncType> eliminaters_{}; |
|
|
|
}; |
|
|
|
|
|
|
|
// Arithmetic Simplifications should be done after step_parallel. |
|
|
|
// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor |
|
|
|
// with shape(weight), but after step_parallel, shape of weight may be changed, so the |
|
|
|
// shape of the constant tensor should also be changed. So this pass is seperated from |
|
|
|
// ArithmeticSimplify and deferred until step_parallel. |
|
|
|
class ArithmeticSimplify2 { |
|
|
|
public: |
|
|
|
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } |
|
|
|
~ArithmeticSimplify2() = default; |
|
|
|
|
|
|
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { |
|
|
|
AnfNodePtr new_node; |
|
|
|
for (auto &eliminater : eliminaters_) { |
|
|
|
new_node = eliminater(optimizer, node); |
|
|
|
if (new_node != nullptr) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
TensorMultiplyByZero tensor_multiply_by_zero_; |
|
|
|
std::vector<TransformFuncType> eliminaters_{}; |
|
|
|
}; |
|
|
|
} // namespace irpass |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |
|
|
|
|