Merge pull request !3177 from Giancarlo/pm_arithmetic_simplifytags/v0.6.0-beta
| @@ -14,543 +14,74 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <functional> | |||||
| #include "frontend/optimizer/irpass/arithmetic_simplify.h" | #include "frontend/optimizer/irpass/arithmetic_simplify.h" | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "frontend/optimizer/irpass.h" | |||||
| #include "frontend/optimizer/irpass/prim_eliminate.h" | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} | |||||
| // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} | |||||
| AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimScalarMul)(node); | |||||
| if (is_zero_) { | |||||
| return NewValueNode(zero_); | |||||
| } | |||||
| if (is_one_) { | |||||
| return x_; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { | |||||
| if (is_one_ || node->isa<CNode>()) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| AnfVisitor::Visit(node); | |||||
| if (!is_one_) { | |||||
| x_ = node; | |||||
| } | |||||
| } | |||||
| void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { | |||||
| auto value = vnode->value(); | |||||
| if (*value == *zero_) { | |||||
| is_zero_ = true; | |||||
| } else if (*value == *one_) { | |||||
| is_one_ = true; | |||||
| } | |||||
| } | |||||
| void MultiplyByZeroOrOne::Reset() { | |||||
| x_ = nullptr; | |||||
| is_one_ = false; | |||||
| is_zero_ = false; | |||||
| } | |||||
| // Support class used for checking if all values of a Tensor are equal `check_value_` | |||||
| // Supported data types: double, float/float32, int/int32 | |||||
| bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||||
| float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > FLT_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if (tensor_type == TypeId::kNumberTypeFloat64) { | |||||
| double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > DBL_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { | |||||
| int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (data2[i] != check_value_) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // input Data Types is not supported | |||||
| return false; | |||||
| } | |||||
| bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { | |||||
| return false; | |||||
| } | |||||
| return IsTensorConstant(value); | |||||
| } | |||||
| void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { | |||||
| if (!node->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| return tensor_ptr->data_c(); | |||||
| } | |||||
| // Make a new tensor (when possible) with the same shape as of `node` | |||||
| // If x is nullptr then fill new tensor will "0" | |||||
| // If x is a tensor with empty shape then fill new tensor with the single value of x | |||||
| // If x is a tensor with same shape as `node` then return x as result | |||||
| AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { | |||||
| if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); | |||||
| std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||||
| if (x == nullptr) { | |||||
| memset_s(data, mem_size, 0, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| // x is not nullptr | |||||
| if (x->isa<CNode>()) { | |||||
| if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| std::vector<int> x_shape = x_abstract->shape()->shape(); | |||||
| if (x_shape != tensor_shape) { | |||||
| return nullptr; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| if (!x->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_value = x->cast<ValueNodePtr>()->value(); | |||||
| if (!x_value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value); | |||||
| if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { | |||||
| return nullptr; | |||||
| } | |||||
| char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x)); | |||||
| if (x_tensor_ptr->DataSize() == 1) { | |||||
| for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { | |||||
| memcpy_s(data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr), source_data, | |||||
| GetTypeByte(tensor_type_ptr)); | |||||
| } | |||||
| } else { | |||||
| memcpy_s(data, mem_size, source_data, mem_size); | |||||
| } | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} | |||||
| AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| if (is_zero_) { | |||||
| if (x_->func_graph() != node->func_graph()) { | |||||
| return nullptr; | |||||
| } | |||||
| return NewTensorFilledWithData(node); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { | |||||
| if (is_zero_) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| if (IsParam(node)) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| if (IsCNode(node)) { | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| } | |||||
| void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { | |||||
| auto value = vnode->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = vnode; | |||||
| } | |||||
| void TensorMultiplyByZero::Reset() { | |||||
| x_ = nullptr; | |||||
| is_zero_ = false; | |||||
| } | |||||
| // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} | |||||
| AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimMul)(node); | |||||
| if (is_one_) { | |||||
| return NewTensorFilledWithData(node, x_); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { | |||||
| if (is_one_) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| if (IsParam(node) || IsCNode(node)) { | |||||
| x_ = node; | |||||
| return; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||||
| is_one_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| } | |||||
| void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { | |||||
| auto value = vnode->value(); | |||||
| if (CheckTensorConstant(1).IsTensorConstant(value)) { | |||||
| is_one_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = vnode; | |||||
| } | |||||
| void TensorMultiplyByOne::Reset() { | |||||
| x_ = nullptr; | |||||
| is_one_ = false; | |||||
| } | |||||
| // {prim::kPrimScalarAdd, X, 0} | |||||
| // {prim::kPrimScalarAdd, 0, X} | |||||
| AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimScalarAdd)(node); | |||||
| if (is_zero_) { | |||||
| return x_; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void AddByZero::Visit(const AnfNodePtr &node) { | |||||
| if (node->isa<ValueNode>() && | |||||
| ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| } | |||||
| void AddByZero::Reset() { | |||||
| x_ = nullptr; | |||||
| is_zero_ = false; | |||||
| } | |||||
| // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, | |||||
| // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} | |||||
| AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| AnfVisitor::Match(prim::kPrimTensorAdd)(node); | |||||
| AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| PatternNode x, y, z, xs; | |||||
| PConstant one_(node, false, 1); | |||||
| PConstant one_scalar_(node, false, 1, true); | |||||
| PConstant zero_(node, false, 0); | |||||
| PConstant zero_scalar_(node, false, 0, true); | |||||
| PConstant const_(node); | |||||
| PConstant const_2(node); | |||||
| PConstant any_const(node); | |||||
| MATCH_REPLACE(node, x + zero_, x); // Add by zero | |||||
| MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero | |||||
| MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node)); // Multiply by one | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one | |||||
| // Scalar Mul by zero | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_scalar_.NewValue()); | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_scalar_.NewValue()); | |||||
| // Prim Eliminate (identity) | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); | |||||
| // ConstantDuplicateMul | |||||
| auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { | |||||
| auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node)); | |||||
| auto mul_node = node->cast<CNodePtr>()->inputs()[0]; | |||||
| if (new_mul_tensor == nullptr) { | |||||
| auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph()); | |||||
| return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph()); | |||||
| } | |||||
| return NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph()); | |||||
| }; | |||||
| MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda); | |||||
| if (node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| // OptUpdateZeroTensor | |||||
| MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs), | |||||
| PPrimitive(prim::kPrimMakeTuple, z, y)); | |||||
| // PowerOneEliminate | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x, | |||||
| one_scalar_.CheckFunc(IsValueNode<Scalar>, node)); | |||||
| if (is_zero_) { | |||||
| return x_; | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| void TensorAddByZero::Visit(const AnfNodePtr &node) { | |||||
| if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| x_ = node; | |||||
| } | |||||
| void TensorAddByZero::Visit(const ValueNodePtr &vnode) { | |||||
| auto value = vnode->value(); | |||||
| if (CheckTensorConstant(0).IsTensorConstant(value)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| } | |||||
| void TensorAddByZero::Reset() { | |||||
| x_ = nullptr; | |||||
| is_zero_ = false; | |||||
| } | |||||
| // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} | |||||
| AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| // {PrimMomentum, {...}, Y, Z, Xs} | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto y = inputs[2]; | |||||
| auto z = inputs[3]; | |||||
| // {kPrimZerosLike, X} | |||||
| if (inputs[1]->cast<CNodePtr>()->size() != 2) { | |||||
| return nullptr; | |||||
| } | |||||
| // {prim::kPrimMakeTuple, Z, Y} | |||||
| return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); | |||||
| } | |||||
| // {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> | |||||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | |||||
| // Support function to multiply two constant tensors: partially support broadcasting shapes | |||||
| template <typename T> | |||||
| void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, | |||||
| void **out_data, int out_data_size) { | |||||
| T *data_1 = reinterpret_cast<T *>(in_data_1); | |||||
| T *data_2 = reinterpret_cast<T *>(in_data_2); | |||||
| T *data_out = new T[out_data_size]; | |||||
| if (in_data_1_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[i]; | |||||
| } | |||||
| } | |||||
| if (in_data_2_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[i]; | |||||
| } | |||||
| } | |||||
| *out_data = reinterpret_cast<void *>(data_out); | |||||
| return; | |||||
| } | |||||
| AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, | |||||
| const AnfNodePtr &node_3) { | |||||
| if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) || | |||||
| (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value_1 = GetValueNode(vnode_1); | |||||
| auto value_2 = GetValueNode(vnode_2); | |||||
| if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1); | |||||
| auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2); | |||||
| auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); | |||||
| TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); | |||||
| TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); | |||||
| if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || | |||||
| (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| PatternNode x, y; | |||||
| PConstant zero_(node, false, 0); | |||||
| std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape(); | |||||
| // Multiply by zero | |||||
| MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node), | |||||
| !zero_.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); | |||||
| auto zero_prim = PPrimitive(prim::kPrimZerosLike, y); | |||||
| MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node), | |||||
| !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); | |||||
| int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>()); | |||||
| if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| void *data_out; | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { | |||||
| Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), | |||||
| &data_out, data_out_size); | |||||
| } else { | |||||
| if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { | |||||
| Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| } else { | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { | |||||
| Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| } else { | |||||
| // Un-support data types | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||||
| memcpy_s(data, mem_size, data_out, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| Reset(); | |||||
| // {prim::kPrimMul, Tensor1, {...}} | |||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); | |||||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (!IsCNode(c_p_node_)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor1 = vnode_; | |||||
| auto mul = c_p_node_->cast<CNodePtr>(); | |||||
| Reset(); | |||||
| // {prim::kPrimMul, Tensor2, {...}} | |||||
| AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); | |||||
| if (vnode_ == nullptr || c_p_node_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor2 = vnode_; | |||||
| auto c_p_node = c_p_node_; | |||||
| auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); | |||||
| auto fg = node->func_graph(); | |||||
| auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); | |||||
| if (new_mul_tensor == nullptr) { | |||||
| auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); | |||||
| return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); | |||||
| } | |||||
| return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); | |||||
| } | |||||
| void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { | |||||
| if (IsValueNode<tensor::Tensor>(node)) { | |||||
| vnode_ = node; | |||||
| } | |||||
| if (IsCNode(node) || IsParam(node)) { | |||||
| c_p_node_ = node; | |||||
| } | |||||
| } | |||||
| void ConstantDuplicateMul::Reset() { | |||||
| vnode_ = nullptr; | |||||
| c_p_node_ = nullptr; | |||||
| } | |||||
| AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (!IsValueNode<Scalar>(inputs[2])) { | |||||
| return nullptr; | |||||
| } | |||||
| auto scalar = GetValueNode<ScalarPtr>(inputs[2]); | |||||
| if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) { | |||||
| return inputs[1]; | |||||
| } else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) { | |||||
| return inputs[1]; | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -655,27 +186,6 @@ void AdjustAllReduceMulAdd::Reset() { | |||||
| all_reduce_fg_ = nullptr; | all_reduce_fg_ = nullptr; | ||||
| } | } | ||||
| AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr new_node; | |||||
| for (auto &eliminater : eliminaters_) { | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | |||||
| return new_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { | |||||
| AnfNodePtr new_node; | |||||
| for (auto &eliminater : eliminaters_) { | |||||
| new_node = (*eliminater)(optimizer, node); | |||||
| if (new_node != nullptr) { | |||||
| return new_node; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,159 +21,15 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/visitor.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "frontend/optimizer/irpass.h" | #include "frontend/optimizer/irpass.h" | ||||
| #include "frontend/optimizer/irpass/prim_eliminate.h" | #include "frontend/optimizer/irpass/prim_eliminate.h" | ||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "ir/optimizer_caller.h" | |||||
| #include "ir/pattern_matcher.h" | |||||
| #include "ir/visitor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace irpass { | namespace irpass { | ||||
| // {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} | |||||
| // {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} | |||||
| class MultiplyByZeroOrOne : public AnfVisitor { | |||||
| public: | |||||
| MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} | |||||
| ~MultiplyByZeroOrOne() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Visit(const ValueNodePtr &vnode) override; | |||||
| void Reset(); | |||||
| private: | |||||
| bool is_zero_{false}, is_one_{false}; | |||||
| ValuePtr zero_, one_; | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // Support class used for checking if all values of a Tensor are equal `check_value_` | |||||
| // Supported data types: double, float/float32, int/int32 | |||||
| class CheckTensorConstant { | |||||
| public: | |||||
| explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} | |||||
| ~CheckTensorConstant() = default; | |||||
| bool IsTensorConstant(const ValuePtr &value); | |||||
| bool IsTensorScalarConstant(const ValuePtr &value); | |||||
| private: | |||||
| int check_value_; | |||||
| }; | |||||
| class TensorMultiplyBase : public AnfVisitor { | |||||
| protected: | |||||
| void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); | |||||
| // Make a new tensor (when possible) with the same shape as of `node` | |||||
| // If x is nullptr then fill new tensor will "0" | |||||
| // If x is a tensor with empty shape then fill new tensor with the single value of x | |||||
| // If x is a tensor with same shape as `node` then return x as result | |||||
| AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} | |||||
| class TensorMultiplyByZero : public TensorMultiplyBase { | |||||
| public: | |||||
| TensorMultiplyByZero() : zero_(MakeValue(0)) {} | |||||
| ~TensorMultiplyByZero() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Visit(const ValueNodePtr &vnode) override; | |||||
| void Reset(); | |||||
| private: | |||||
| bool is_zero_{false}; | |||||
| ValuePtr zero_; | |||||
| }; | |||||
| // {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} | |||||
| class TensorMultiplyByOne : public TensorMultiplyBase { | |||||
| public: | |||||
| TensorMultiplyByOne() {} | |||||
| ~TensorMultiplyByOne() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Visit(const ValueNodePtr &vnode) override; | |||||
| void Reset(); | |||||
| private: | |||||
| bool is_one_{false}; | |||||
| }; | |||||
| // {prim::kPrimScalarAdd, X, 0} | |||||
| // {prim::kPrimScalarAdd, 0, X} | |||||
| class AddByZero : public AnfVisitor { | |||||
| public: | |||||
| AddByZero() : zero_(MakeValue(0)) {} | |||||
| ~AddByZero() override = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Reset(); | |||||
| private: | |||||
| bool is_zero_{false}; | |||||
| ValuePtr zero_; | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, | |||||
| // {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} | |||||
| class TensorAddByZero : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Visit(const ValueNodePtr &vnode) override; | |||||
| void Reset(); | |||||
| private: | |||||
| bool is_zero_{false}; | |||||
| AnfNodePtr x_{nullptr}; | |||||
| }; | |||||
| // {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} | |||||
| class OptUpdateZeroTensor : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | |||||
| // {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> | |||||
| // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} | |||||
| class ConstantDuplicateMul : public AnfVisitor { | |||||
| public: | |||||
| // Support function to multiply two constant tensors: partially support broadcasting shapes | |||||
| template <typename T> | |||||
| void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, | |||||
| int out_data_size); | |||||
| AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| void Visit(const AnfNodePtr &node) override; | |||||
| void Reset(); | |||||
| private: | |||||
| AnfNodePtr vnode_; | |||||
| AnfNodePtr c_p_node_; | |||||
| }; | |||||
| class PowerOneEliminate : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | |||||
| // grad = AllReduce(grad) / worker_number | // grad = AllReduce(grad) / worker_number | ||||
| // grad = grad + weight * decy | // grad = grad + weight * decy | ||||
| // -> | // -> | ||||
| @@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { | |||||
| class ArithmeticSimplify : public OptimizerCaller { | class ArithmeticSimplify : public OptimizerCaller { | ||||
| public: | public: | ||||
| ArithmeticSimplify() | |||||
| : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()), | |||||
| tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()), | |||||
| add_by_zero_(std::make_shared<AddByZero>()), | |||||
| tensor_add_by_zero_(std::make_shared<TensorAddByZero>()), | |||||
| identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)), | |||||
| opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()), | |||||
| constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()), | |||||
| power_one_(std::make_shared<PowerOneEliminate>()) { | |||||
| eliminaters_.emplace_back(multiply_by_zero_or_one_); | |||||
| eliminaters_.emplace_back(tensor_multiply_by_one_); | |||||
| eliminaters_.emplace_back(add_by_zero_); | |||||
| eliminaters_.emplace_back(tensor_add_by_zero_); | |||||
| eliminaters_.emplace_back(identity_); | |||||
| eliminaters_.emplace_back(opt_update_zero_tensor_); | |||||
| eliminaters_.emplace_back(constant_duplicate_mul_); | |||||
| eliminaters_.emplace_back(power_one_); | |||||
| } | |||||
| ~ArithmeticSimplify() = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; | |||||
| private: | |||||
| OptimizerCallerPtr multiply_by_zero_or_one_; | |||||
| OptimizerCallerPtr tensor_multiply_by_one_; | |||||
| OptimizerCallerPtr add_by_zero_; | |||||
| OptimizerCallerPtr tensor_add_by_zero_; | |||||
| OptimizerCallerPtr identity_; | |||||
| OptimizerCallerPtr opt_update_zero_tensor_; | |||||
| OptimizerCallerPtr constant_duplicate_mul_; | |||||
| OptimizerCallerPtr power_one_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | }; | ||||
| // Arithmetic Simplifications should be done after step_parallel. | // Arithmetic Simplifications should be done after step_parallel. | ||||
| @@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller { | |||||
| // ArithmeticSimplify and deferred until step_parallel. | // ArithmeticSimplify and deferred until step_parallel. | ||||
| class ArithmeticSimplify2 : public OptimizerCaller { | class ArithmeticSimplify2 : public OptimizerCaller { | ||||
| public: | public: | ||||
| ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) { | |||||
| eliminaters_.emplace_back(tensor_multiply_by_zero_); | |||||
| } | |||||
| ~ArithmeticSimplify2() = default; | |||||
| AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; | |||||
| private: | |||||
| OptimizerCallerPtr tensor_multiply_by_zero_; | |||||
| std::vector<OptimizerCallerPtr> eliminaters_{}; | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | }; | ||||
| } // namespace irpass | } // namespace irpass | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,14 +17,16 @@ | |||||
| #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ | #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ | ||||
| #define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ | #define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ | ||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ir/anf.h" | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "ir/anf.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| /// | /// | ||||
| /// Base class for all recognizable patterns. | /// Base class for all recognizable patterns. | ||||
| /// We implement an Expression Template approach using static polymorphism based on | /// We implement an Expression Template approach using static polymorphism based on | ||||
| @@ -60,7 +62,7 @@ class PIsEqual { | |||||
| bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } | bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } | ||||
| }; | }; | ||||
| template <typename T> | |||||
| template <typename T = AnfNodePtr> | |||||
| class PatternNode : public PBase<PatternNode<T> > { | class PatternNode : public PBase<PatternNode<T> > { | ||||
| public: | public: | ||||
| T GetNode(const AnfNodePtr &node) const { | T GetNode(const AnfNodePtr &node) const { | ||||
| @@ -90,12 +92,13 @@ class PatternNode : public PBase<PatternNode<T> > { | |||||
| template <typename T, typename T2> | template <typename T, typename T2> | ||||
| class PBinOperation : public PBase<PBinOperation<T, T2> > { | class PBinOperation : public PBase<PBinOperation<T, T2> > { | ||||
| public: | public: | ||||
| PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} | |||||
| PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false) | |||||
| : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {} | |||||
| AnfNodePtr GetNode(const AnfNodePtr &node) const { | AnfNodePtr GetNode(const AnfNodePtr &node) const { | ||||
| AnfNodePtr lhs = x_.GetNode(node->func_graph()); | AnfNodePtr lhs = x_.GetNode(node->func_graph()); | ||||
| AnfNodePtr rhs = y_.GetNode(node->func_graph()); | AnfNodePtr rhs = y_.GetNode(node->func_graph()); | ||||
| AnfNodePtrList list = {prim_->cast<AnfNodePtr>(), lhs, rhs}; | |||||
| AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; | |||||
| return NewCNode(list, node->func_graph()); | return NewCNode(list, node->func_graph()); | ||||
| } | } | ||||
| @@ -106,6 +109,14 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| // Binary Prim assumes only two inputs | // Binary Prim assumes only two inputs | ||||
| if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { | if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { | ||||
| // If the operation is commutative, then check with inversed operands | |||||
| if (is_commutative_) { | |||||
| Reset(); | |||||
| if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -113,7 +124,6 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| void Reset() const { | void Reset() const { | ||||
| x_.Reset(); | x_.Reset(); | ||||
| y_.Reset(); | y_.Reset(); | ||||
| @@ -123,6 +133,7 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > { | |||||
| const PrimitivePtr prim_; | const PrimitivePtr prim_; | ||||
| typename T::Internal x_; | typename T::Internal x_; | ||||
| typename T2::Internal y_; | typename T2::Internal y_; | ||||
| bool is_commutative_{false}; | |||||
| }; | }; | ||||
| /// | /// | ||||
| @@ -214,7 +225,6 @@ class PCNode : public PBase<PCNode<TArgs...> > { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void Reset() const { | void Reset() const { | ||||
| tuple_utils::PTupleResetCapture reset; | tuple_utils::PTupleResetCapture reset; | ||||
| tuple_utils::apply_func_tuple(&reset, args_); | tuple_utils::apply_func_tuple(&reset, args_); | ||||
| @@ -255,6 +265,12 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case) | |||||
| const PPrimitive<TArgs...> &Commutative(const bool &is_commutative = true) const { | |||||
| is_commutative_ = is_commutative; | |||||
| return *this; | |||||
| } | |||||
| void Reset() const { | void Reset() const { | ||||
| tuple_utils::PTupleResetCapture reset; | tuple_utils::PTupleResetCapture reset; | ||||
| tuple_utils::apply_func_tuple(&reset, args_); | tuple_utils::apply_func_tuple(&reset, args_); | ||||
| @@ -263,46 +279,457 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > { | |||||
| private: | private: | ||||
| const PrimitivePtr prim_; | const PrimitivePtr prim_; | ||||
| std::tuple<typename TArgs::Internal...> args_; | std::tuple<typename TArgs::Internal...> args_; | ||||
| mutable bool is_commutative_{false}; | |||||
| }; | |||||
| /// | |||||
| /// PConstant class can capture a value node of a specified value (check_value_) | |||||
| /// or a non-specified one (any_value = true). | |||||
| /// It can be configured to capture a scalar constant as well (is_scalar_ = true) | |||||
| /// | |||||
| template <typename T = AnfNodePtr> | |||||
| class PConstant : public PBase<PConstant<T> > { | |||||
| public: | |||||
| explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int check_value = 0, | |||||
| const bool is_scalar = false) | |||||
| : as_node_(as_node), | |||||
| captured_node_(as_node), | |||||
| any_value_(any_value), | |||||
| check_value_(check_value), | |||||
| is_scalar_(is_scalar) {} | |||||
| // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode | |||||
| const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const { | |||||
| if (node == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a nullptr node."; | |||||
| } | |||||
| as_node_ = node; | |||||
| changed_shape_ = true; | |||||
| return *this; | |||||
| } | |||||
| // Sets as_node_ as the node caputred by the received Pattern token to produce a same-shape node with GetNode | |||||
| const PConstant<T> &WithShapeAs(const PatternNode<T> &pnode) const { | |||||
| if (captured_node_ == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "WithShapeAs is trying to use a Pattern token without previously capturing a node."; | |||||
| } | |||||
| as_node_ = pnode.GetNode(captured_node_); | |||||
| changed_shape_ = true; | |||||
| return *this; | |||||
| } | |||||
| /// Sets captured_node_ as the node captured by the Pattern received as argument | |||||
| /// to produce a new node with its contents when calling GetNode. | |||||
| const PConstant<T> &WithValueOf(const PatternNode<T> &pnode) const { | |||||
| if (!any_value_) { | |||||
| MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node."; | |||||
| } | |||||
| if (captured_node_ == nullptr) { | |||||
| MS_EXCEPTION(ValueError) << "WithValueOf is trying to use a Pattern token without previously capturing a node."; | |||||
| } | |||||
| captured_node_ = pnode.GetNode(captured_node_); | |||||
| changed_shape_ = true; | |||||
| return *this; | |||||
| } | |||||
| /// Create a new Value Node filled up with check_value. | |||||
| /// This function must be used immediately before GetNode to avoid replacing the expected result. | |||||
| /// Only valid for scalar constants. For tensors use WithShapeAs or WithValueOf. | |||||
| const PConstant<T> &NewValue() const { | |||||
| if (!is_scalar_) { | |||||
| MS_EXCEPTION(ValueError) << "NewValue is valid only for scalar PConstants."; | |||||
| } | |||||
| auto value_node_ = MakeValue(check_value_); | |||||
| captured_node_ = NewValueNode(value_node_); | |||||
| is_new_value_node_ = true; | |||||
| return *this; | |||||
| } | |||||
| AnfNodePtr GetNode(const AnfNodePtr &node) const { | |||||
| // If a NewValueNode was requested (using NewValue function) then return that created node. | |||||
| if (is_new_value_node_) { | |||||
| return captured_node_; | |||||
| } | |||||
| /// Return a NewTensorFilledWithData if the node was initialized to have a specific value | |||||
| /// even if it wasn't captured. Usually for zero constants (x - x => zero). | |||||
| /// If the shape was changed, use the new shape. | |||||
| if (changed_shape_ || !captured_) { | |||||
| if (!any_value_) { | |||||
| return NewTensorFilledWithData(as_node_, check_value_); | |||||
| } | |||||
| return NewTensorFilledWithData(as_node_, captured_node_); | |||||
| } | |||||
| return captured_node_; | |||||
| } | |||||
| bool TryCapture_(const AnfNodePtr &node) const { | |||||
| // if (IsValueNode<Value>(node)) { | |||||
| if (node->isa<ValueNode>()) { | |||||
| // If any_value_ is set don't check for the node's value. Just capture it. | |||||
| if (any_value_) { | |||||
| captured_node_ = node; | |||||
| captured_ = true; | |||||
| return true; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) { | |||||
| captured_node_ = node; | |||||
| captured_ = true; | |||||
| return true; | |||||
| } | |||||
| auto value_node_ = MakeValue(check_value_); | |||||
| if (*GetValueNode(node) == *value_node_) { | |||||
| captured_node_ = node; | |||||
| captured_ = true; | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void Reset() const { | |||||
| captured_ = false; | |||||
| changed_shape_ = false; | |||||
| is_new_value_node_ = false; | |||||
| } | |||||
| // Support function used for checking if all values of a Tensor are equal to `check_value_` | |||||
| // Supported data types: double, float/float32, int/int32 | |||||
| bool IsTensorConstant(const ValuePtr &value) const { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||||
| float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > FLT_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if (tensor_type == TypeId::kNumberTypeFloat64) { | |||||
| double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (fabs(data2[i] - check_value_) > DBL_EPSILON) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { | |||||
| int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c()); | |||||
| for (int i = 0; i < tensor_ptr->DataSize(); i++) { | |||||
| if (data2[i] != check_value_) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // Input Data Type is not supported | |||||
| return false; | |||||
| } | |||||
| bool IsTensorScalarConstant(const ValuePtr &value) const { | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return false; | |||||
| } | |||||
| auto tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { | |||||
| return false; | |||||
| } | |||||
| return IsTensorConstant(value); | |||||
| } | |||||
| void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const { | |||||
| if (!node->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value = node->cast<ValueNodePtr>()->value(); | |||||
| if (!value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value); | |||||
| return tensor_ptr->data_c(); | |||||
| } | |||||
| // Make a new tensor (when possible) with the same shape as of `node` | |||||
| // If x is nullptr then fill new tensor will "0" | |||||
| // If x is a tensor with empty shape then fill new tensor with the single value of x | |||||
| // If x is a tensor with same shape as `node` then return x as result | |||||
| AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const { | |||||
| if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); | |||||
| std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||||
| if (x == nullptr) { | |||||
| memset_s(data, mem_size, 0, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| // x is not nullptr | |||||
| if (x->isa<CNode>()) { | |||||
| if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| std::vector<int> x_shape = x_abstract->shape()->shape(); | |||||
| if (x_shape != tensor_shape) { | |||||
| return nullptr; | |||||
| } | |||||
| return x; | |||||
| } | |||||
| if (!x->isa<ValueNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_value = x->cast<ValueNodePtr>()->value(); | |||||
| if (!x_value->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value); | |||||
| if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { | |||||
| return nullptr; | |||||
| } | |||||
| char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x)); | |||||
| if (x_tensor_ptr->DataSize() == 1) { | |||||
| for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { | |||||
| memcpy_s(data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr), source_data, | |||||
| GetTypeByte(tensor_type_ptr)); | |||||
| } | |||||
| } else { | |||||
| memcpy_s(data, mem_size, source_data, mem_size); | |||||
| } | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const { | |||||
| if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); | |||||
| std::vector<int> tensor_shape = tensor_abstract->shape()->shape(); | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||||
| memset_s(data, mem_size, value, mem_size); | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| // Support function to multiply two constant tensors: partially support broadcasting shapes | |||||
| template <typename TM> | |||||
| void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, | |||||
| int out_data_size) const { | |||||
| TM *data_1 = reinterpret_cast<TM *>(in_data_1); | |||||
| TM *data_2 = reinterpret_cast<TM *>(in_data_2); | |||||
| TM *data_out = new TM[out_data_size]; | |||||
| if (in_data_1_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[0]; | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] = data_1[i]; | |||||
| } | |||||
| } | |||||
| if (in_data_2_size == 1) { | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[0]; | |||||
| } | |||||
| } else { | |||||
| if (in_data_2_size < out_data_size) { | |||||
| MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size."; | |||||
| } | |||||
| for (int i = 0; i < out_data_size; i++) { | |||||
| data_out[i] *= data_2[i]; | |||||
| } | |||||
| } | |||||
| *out_data = reinterpret_cast<void *>(data_out); | |||||
| return; | |||||
| } | |||||
| AnfNodePtr MulByPatternConst(const PConstant<T> &vpnode_2, const AnfNodePtr &node_3) const { | |||||
| AnfNodePtr vnode_1 = this->GetNode(captured_node_); | |||||
| AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); | |||||
| return MulConstantTensors(vnode_1, vnode_2, node_3); | |||||
| } | |||||
| AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { | |||||
| if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) || | |||||
| (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto value_1 = GetValueNode(vnode_1); | |||||
| auto value_2 = GetValueNode(vnode_2); | |||||
| if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1); | |||||
| auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2); | |||||
| auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>(); | |||||
| TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); | |||||
| TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); | |||||
| TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); | |||||
| if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || | |||||
| (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape(); | |||||
| int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>()); | |||||
| if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape); | |||||
| size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); | |||||
| char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); | |||||
| int ret = 0; | |||||
| void *data_out = nullptr; | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { | |||||
| Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| ret = memcpy_s(data, mem_size, data_out, mem_size); | |||||
| delete[] reinterpret_cast<float *>(data_out); | |||||
| } else { | |||||
| if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { | |||||
| Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| ret = memcpy_s(data, mem_size, data_out, mem_size); | |||||
| delete[] reinterpret_cast<double *>(data_out); | |||||
| } else { | |||||
| if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || | |||||
| (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { | |||||
| Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), | |||||
| tensor_ptr_2->DataSize(), &data_out, data_out_size); | |||||
| ret = memcpy_s(data, mem_size, data_out, mem_size); | |||||
| delete[] reinterpret_cast<int *>(data_out); | |||||
| } else { | |||||
| // Un-support data types | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" | |||||
| << new_tensor_ptr->DataSize(); | |||||
| } | |||||
| auto new_vnode = NewValueNode(new_tensor_ptr); | |||||
| new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); | |||||
| return new_vnode; | |||||
| } | |||||
| using Internal = const PConstant<T> &; | |||||
| protected: | |||||
| mutable AnfNodePtr as_node_; | |||||
| mutable AnfNodePtr captured_node_; | |||||
| bool any_value_{true}; | |||||
| int check_value_{0}; | |||||
| bool is_scalar_{false}; | |||||
| mutable bool is_new_value_node_{false}; | |||||
| mutable bool captured_{false}; | |||||
| mutable bool changed_shape_{false}; | |||||
| }; | }; | ||||
| // Macro for binary operation functions | // Macro for binary operation functions | ||||
| #define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ | |||||
| template <typename T, typename T2> \ | |||||
| inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \ | |||||
| return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ | |||||
| #define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \ | |||||
| template <typename T, typename T2> \ | |||||
| inline PBinOperation<T, T2> Operator(const PBase<T> &x, const PBase<T2> &y) { \ | |||||
| return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \ | |||||
| } | } | ||||
| // Arithmetic operations | // Arithmetic operations | ||||
| BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); | |||||
| BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); | |||||
| BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); | |||||
| BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); | |||||
| // Macros for match and replace | // Macros for match and replace | ||||
| #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ | #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ | ||||
| if ((CaptureNode).TryCapture(OrigNode)) { \ | if ((CaptureNode).TryCapture(OrigNode)) { \ | ||||
| return (ReplaceWith).GetNode(OrigNode); \ | |||||
| auto rep = (ReplaceWith).GetNode(OrigNode); \ | |||||
| if (rep != nullptr) { \ | |||||
| return rep; \ | |||||
| } \ | |||||
| } | } | ||||
| #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ | #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ | ||||
| if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ | if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ | ||||
| return (ReplaceWith).GetNode(OrigNode); \ | |||||
| auto rep = (ReplaceWith).GetNode(OrigNode); \ | |||||
| if (rep != nullptr) { \ | |||||
| return rep; \ | |||||
| } \ | |||||
| } | } | ||||
| #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ | #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ | ||||
| if ((CaptureNode).TryCapture(OrigNode)) { \ | if ((CaptureNode).TryCapture(OrigNode)) { \ | ||||
| if ((Condition)) { \ | if ((Condition)) { \ | ||||
| return (ReplaceWith).GetNode(OrigNode); \ | |||||
| auto rep = (ReplaceWith).GetNode(OrigNode); \ | |||||
| if (rep != nullptr) { \ | |||||
| return (ReplaceWith).GetNode(OrigNode); \ | |||||
| } \ | |||||
| } else { \ | |||||
| auto rep = (ElseNode).GetNode(OrigNode); \ | |||||
| if (rep != nullptr) { \ | |||||
| return (ElseNode).GetNode(OrigNode); \ | |||||
| } \ | |||||
| } \ | } \ | ||||
| return (ElseNode).GetNode(OrigNode); \ | |||||
| } | } | ||||
| #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ | #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ | ||||
| if ((CaptureNode).TryCapture(OrigNode)) { \ | if ((CaptureNode).TryCapture(OrigNode)) { \ | ||||
| return (Lambda)(); \ | |||||
| auto rep = (Lambda)(); \ | |||||
| if (rep != nullptr) { \ | |||||
| return rep; \ | |||||
| } \ | |||||
| } | } | ||||
| #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ | #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ | ||||
| if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ | if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ | ||||
| return (Lambda)(); \ | |||||
| auto rep = (Lambda)(); \ | |||||
| if (rep != nullptr) { \ | |||||
| return rep; \ | |||||
| } \ | |||||
| } | } | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common { | |||||
| }; | }; | ||||
| void SetUp() { | void SetUp() { | ||||
| elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd); | |||||
| elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd); | |||||
| elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); | elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R); | ||||
| idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); | idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P); | ||||
| Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); | Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q); | ||||