/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_ #define MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_ #include #include #include #include #include #include "ir/anf.h" #include "pipeline/static_analysis/abstract_value.h" #include "utils/any.h" namespace mindspore { class SymbolicKeyInstance : public Value { public: SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) : node_(node), abstract_(abstract) {} ~SymbolicKeyInstance() override = default; MS_DECLARE_PARENT(SymbolicKeyInstance, Value); AnfNodePtr node() const { return node_; } abstract::AbstractBasePtr abstract() const { return abstract_; } bool operator==(const SymbolicKeyInstance &other) const { return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); } std::size_t hash() const override { return std::hash{}(node_); } friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { if (inst == nullptr) { os << "[Key][" << "Invalid symbolic key instance" << "]"; } else { os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString(); } return os; } std::string ToString() const override { return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); } bool operator==(const Value &other) const override { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } abstract::AbstractBasePtr ToAbstract() override { return std::make_shared(shared_from_base(), std::make_shared()); } private: AnfNodePtr node_; abstract::AbstractBasePtr abstract_; }; using SymbolicKeyInstancePtr = std::shared_ptr; struct SymbolicKeyInstanceHash { std::size_t operator()(const SymbolicKeyInstancePtr s) const { if (s == nullptr) { return 0; } return s->abstract()->hash(); } }; struct SymbolicKeyInstanceEqual { bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const { if (lhs == nullptr || rhs == nullptr) { return false; } MS_EXCEPTION_IF_NULL(lhs->node()); MS_EXCEPTION_IF_NULL(rhs->node()); MS_EXCEPTION_IF_NULL(lhs->abstract()); MS_EXCEPTION_IF_NULL(rhs->abstract()); return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract()); } }; using EnvInstanceContentsMap = std::unordered_map; // Environment mapping keys to values. // Keys are SymbolicKeyInstances, which represent nodes in the graph along // with inferred properties. class EnvInstance : public Value { public: friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} ~EnvInstance() override = default; MS_DECLARE_PARENT(EnvInstance, Value); abstract::AbstractBasePtr ToAbstract() override { return std::make_shared(shared_from_base(), std::make_shared()); } bool operator==(const EnvInstance &other) const; bool operator==(const Value &other) const override; EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} EnvInstance(EnvInstance &&v) = default; EnvInstance &operator=(EnvInstance &&src) noexcept { if (&src != this) { contents_ = src.contents_; } return *this; }; // Get the sensitivity list for the given key const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { auto iterator = contents_.find(key); if (iterator != contents_.end()) { return iterator->second; } return def; } // Set a value for the given key. EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { EnvInstance rval(contents_); rval.contents_[key] = value; return rval; } // Add two EnvInstances. EnvInstance Add(const EnvInstance &other) const { EnvInstance rval(contents_); for (auto iter_other : other.contents_) { auto item_self = contents_.find(iter_other.first); if (item_self != contents_.end()) { MS_LOG(DEBUG) << "Need to use add"; } else { rval.contents_[iter_other.first] = iter_other.second; } } return rval; } size_t Len() const { return contents_.size(); } std::size_t hash() const override { // deterministic characteristic of member variables. return Len(); } const bool parse_info_ = true; private: EnvInstanceContentsMap contents_; }; using EnvInstancePtr = std::shared_ptr; extern std::shared_ptr newenv; } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_