| @@ -45,9 +45,9 @@ namespace mindspore { | |||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| OptimizeIRPassLib::OptimizeIRPassLib() { | OptimizeIRPassLib::OptimizeIRPassLib() { | ||||
| arithmetic_simplify_ = MakeSubstitution( | |||||
| ArithmeticSimplify(), "arithmetic_simplify", | |||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, prim::kPrimIdentity, prim::kPrimMomentum}); | |||||
| arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", | |||||
| {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, | |||||
| prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); | |||||
| special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", | ||||
| {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, | ||||
| prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); | ||||
| @@ -179,6 +179,55 @@ class OptUpdateZeroTensor : public AnfVisitor { | |||||
| } | } | ||||
| }; | }; | ||||
| // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> | |||||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | |||||
| class ConstantDuplicateMul : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||||
| Reset(); | |||||
| // {prim::kPrimMul, Tensor1, {...}} | |||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | |||||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor1 = vnode_; | |||||
| auto mul = cnode_; | |||||
| Reset(); | |||||
| // {prim::kPrimMul, Tensor2, {...}} | |||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | |||||
| if (vnode_ == nullptr || cnode_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor2 = vnode_; | |||||
| auto cnode = cnode_; | |||||
| auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | |||||
| auto fg = node->func_graph(); | |||||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||||
| return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); | |||||
| } | |||||
| void Visit(const AnfNodePtr &node) override { | |||||
| if (IsValueNode<tensor::Tensor>(node)) { | |||||
| vnode_ = node; | |||||
| } | |||||
| if (IsCNode(node)) { | |||||
| cnode_ = node->cast<CNodePtr>(); | |||||
| } | |||||
| } | |||||
| void Reset() { | |||||
| vnode_ = nullptr; | |||||
| cnode_ = nullptr; | |||||
| } | |||||
| private: | |||||
| AnfNodePtr vnode_; | |||||
| CNodePtr cnode_; | |||||
| }; | |||||
| class ArithmeticSimplify { | class ArithmeticSimplify { | ||||
| public: | public: | ||||
| ArithmeticSimplify() | ArithmeticSimplify() | ||||
| @@ -186,12 +235,14 @@ class ArithmeticSimplify { | |||||
| add_by_zero_(), | add_by_zero_(), | ||||
| tensor_add_by_zero_(), | tensor_add_by_zero_(), | ||||
| identity_(prim::kPrimIdentity), | identity_(prim::kPrimIdentity), | ||||
| opt_update_zero_tensor_() { | |||||
| opt_update_zero_tensor_(), | |||||
| constant_duplicate_mul_() { | |||||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | eliminaters_.emplace_back(multiply_by_zero_or_one_); | ||||
| eliminaters_.emplace_back(add_by_zero_); | eliminaters_.emplace_back(add_by_zero_); | ||||
| eliminaters_.emplace_back(tensor_add_by_zero_); | eliminaters_.emplace_back(tensor_add_by_zero_); | ||||
| eliminaters_.emplace_back(identity_); | eliminaters_.emplace_back(identity_); | ||||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | eliminaters_.emplace_back(opt_update_zero_tensor_); | ||||
| eliminaters_.emplace_back(constant_duplicate_mul_); | |||||
| } | } | ||||
| ~ArithmeticSimplify() = default; | ~ArithmeticSimplify() = default; | ||||
| @@ -212,6 +263,7 @@ class ArithmeticSimplify { | |||||
| TensorAddByZero tensor_add_by_zero_; | TensorAddByZero tensor_add_by_zero_; | ||||
| PrimEliminater identity_; | PrimEliminater identity_; | ||||
| OptUpdateZeroTensor opt_update_zero_tensor_; | OptUpdateZeroTensor opt_update_zero_tensor_; | ||||
| ConstantDuplicateMul constant_duplicate_mul_; | |||||
| std::vector<TransformFuncType> eliminaters_{}; | std::vector<TransformFuncType> eliminaters_{}; | ||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| @@ -400,6 +400,8 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu | |||||
| auto a2 = GetValueNode(node2); | auto a2 = GetValueNode(node2); | ||||
| if (a1->isa<Primitive>() && a2->isa<Primitive>()) { | if (a1->isa<Primitive>() && a2->isa<Primitive>()) { | ||||
| return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); | return a1->cast<PrimitivePtr>()->name() == a2->cast<PrimitivePtr>()->name(); | ||||
| } else if (a1->isa<tensor::Tensor>() && a2->isa<tensor::Tensor>()) { | |||||
| return a1->cast<tensor::TensorPtr>()->ValueEqual(*(a2->cast<tensor::TensorPtr>())); | |||||
| } else { | } else { | ||||
| return *a1 == *a2; | return *a1 == *a2; | ||||
| } | } | ||||
| @@ -771,6 +771,14 @@ class Mul(_MathBinaryOp): | |||||
| >>> mul(input_x, input_y) | >>> mul(input_x, input_y) | ||||
| [4, 10, 18] | [4, 10, 18] | ||||
| """ | """ | ||||
| def infer_value(self, x, y): | |||||
| if x is not None and y is not None: | |||||
| x = x.asnumpy() | |||||
| y = y.asnumpy() | |||||
| out = x * y | |||||
| out = np.array(out, x.dtype) | |||||
| return Tensor(out) | |||||
| return None | |||||
| class Square(PrimitiveWithInfer): | class Square(PrimitiveWithInfer): | ||||
| @@ -543,5 +543,18 @@ TEST_F(TestOptLib, test_print_tuple_wrapper) { | |||||
| ASSERT_TRUE(CheckOpt(before2, after2, patterns)); | ASSERT_TRUE(CheckOpt(before2, after2, patterns)); | ||||
| ASSERT_TRUE(CheckOpt(before3, before3, patterns)); | ASSERT_TRUE(CheckOpt(before3, before3, patterns)); | ||||
| } | } | ||||
| TEST_F(TestOptLib, test_constant_duplicate_mul) { | |||||
| FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforell"); | |||||
| FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforelr"); | |||||
| FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerl"); | |||||
| FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "beforerr"); | |||||
| FuncGraphPtr after = getPyFun.CallAndParseRet("test_constant_duplicate_mul", "after"); | |||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_}); | |||||
| ASSERT_TRUE(CheckOpt(beforell, after, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforelr, after, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); | |||||
| ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,6 +16,8 @@ | |||||
| from mindspore.ops import Primitive, PrimitiveWithInfer | from mindspore.ops import Primitive, PrimitiveWithInfer | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.operations import _grad_ops as G | from mindspore.ops.operations import _grad_ops as G | ||||
| from mindspore import Tensor | |||||
| import numpy as np | |||||
| # pylint: disable=unused-variable | # pylint: disable=unused-variable | ||||
| @@ -903,3 +905,34 @@ def test_print_tuple_wrapper(tag): | |||||
| return print_(make_tuple(x, y, z)) | return print_(make_tuple(x, y, z)) | ||||
| return fns[tag] | return fns[tag] | ||||
| def test_constant_duplicate_mul(tag): | |||||
| fns = FnDict() | |||||
| Mul = Primitive('Mul'); | |||||
| Sqrt = Primitive('Sqrt'); | |||||
| x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) | |||||
| tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) | |||||
| tensor2 = Tensor(np.array([[2.2, 3.1], [3.2, 4.2]]).astype('float32')) | |||||
| @fns | |||||
| def beforell(): | |||||
| return Mul(tensor1, Mul(tensor2, Sqrt(x))) | |||||
| @fns | |||||
| def beforelr(): | |||||
| return Mul(tensor1, Mul(Sqrt(x), tensor2)) | |||||
| @fns | |||||
| def beforerl(): | |||||
| return Mul(Mul(Sqrt(x), tensor2), tensor1) | |||||
| @fns | |||||
| def beforerr(): | |||||
| return Mul(Mul(Sqrt(x), tensor2), tensor1) | |||||
| @fns | |||||
| def after(): | |||||
| return Mul(Sqrt(x), Mul(tensor1, tensor2)) | |||||
| return fns[tag] | |||||