|
|
|
@@ -300,6 +300,46 @@ bool IsStateEquivalent(const MonadState &state1, const MonadState &state2) { |
|
|
|
(state1.io == nullptr || state2.io == nullptr || state1.io == state2.io); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) { |
|
|
|
MonadState state_matmul = GetMonadState(inner); |
|
|
|
MonadState state_node = GetMonadState(outer, inner); |
|
|
|
return IsStateEquivalent(state_matmul, state_node); |
|
|
|
} |
|
|
|
|
|
|
|
std::set<CNodePtr> GetLoadInputs(const AnfNodePtr &node) { |
|
|
|
std::set<CNodePtr> loads; |
|
|
|
auto cnode = dyn_cast<CNode>(node); |
|
|
|
if (cnode == nullptr) { |
|
|
|
return loads; |
|
|
|
} |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) { |
|
|
|
auto &input = inputs.at(i); |
|
|
|
if (IsPrimitiveCNode(input, prim::kPrimLoad)) { |
|
|
|
loads.insert(input->cast<CNodePtr>()); |
|
|
|
} else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { |
|
|
|
loads.merge(GetLoadInputs(input)); |
|
|
|
} |
|
|
|
} |
|
|
|
return loads; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) { |
|
|
|
constexpr size_t kMonadInput = 2; |
|
|
|
auto outer_loads = GetLoadInputs(outer); |
|
|
|
if (outer_loads.empty()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
auto inner_loads = GetLoadInputs(inner); |
|
|
|
if (inner_loads.empty()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
outer_loads.merge(inner_loads); |
|
|
|
auto &monad = (*outer_loads.begin())->inputs().at(kMonadInput); |
|
|
|
return std::all_of(++outer_loads.begin(), outer_loads.end(), |
|
|
|
[&monad](const CNodePtr &load) { return load->inputs().at(kMonadInput) == monad; }); |
|
|
|
} |
|
|
|
|
|
|
|
size_t NewSeenGeneration() { |
|
|
|
static size_t seen_generation = 0; |
|
|
|
return ++seen_generation; |
|
|
|
@@ -353,6 +393,26 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { |
|
|
|
std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_target, const AnfNodePtr &attr_input, |
|
|
|
const std::string &primitive_target, const std::string &default_target) { |
|
|
|
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimPartial)) { |
|
|
|
primitive->EraseAttr(primitive_target); |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
if (!att_target->isa<StringImm>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; |
|
|
|
} |
|
|
|
auto target = GetValue<std::string>(att_target); |
|
|
|
if (kTargetSet.find(target) == kTargetSet.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; |
|
|
|
} |
|
|
|
return target; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
@@ -387,22 +447,7 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { |
|
|
|
auto primitive = value->cast<PrimitivePtr>(); |
|
|
|
auto att_target = primitive->GetAttr(primitive_target); |
|
|
|
if (att_target != nullptr) { |
|
|
|
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || |
|
|
|
IsPrimitive(attr_input, prim::kPrimPartial)) { |
|
|
|
primitive->EraseAttr(primitive_target); |
|
|
|
return default_target; |
|
|
|
} |
|
|
|
if (!att_target->isa<StringImm>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; |
|
|
|
} |
|
|
|
auto target = GetValue<std::string>(att_target); |
|
|
|
if (kTargetSet.find(target) == kTargetSet.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target; |
|
|
|
} |
|
|
|
return target; |
|
|
|
return GetAttrTarget(primitive, att_target, attr_input, primitive_target, default_target); |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimDepend)) { |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
|