Browse Source

!2223 Make those AdamXX and LambXX fusion pass not work for unexpect data type

Merge pull request !2223 from huanghui/TMP
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3c1b8308cf
11 changed files with 41 additions and 1 deletions
  1. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc
  2. +3
    -1
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc
  3. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc
  4. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc
  5. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc
  6. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc
  7. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc
  8. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc
  9. +10
    -0
      mindspore/ccsrc/pre_activate/common/helper.cc
  10. +4
    -0
      mindspore/ccsrc/pre_activate/common/helper.h
  11. +3
    -0
      mindspore/ccsrc/utils/utils.h

+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc View File

@@ -109,6 +109,9 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto new_node = CreateAdamApplyOneNode(func_graph, equiv);
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_scope(node->scope());


+ 3
- 1
mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc View File

@@ -146,7 +146,9 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c
if (graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}

if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
std::vector<AnfNodePtr> inputs = GetFusionNodeInputs(equiv);
auto fusion_node = graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(fusion_node);


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc View File

@@ -108,6 +108,9 @@ bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2

const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
std::vector<AnfNodePtr> old_pattern_outputs;
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
return nullptr;


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc View File

@@ -88,6 +88,9 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
MS_EXCEPTION_IF_NULL(mul4);
// Get add3 and match the add3 pattern


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc View File

@@ -153,6 +153,9 @@ const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_gra
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
AnfNodePtr mul4 = nullptr;
AnfNodePtr real_div0 = nullptr;
AnfNodePtr real_div1 = nullptr;


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc View File

@@ -61,6 +61,9 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto new_node = CreateLambNextRightNode(func_graph, equiv);
MS_EXCEPTION_IF_NULL(new_node);
// Set abstract of new node


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc View File

@@ -50,6 +50,9 @@ const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph,
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto input0 = utils::cast<AnfNodePtr>((*equiv)[input0_]);
auto input1 = utils::cast<AnfNodePtr>((*equiv)[input1_]);
auto input2 = utils::cast<AnfNodePtr>((*equiv)[input2_]);


+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc View File

@@ -42,6 +42,9 @@ const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, con
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
if (!CheckSupportDataType(node, kFloatDataTypeSet)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kLambUpdateWithLrV2OpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs),


+ 10
- 0
mindspore/ccsrc/pre_activate/common/helper.cc View File

@@ -765,5 +765,15 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) {
MS_EXCEPTION_IF_NULL(cnode);
return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr<bool>(node, attr_name);
}

bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set) {
MS_EXCEPTION_IF_NULL(node);
TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0);
if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) {
return true;
}
MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
return false;
}
} // namespace opt
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/pre_activate/common/helper.h View File

@@ -20,6 +20,7 @@
#include <memory>
#include <utility>
#include <string>
#include <set>
#include <unordered_set>
#include "ir/func_graph.h"
#include "session/kernel_graph.h"
@@ -189,6 +190,9 @@ bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2);

// Get attr which is bool from cnode
bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);

// Check node's data type is in supported data type set
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_

+ 3
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -25,6 +25,7 @@
#include <set>

#include "utils/log_adapter.h"
#include "ir/dtype/type.h"

namespace mindspore {
// op name. Op which not exists in operator/ops.h, so define it's name here
@@ -270,6 +271,8 @@ const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFo
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};

const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};

static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
try {
if (chmod(file_name.c_str(), mode) != 0) {


Loading…
Cancel
Save