Merge pull request !1296 from leopz/fix_primitivetags/v0.3.0-alpha
| @@ -24,75 +24,13 @@ | |||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| #include "pybind11/pytypes.h" | #include "pybind11/pytypes.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/primitive_utils.h" | |||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "pybind_api/export_flags.h" | #include "pybind_api/export_flags.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| using mindspore::abstract::AbstractFunction; | |||||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { | |||||
| auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); | |||||
| return prim_func; | |||||
| } | |||||
| static py::function GetBpropFunctionByObj(py::object obj) { | |||||
| static const std::string get_bprop_fn = "get_bprop_fn"; | |||||
| static const std::string ad_module = "mindspore.ops._grad"; | |||||
| py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj); | |||||
| return fn; | |||||
| } | |||||
| py::function Primitive::GetBpropFunction() { | |||||
| auto fn = GetBpropFunctionByObj(py::str(name())); | |||||
| if (fn.is_none()) { | |||||
| MS_LOG(WARNING) << "Can't find bprop function for " << name(); | |||||
| } | |||||
| return fn; | |||||
| } | |||||
| py::function Primitive::GetComputeFunction() { | |||||
| static const std::string module = "mindspore._extends.builtin_operations"; | |||||
| py::module mod = py::module::import(common::SafeCStr(module)); | |||||
| if (!py::hasattr(mod, common::SafeCStr(name()))) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name())); | |||||
| // If raise AttributeError, user can't understand. This case need raise NotImplementedError. | |||||
| throw py::error_already_set(); | |||||
| } | |||||
| py::object fn = mod.attr(common::SafeCStr(name())); | |||||
| return fn; | |||||
| } | |||||
| bool Primitive::operator==(const Value &other) const { | |||||
| if (other.isa<Primitive>()) { | |||||
| auto other_prim = static_cast<const Primitive &>(other); | |||||
| return *this == other_prim; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool Primitive::operator==(const Primitive &other) const { | |||||
| if (name() != other.name()) { | |||||
| return false; | |||||
| } | |||||
| if (attrs_.size() != other.attrs_.size()) { | |||||
| return false; | |||||
| } | |||||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||||
| if (item.second == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto iter = other.attrs_.find(item.first); | |||||
| if (iter == other.attrs_.end()) { | |||||
| return false; | |||||
| } | |||||
| return *item.second == *iter->second; | |||||
| }); | |||||
| return all; | |||||
| } | |||||
| void Primitive::set_signatures( | |||||
| void PrimitivePy::set_signatures( | |||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | ||||
| signatures_.clear(); | signatures_.clear(); | ||||
| for (auto &signature : signatures) { | for (auto &signature : signatures) { | ||||
| @@ -104,27 +42,7 @@ void Primitive::set_signatures( | |||||
| std::tie(name, rw, kind, default_value, dtype) = signature; | std::tie(name, rw, kind, default_value, dtype) = signature; | ||||
| signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype)); | signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype)); | ||||
| } | } | ||||
| } | |||||
| std::string Primitive::GetAttrsText() const { | |||||
| if (attrs_.empty()) { | |||||
| return ""; | |||||
| } | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| bool is_first = true; | |||||
| for (auto &attr : attrs_) { | |||||
| if (is_first) { | |||||
| is_first = false; | |||||
| } else { | |||||
| oss << ", "; | |||||
| } | |||||
| oss << attr.first << "=" << attr.second->DumpText(); | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| set_has_signature(true); | |||||
| } | } | ||||
| py::function PrimitivePy::GetBpropFunction() { | py::function PrimitivePy::GetBpropFunction() { | ||||
| @@ -158,7 +76,7 @@ py::function PrimitivePy::GetComputeFunction() { | |||||
| if (py::isinstance<py::none>(vm_fn)) { | if (py::isinstance<py::none>(vm_fn)) { | ||||
| MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); | MS_LOG(DEBUG) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); | ||||
| vm_fn = Primitive::GetComputeFunction(); | |||||
| vm_fn = mindspore::GetComputeFunction(Primitive::name()); | |||||
| } | } | ||||
| return vm_fn; | return vm_fn; | ||||
| } | } | ||||
| @@ -22,59 +22,26 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| #include "utils/misc.h" | #include "utils/misc.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "ir/primitive_base.h" | |||||
| #include "ir/signature.h" | #include "ir/signature.h" | ||||
| #include "parallel/ops_info/operator_info.h" | #include "parallel/ops_info/operator_info.h" | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| using abstract::AbstractBasePtr; | |||||
| using abstract::AbstractBasePtrList; | |||||
| // Supported meta type | |||||
| enum PrimType { | |||||
| kPrimTypeUnknown = 0, | |||||
| kPrimTypeBegin = kTypeUnknown, | |||||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||||
| kPrimTypeUserCustom | |||||
| }; | |||||
| class Primitive : public Named { | |||||
| class PrimitivePy : public Primitive { | |||||
| public: | public: | ||||
| explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) | |||||
| : Named(name), signatures_(), prim_type_(prim_type) {} | |||||
| Primitive(const Primitive &prim) | |||||
| : Named(prim), | |||||
| attrs_(prim.attrs_), | |||||
| signatures_(prim.signatures_), | |||||
| instance_name_(prim.instance_name_), | |||||
| prim_type_(prim.prim_type_) {} | |||||
| MS_DECLARE_PARENT(Primitive, Named); | |||||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||||
| std::string ToString() const override { return name(); } | |||||
| virtual py::function GetBpropFunction(); | |||||
| virtual py::function GetComputeFunction(); | |||||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||||
| attrs_[name] = attr; | |||||
| return *this; | |||||
| } | |||||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||||
| for (auto &attr : attrs) { | |||||
| attrs_[attr.first] = attr.second; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| PrimitivePy(const py::str &name, const py::object &python_obj) | |||||
| : Primitive(name, false), python_obj_(python_obj), signatures_() {} | |||||
| ~PrimitivePy() override = default; | |||||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||||
| py::function GetBpropFunction(); | |||||
| py::function GetComputeFunction(); | |||||
| void set_signatures( | void set_signatures( | ||||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | ||||
| @@ -82,52 +49,6 @@ class Primitive : public Named { | |||||
| const std::vector<Signature> &signatures() const { return signatures_; } | const std::vector<Signature> &signatures() const { return signatures_; } | ||||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||||
| ValuePtr GetAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | |||||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||||
| } | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||||
| bool HasAttr() const { return !attrs_.empty(); } | |||||
| bool HasAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | |||||
| return !(iter == attrs_.cend()); | |||||
| } | |||||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | |||||
| void set_instance_name(const std::string s) { instance_name_ = s; } | |||||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | |||||
| PrimType prim_type() const { return prim_type_; } | |||||
| std::string instance_name() const { return instance_name_; } | |||||
| std::string GetAttrsText() const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Primitive &other) const; | |||||
| ~Primitive() override = default; | |||||
| protected: | |||||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||||
| private: | |||||
| std::vector<Signature> signatures_; | |||||
| std::string instance_name_; | |||||
| PrimType prim_type_; | |||||
| }; | |||||
| class PrimitivePy : public Primitive { | |||||
| public: | |||||
| PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} | |||||
| ~PrimitivePy() override = default; | |||||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||||
| py::function GetBpropFunction() override; | |||||
| py::function GetComputeFunction() override; | |||||
| void AddPyAttr(const py::str &name, const py::object &obj); | void AddPyAttr(const py::str &name, const py::object &obj); | ||||
| py::dict GetAttrDict(); | py::dict GetAttrDict(); | ||||
| @@ -138,25 +59,9 @@ class PrimitivePy : public Primitive { | |||||
| private: | private: | ||||
| py::object python_obj_; | py::object python_obj_; | ||||
| std::vector<Signature> signatures_; | |||||
| }; | }; | ||||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | ||||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||||
| os << *p; | |||||
| return os; | |||||
| } | |||||
| struct PrimitiveEqual { | |||||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||||
| MS_EXCEPTION_IF_NULL(t1); | |||||
| MS_EXCEPTION_IF_NULL(t2); | |||||
| return t1->name() == t2->name(); | |||||
| } | |||||
| }; | |||||
| struct PrimitiveHasher { | |||||
| std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } | |||||
| }; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ | #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ir/primitive_base.h" | |||||
| #include <utility> | |||||
| namespace mindspore { | |||||
| bool Primitive::operator==(const Value &other) const { | |||||
| if (other.isa<Primitive>()) { | |||||
| auto other_prim = static_cast<const Primitive &>(other); | |||||
| return *this == other_prim; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool Primitive::operator==(const Primitive &other) const { | |||||
| if (name() != other.name()) { | |||||
| return false; | |||||
| } | |||||
| if (attrs_.size() != other.attrs_.size()) { | |||||
| return false; | |||||
| } | |||||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||||
| if (item.second == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto iter = other.attrs_.find(item.first); | |||||
| if (iter == other.attrs_.end()) { | |||||
| return false; | |||||
| } | |||||
| return *item.second == *iter->second; | |||||
| }); | |||||
| return all; | |||||
| } | |||||
| std::string Primitive::GetAttrsText() const { | |||||
| if (attrs_.empty()) { | |||||
| return ""; | |||||
| } | |||||
| std::ostringstream oss; | |||||
| oss << "["; | |||||
| bool is_first = true; | |||||
| for (auto &attr : attrs_) { | |||||
| if (is_first) { | |||||
| is_first = false; | |||||
| } else { | |||||
| oss << ", "; | |||||
| } | |||||
| oss << attr.first << "=" << attr.second->DumpText(); | |||||
| } | |||||
| oss << "]"; | |||||
| return oss.str(); | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,128 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||||
| #define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include "ir/dtype/type.h" | |||||
| namespace mindspore { | |||||
| // Supported meta type | |||||
| enum PrimType { | |||||
| kPrimTypeUnknown = 0, | |||||
| kPrimTypeBegin = kTypeUnknown, | |||||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||||
| kPrimTypeUserCustom | |||||
| }; | |||||
| class Primitive : public Named { | |||||
| public: | |||||
| explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | |||||
| : Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} | |||||
| Primitive(const Primitive &prim) | |||||
| : Named(prim), | |||||
| attrs_(prim.attrs_), | |||||
| instance_name_(prim.instance_name_), | |||||
| is_base_(prim.is_base_), | |||||
| has_signature_(prim.has_signature_), | |||||
| prim_type_(prim.prim_type_) {} | |||||
| MS_DECLARE_PARENT(Primitive, Named); | |||||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||||
| std::string ToString() const override { return name(); } | |||||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||||
| attrs_[name] = attr; | |||||
| return *this; | |||||
| } | |||||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||||
| for (auto &attr : attrs) { | |||||
| attrs_[attr.first] = attr.second; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||||
| ValuePtr GetAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | |||||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||||
| } | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||||
| bool HasAttr() const { return !attrs_.empty(); } | |||||
| bool HasAttr(const std::string &attrName) const { | |||||
| auto iter = attrs_.find(attrName); | |||||
| return !(iter == attrs_.cend()); | |||||
| } | |||||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | |||||
| void set_instance_name(const std::string s) { instance_name_ = s; } | |||||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | |||||
| PrimType prim_type() const { return prim_type_; } | |||||
| std::string instance_name() const { return instance_name_; } | |||||
| std::string GetAttrsText() const; | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Primitive &other) const; | |||||
| ~Primitive() override = default; | |||||
| void set_has_signature(bool has_signature) { has_signature_ = has_signature; } | |||||
| bool has_signature() const { return has_signature_; } | |||||
| bool is_base() const { return is_base_; } | |||||
| protected: | |||||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||||
| private: | |||||
| std::string instance_name_; | |||||
| bool is_base_; | |||||
| bool has_signature_; | |||||
| PrimType prim_type_; | |||||
| }; | |||||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||||
| os << *p; | |||||
| return os; | |||||
| } | |||||
| struct PrimitiveEqual { | |||||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||||
| MS_EXCEPTION_IF_NULL(t1); | |||||
| MS_EXCEPTION_IF_NULL(t2); | |||||
| return t1->name() == t2->name(); | |||||
| } | |||||
| }; | |||||
| struct PrimitiveHasher { | |||||
| std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ir/primitive_base.h" | |||||
| #include "pipeline/static_analysis/abstract_function.h" | |||||
| namespace mindspore { | |||||
| abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { | |||||
| auto prim_func = std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), anf_node); | |||||
| return prim_func; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -36,8 +36,8 @@ using PatternListType = std::initializer_list<BaseRef>; | |||||
| const std::vector<Signature> &GetSignature(const ValuePtr &function) { | const std::vector<Signature> &GetSignature(const ValuePtr &function) { | ||||
| static const auto empty = std::vector<Signature>(); | static const auto empty = std::vector<Signature>(); | ||||
| if (function->isa<Primitive>()) { | |||||
| return function->cast<PrimitivePtr>()->signatures(); | |||||
| if (function->isa<Primitive>() && function->cast<PrimitivePtr>()->has_signature()) { | |||||
| return function->cast<PrimitivePyPtr>()->signatures(); | |||||
| } else if (function->isa<MetaFuncGraph>()) { | } else if (function->isa<MetaFuncGraph>()) { | ||||
| return function->cast<MetaFuncGraphPtr>()->signatures(); | return function->cast<MetaFuncGraphPtr>()->signatures(); | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/primitive.h" | |||||
| #include "ir/meta_func_graph.h" | #include "ir/meta_func_graph.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| @@ -30,6 +31,7 @@ | |||||
| #include "operator/ops.h" | #include "operator/ops.h" | ||||
| #include "operator/composite/composite.h" | #include "operator/composite/composite.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/primitive_utils.h" | |||||
| #include "debug/info.h" | #include "debug/info.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| @@ -49,7 +51,7 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { | |||||
| auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + | auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + | ||||
| grad_op_child_scope_prefix + prim->name()); | grad_op_child_scope_prefix + prim->name()); | ||||
| ScopeGuard scope_guard(scope); | ScopeGuard scope_guard(scope); | ||||
| py::function fn = prim->GetBpropFunction(); | |||||
| py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction(); | |||||
| if (fn == nullptr || py::isinstance<py::none>(fn)) { | if (fn == nullptr || py::isinstance<py::none>(fn)) { | ||||
| MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; | MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/primitive_utils.h" | |||||
| #include "pipeline/parse/python_adapter.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "common/utils.h" | |||||
| namespace mindspore { | |||||
| py::function GetBpropFunctionByObj(py::object obj) { | |||||
| static const std::string get_bprop_fn = "get_bprop_fn"; | |||||
| static const std::string ad_module = "mindspore.ops._grad"; | |||||
| py::function fn = parse::python_adapter::GetPyFn(ad_module, get_bprop_fn)(obj); | |||||
| return fn; | |||||
| } | |||||
| py::function GetBpropFunction(std::string name) { | |||||
| auto fn = GetBpropFunctionByObj(py::str(name)); | |||||
| if (fn.is_none()) { | |||||
| MS_LOG(WARNING) << "Can't find bprop function for " << name; | |||||
| } | |||||
| return fn; | |||||
| } | |||||
| py::function GetComputeFunction(std::string name) { | |||||
| static const std::string module = "mindspore._extends.builtin_operations"; | |||||
| py::module mod = py::module::import(common::SafeCStr(module)); | |||||
| if (!py::hasattr(mod, common::SafeCStr(name))) { | |||||
| PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name)); | |||||
| // If raise AttributeError, user can't understand. This case need raise NotImplementedError. | |||||
| throw py::error_already_set(); | |||||
| } | |||||
| py::object fn = mod.attr(common::SafeCStr(name)); | |||||
| return fn; | |||||
| } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ | |||||
| #define MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ | |||||
| #include <string> | |||||
| #include "pybind11/pybind11.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| py::function GetBpropFunctionByObj(py::object obj); | |||||
| py::function GetBpropFunction(std::string name); | |||||
| py::function GetComputeFunction(std::string name); | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ | |||||
| @@ -31,6 +31,7 @@ | |||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/primitive_utils.h" | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -443,7 +444,7 @@ BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { | |||||
| PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim); | PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim); | ||||
| MS_LOG(DEBUG) << "operation start " << prim->name(); | MS_LOG(DEBUG) << "operation start " << prim->name(); | ||||
| auto func = operation != nullptr ? operation->GetComputeFunction() : prim->GetComputeFunction(); | |||||
| auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name()); | |||||
| if (py::isinstance<py::none>(func)) { | if (py::isinstance<py::none>(func)) { | ||||
| MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; | MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; | ||||
| } | } | ||||
| @@ -390,7 +390,7 @@ TEST_F(TestOps, Conv2dAttrTest) { | |||||
| } | } | ||||
| TEST_F(TestOps, CustomOpAttrTest) { | TEST_F(TestOps, CustomOpAttrTest) { | ||||
| Primitive prim("CustomOp", kPrimTypePyInferShape); | |||||
| Primitive prim("CustomOp", true, kPrimTypePyInferShape); | |||||
| prim.SetAttrs({ | prim.SetAttrs({ | ||||
| {"attr1", MakeValue(3)}, | {"attr1", MakeValue(3)}, | ||||
| {"attr2", MakeValue(1)}, | {"attr2", MakeValue(1)}, | ||||