| @@ -162,8 +162,7 @@ class EnvGetItemEliminater : public AnfVisitor { | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| PatternNode c1, c2, y; | |||
| MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y, | |||
| (IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) && | |||
| (GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0)); | |||
| (IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)))); | |||
| return nullptr; | |||
| } | |||
| }; | |||
| @@ -22,24 +22,12 @@ | |||
| namespace mindspore { | |||
| std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &objPtr) { | |||
| out << "("; | |||
| MS_EXCEPTION_IF_NULL(objPtr); | |||
| for (auto &iter : objPtr->contents_) { | |||
| out << iter.first << ":" << iter.second << ";"; | |||
| } | |||
| out << ")"; | |||
| out << "()"; | |||
| return out; | |||
| } | |||
| bool EnvInstance::operator==(const EnvInstance &other) const { | |||
| if (Len() != other.Len()) { | |||
| return false; | |||
| } | |||
| bool equal = std::all_of(contents_.begin(), contents_.end(), [&other](const auto &item) { | |||
| return other.contents_.find(item.first) != other.contents_.end(); | |||
| }); | |||
| return equal; | |||
| } | |||
| bool EnvInstance::operator==(const EnvInstance &other) const { return true; } | |||
| bool EnvInstance::operator==(const Value &other) const { | |||
| if (other.isa<EnvInstance>()) { | |||
| auto other_env_inst = static_cast<const EnvInstance *>(&other); | |||
| @@ -107,7 +107,7 @@ class EnvInstance : public Value { | |||
| public: | |||
| friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env); | |||
| explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} | |||
| EnvInstance() = default; | |||
| ~EnvInstance() override = default; | |||
| MS_DECLARE_PARENT(EnvInstance, Value); | |||
| abstract::AbstractBasePtr ToAbstract() override { | |||
| @@ -115,53 +115,14 @@ class EnvInstance : public Value { | |||
| } | |||
| bool operator==(const EnvInstance &other) const; | |||
| bool operator==(const Value &other) const override; | |||
| EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} | |||
| EnvInstance(const EnvInstance &v) : Value(v) {} | |||
| EnvInstance(EnvInstance &&v) = default; | |||
| EnvInstance &operator=(EnvInstance &&src) noexcept { | |||
| if (&src != this) { | |||
| contents_ = src.contents_; | |||
| } | |||
| return *this; | |||
| }; | |||
| // Get the sensitivity list for the given key | |||
| const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { | |||
| auto iterator = contents_.find(key); | |||
| if (iterator != contents_.end()) { | |||
| return iterator->second; | |||
| } | |||
| return def; | |||
| } | |||
| // Set a value for the given key. | |||
| EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { | |||
| EnvInstance rval(contents_); | |||
| rval.contents_[key] = value; | |||
| return rval; | |||
| } | |||
| EnvInstance &operator=(EnvInstance &&src) noexcept { return *this; }; | |||
| // Add two EnvInstances. | |||
| EnvInstance Add(const EnvInstance &other) const { | |||
| EnvInstance rval(contents_); | |||
| for (auto iter_other : other.contents_) { | |||
| auto item_self = contents_.find(iter_other.first); | |||
| if (item_self != contents_.end()) { | |||
| MS_LOG(DEBUG) << "Need to use add"; | |||
| } else { | |||
| rval.contents_[iter_other.first] = iter_other.second; | |||
| } | |||
| } | |||
| return rval; | |||
| } | |||
| size_t Len() const { return contents_.size(); } | |||
| std::size_t hash() const override { | |||
| // deterministic characteristic of member variables. | |||
| return Len(); | |||
| return tid(); | |||
| } | |||
| private: | |||
| EnvInstanceContentsMap contents_; | |||
| }; | |||
| using EnvInstancePtr = std::shared_ptr<EnvInstance>; | |||
| @@ -17,41 +17,20 @@ | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "utils/symbolic.h" | |||
| using std::cout; | |||
| using std::endl; | |||
| using std::string; | |||
| namespace mindspore { | |||
| class TestSymbolic : public UT::Common { | |||
| public: | |||
| TestSymbolic() {} | |||
| }; | |||
| TEST_F(TestSymbolic, test_env) { | |||
| /// Feature: Test the basic functionality of SymbolicKeyInstance class. | |||
| /// Description: Test equality of two SymbolicKeyInstance. | |||
| /// Expectation: True | |||
| TEST_F(TestSymbolic, test_symbolic) { | |||
| auto sk1 = std::make_shared<SymbolicKeyInstance>(NewValueNode(static_cast<int64_t>(1)), abstract::FromValue(1234)); | |||
| auto sk1b = std::make_shared<SymbolicKeyInstance>(NewValueNode(static_cast<int64_t>(1)), abstract::FromValue(1234)); | |||
| ASSERT_EQ(*sk1, *sk1b); | |||
| auto sk2 = std::make_shared<SymbolicKeyInstance>(NewValueNode(static_cast<int64_t>(2)), abstract::FromValue(1234)); | |||
| EnvInstance e = newenv->Set(sk1, 100); | |||
| ASSERT_FALSE(e == *newenv); | |||
| ASSERT_EQ(newenv->Len(), 0); | |||
| ASSERT_EQ(e.Len(), 1); | |||
| ASSERT_EQ(e.Get(sk1, 0), 100); | |||
| ASSERT_EQ(e.Get(sk2, 0), 0); | |||
| EnvInstance e2 = e.Set(sk1b, 200); | |||
| ASSERT_EQ(e2.Len(), 1); | |||
| ASSERT_EQ(e2.Get(sk1, 0), 200); | |||
| ASSERT_EQ(e2.Get(sk2, 0), 0); | |||
| EnvInstance e3 = e2.Set(sk2, 300); | |||
| ASSERT_EQ(e3.Len(), 2); | |||
| ASSERT_EQ(e3.Get(sk1, 0), 200); | |||
| ASSERT_EQ(e3.Get(sk2, 0), 300); | |||
| } | |||
| } // namespace mindspore | |||