| @@ -45,9 +45,9 @@ namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| arithmetic_simplify_ = MakeSubstitution( | |||
| ArithmeticSimplify(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum}); | |||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | |||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | |||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | |||
| @@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor { | |||
| } | |||
| }; | |||
| // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> | |||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | |||
| class ConstantDuplicateMul : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| // {prim::kPrimMul, Tensor1, {...}} | |||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | |||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto tensor1 = vnode_; | |||
| auto mul = cnode_; | |||
| Reset(); | |||
| // {prim::kPrimMul, Tensor2, {...}} | |||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | |||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto tensor2 = vnode_; | |||
| auto cnode = cnode_; | |||
| auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | |||
| auto fg = node->func_graph(); | |||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||
| return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (IsValueNode<tensor::Tensor>(node)) { | |||
| vnode_ = node; | |||
| } | |||
| if (IsCNode(node)) { | |||
| cnode_ = node->cast<CNodePtr>(); | |||
| } | |||
| } | |||
| void Reset() { | |||
| vnode_ = nullptr; | |||
| cnode_ = nullptr; | |||
| } | |||
| private: | |||
| AnfNodePtr vnode_; | |||
| CNodePtr cnode_; | |||
| }; | |||
| class ArithmeticSimplify { | |||
| public: | |||
| ArithmeticSimplify() | |||
| @@ -186,12 +235,14 @@ class ArithmeticSimplify { | |||
| add_by_zero_(), | |||
| tensor_add_by_zero_(), | |||
| identity_(prim::kPrimIdentity), | |||
| opt_update_zero_tensor_() { | |||
| opt_update_zero_tensor_(), | |||
| constant_duplicate_mul_() { | |||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | |||
| eliminaters_.emplace_back(add_by_zero_); | |||
| eliminaters_.emplace_back(tensor_add_by_zero_); | |||
| eliminaters_.emplace_back(identity_); | |||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | |||
| eliminaters_.emplace_back(constant_duplicate_mul_); | |||
| } | |||
| ~ArithmeticSimplify() = default; | |||
| @@ -212,6 +263,7 @@ class ArithmeticSimplify { | |||
| TensorAddByZero tensor_add_by_zero_; | |||
| PrimEliminater identity_; | |||
| OptUpdateZeroTensor opt_update_zero_tensor_; | |||
| ConstantDuplicateMul constant_duplicate_mul_; | |||
| std::vector<TransformFuncType> eliminaters_{}; | |||
| }; | |||
| } // namespace irpass | |||
| @@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu | |||
| auto a2 = GetValueNode(node2); | |||
| if (a1->isa<Primitive>() && a2->isa<Primitive>()) { | |||
| return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); | |||
| } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) { | |||
| return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>())); | |||
| } else { | |||
| return *a1 == *a2; | |||
| } | |||
| @@ -771,6 +771,14 @@ class Mul(_MathBinaryOp): | |||
| >>> mul(input_x, input_y) | |||
| [4, 10, 18] | |||
| """ | |||
| def infer_value(self, x, y): | |||
| if x is not None and y is not None: | |||
| x = x.asnumpy() | |||
| y = y.asnumpy() | |||
| out = x * y | |||
| out = np.array(out, x.dtype) | |||
| return Tensor(out) | |||
| return None | |||
| class Square(PrimitiveWithInfer): | |||
| @@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) { | |||
| ASSERT_TRUE(CheckOpt(before2, after2, patterns)); | |||
| ASSERT_TRUE(CheckOpt(before3, before3, patterns)); | |||
| } | |||
| TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell"); | |||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr"); | |||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl"); | |||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr"); | |||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after"); | |||
| auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); | |||
| ASSERT_TRUE(CheckOpt(beforell, after, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforelr, after, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | |||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -16,6 +16,8 @@ | |||
| from mindspore.ops import Primitive, PrimitiveWithInfer | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore import Tensor | |||
| import numpy as np | |||
| # pylint: disable=unused-variable | |||
| @@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag): | |||
| return print_(make_tuple(x, y, z)) | |||
| return fns[tag] | |||
| def test_constant_duplicate_mul(tag): | |||
| fns = FnDict() | |||
| Mul = Primitive('Mul'); | |||
| Sqrt = Primitive('Sqrt'); | |||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | |||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||
| tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32')) | |||
| @fns | |||
| def beforell(): | |||
| return Mul(tensor1, Mul(tensor2, Sqrt(x))) | |||
| @fns | |||
| def beforelr(): | |||
| return Mul(tensor1, Mul(Sqrt(x), tensor2)) | |||
| @fns | |||
| def beforerl(): | |||
| return Mul(Mul(Sqrt(x), tensor2), tensor1) | |||
| @fns | |||
| def beforerr(): | |||
| return Mul(Mul(Sqrt(x), tensor2), tensor1) | |||
| @fns | |||
| def after(): | |||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | |||
| return fns[tag] | |||