| @@ -21,19 +21,10 @@ namespace opt { | |||||
| namespace irpass { | namespace irpass { | ||||
| #define UPPER_FLT_LIMIT (FLT_MAX / 2.0) | #define UPPER_FLT_LIMIT (FLT_MAX / 2.0) | ||||
| #define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) | #define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) | ||||
| // Define the checking mode | |||||
| enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS }; | |||||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||||
| return IsCNodePositive(node->cast<CNodePtr>()->input(1)); | |||||
| } | |||||
| if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // check if a value is bigger than UPPER_FLT_LIMIT | |||||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||||
| bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | auto value_node = node->cast<ValueNodePtr>(); | ||||
| if (value_node == nullptr) { | if (value_node == nullptr) { | ||||
| return false; | return false; | ||||
| @@ -47,7 +38,10 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||||
| auto scalar = value->cast<ScalarPtr>(); | auto scalar = value->cast<ScalarPtr>(); | ||||
| if (scalar != nullptr) { | if (scalar != nullptr) { | ||||
| if (scalar->isa<FloatImm>()) { | if (scalar->isa<FloatImm>()) { | ||||
| return GetValue<float>(scalar) > UPPER_FLT_LIMIT; | |||||
| if (checking_mode == GREATER_EQUAL) { | |||||
| return GetValue<float>(scalar) >= check_value; | |||||
| } | |||||
| return GetValue<float>(scalar) < check_value; | |||||
| } | } | ||||
| } | } | ||||
| // Check for Tensor [] or Tensor [1] | // Check for Tensor [] or Tensor [1] | ||||
| @@ -62,48 +56,42 @@ bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { | |||||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | ||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | ||||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | ||||
| return data[0] > UPPER_FLT_LIMIT; | |||||
| if (checking_mode == GREATER_EQUAL) { | |||||
| return data[0] >= check_value; | |||||
| } | |||||
| return data[0] < check_value; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| // check if a value is smaller than LOWER_FLT_LIMIT | |||||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| if (value_node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // check if a value is greater or equal 0.0 | |||||
| bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); } | |||||
| auto value = value_node->value(); | |||||
| if (value == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto scalar = value->cast<ScalarPtr>(); | |||||
| if (scalar != nullptr) { | |||||
| if (scalar->isa<FloatImm>()) { | |||||
| return GetValue<float>(scalar) < LOWER_FLT_LIMIT; | |||||
| } | |||||
| } | |||||
| // Check for Tensor [] or Tensor [1] | |||||
| auto tensor_ptr = value->cast<tensor::TensorPtr>(); | |||||
| if (tensor_ptr == nullptr) { | |||||
| return false; | |||||
| bool IsCNodePositive(const AnfNodePtr &node) { | |||||
| if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { | |||||
| return IsCNodePositive(node->cast<CNodePtr>()->input(1)); | |||||
| } | } | ||||
| if (tensor_ptr->DataSize() > 1) { | |||||
| return false; | |||||
| if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { | |||||
| return true; | |||||
| } | } | ||||
| TypeId tensor_type = tensor_ptr->Dtype()->type_id(); | |||||
| if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { | |||||
| float *data = reinterpret_cast<float *>(tensor_ptr->data_c()); | |||||
| return data[0] < LOWER_FLT_LIMIT; | |||||
| if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) { | |||||
| auto first_node_positive = | |||||
| IsCNodePositive(node->cast<CNodePtr>()->input(1)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(1)); | |||||
| auto second_node_positive = | |||||
| IsCNodePositive(node->cast<CNodePtr>()->input(2)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(2)); | |||||
| return first_node_positive && second_node_positive; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| // check if a value is greater or equal UPPER_FLT_LIMIT | |||||
| bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); } | |||||
| // check if a value is smaller than LOWER_FLT_LIMIT | |||||
| bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); } | |||||
| AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | ||||
| PatternNode x, y, z; | PatternNode x, y, z; | ||||
| PConstant zero_(node, false, 0); | PConstant zero_(node, false, 0); | ||||
| @@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, | MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, | ||||
| IsCNodePositive(x.GetNode(node))); | IsCNodePositive(x.GetNode(node))); | ||||
| // {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); | MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); | ||||
| // {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); | MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); | ||||
| // {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0 | |||||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node))); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||