|
|
|
@@ -35,14 +35,40 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ad { |
|
|
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>; |
|
|
|
struct PrimitiveTotalEqual { |
|
|
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { |
|
|
|
if (t1->name() != t2->name()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto const &attrs1 = t1->attrs(); |
|
|
|
auto const &attrs2 = t2->attrs(); |
|
|
|
if (attrs1.size() != attrs2.size()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &attr : attrs1) { |
|
|
|
if (!t2->HasAttr(attr.first)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>; |
|
|
|
class KPrim; |
|
|
|
extern KPrim g_k_prims; |
|
|
|
class DFunctor; |
|
|
|
using DFunctorPtr = std::shared_ptr<DFunctor>; |
|
|
|
|
|
|
|
// D Functor's rules to map closure object and morphisms. |
|
|
|
class DFunctor { |
|
|
|
class DFunctor : public std::enable_shared_from_this<DFunctor> { |
|
|
|
public: |
|
|
|
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); |
|
|
|
~DFunctor() = default; |
|
|
|
@@ -54,7 +80,7 @@ class DFunctor { |
|
|
|
// Construct user defined k object. |
|
|
|
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); |
|
|
|
// Register functor objects to form a global view. |
|
|
|
void Init(const DFunctorPtr &functor, bool is_top = false); |
|
|
|
void Init(bool is_top = false); |
|
|
|
bool IsInScope(const AnfNodePtr &node); |
|
|
|
|
|
|
|
// Clear resources. |
|
|
|
|