| @@ -24,7 +24,7 @@ | |||
| #include <unordered_map> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "operator/ops.h" | |||
| @@ -15,108 +15,57 @@ | |||
| */ | |||
| #include "ir/primitive.h" | |||
| #include <mutex> | |||
| #include <utility> | |||
| #include "ir/signature.h" | |||
| #include "operator/ops.h" | |||
| #include "./common.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "pybind11/pytypes.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/primitive_utils.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| #include <utility> | |||
| namespace mindspore { | |||
| static ValuePtr PyArgToValue(const py::object &arg) { | |||
| if (py::isinstance<SignatureEnumKind>(arg) && | |||
| py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { | |||
| return nullptr; | |||
| } | |||
| return parse::data_converter::PyDataToValue(arg); | |||
| } | |||
| void PrimitivePy::set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | |||
| signatures_.clear(); | |||
| for (auto &signature : signatures) { | |||
| auto [name, rw, kind, arg_default, dtype] = signature; | |||
| auto default_value = PyArgToValue(arg_default); | |||
| signatures_.emplace_back(name, rw, kind, default_value, dtype); | |||
| } | |||
| set_has_signature(true); | |||
| } | |||
| py::function PrimitivePy::GetBpropFunction() { | |||
| static const char *const get_bprop_func_name = "get_bprop"; | |||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | |||
| py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); | |||
| return fn; | |||
| bool Primitive::operator==(const Value &other) const { | |||
| if (other.isa<Primitive>()) { | |||
| auto other_prim = static_cast<const Primitive &>(other); | |||
| return *this == other_prim; | |||
| } else { | |||
| auto fn = GetBpropFunctionByObj(python_obj_); | |||
| return fn; | |||
| return false; | |||
| } | |||
| } | |||
| py::function PrimitivePy::GetComputeFunction() { | |||
| static const char *const compute_func_name = "vm_impl"; | |||
| if (py::hasattr(python_obj_, compute_func_name)) { | |||
| MS_LOG(INFO) << name() << " compute_func_name"; | |||
| py::function fn = python_obj_.attr(compute_func_name).cast<py::function>(); | |||
| return fn; | |||
| bool Primitive::operator==(const Primitive &other) const { | |||
| if (name() != other.name()) { | |||
| return false; | |||
| } | |||
| static const std::string vm_module = "mindspore.ops.vm_impl_registry"; | |||
| static const std::string get_vm_impl_fn = "get_vm_impl_fn"; | |||
| MS_LOG(INFO) << name() << ": get_vm_impl_fn"; | |||
| py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); | |||
| py::function vm_fn = get_fn(python_obj_); | |||
| if (py::isinstance<py::none>(vm_fn)) { | |||
| MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); | |||
| vm_fn = mindspore::GetComputeFunction(Primitive::name()); | |||
| if (attrs_.size() != other.attrs_.size()) { | |||
| return false; | |||
| } | |||
| return vm_fn; | |||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||
| if (item.second == nullptr) { | |||
| return false; | |||
| } | |||
| auto iter = other.attrs_.find(item.first); | |||
| if (iter == other.attrs_.end()) { | |||
| return false; | |||
| } | |||
| return *item.second == *iter->second; | |||
| }); | |||
| return all; | |||
| } | |||
| void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; | |||
| } | |||
| bool converted = parse::ConvertData(obj, &converted_ret); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||
| std::string Primitive::GetAttrsText() const { | |||
| if (attrs_.empty()) { | |||
| return ""; | |||
| } | |||
| (void)this->AddAttr(attr_name, converted_ret); | |||
| } | |||
| py::dict PrimitivePy::GetAttrDict() { | |||
| py::dict attr_dict; | |||
| std::ostringstream oss; | |||
| oss << "["; | |||
| bool is_first = true; | |||
| for (auto &attr : attrs_) { | |||
| attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| oss << ", "; | |||
| } | |||
| oss << attr.first << "=" << attr.second->DumpText(); | |||
| } | |||
| return attr_dict; | |||
| } | |||
| oss << "]"; | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | |||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | |||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | |||
| .def(py::init<py::str &, py::object>()) | |||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | |||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | |||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | |||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | |||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | |||
| })); | |||
| return oss.str(); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -23,45 +23,129 @@ | |||
| #include <string> | |||
| #include <tuple> | |||
| #include "ir/dtype/type.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "ir/signature.h" | |||
| #include "parallel/ops_info/operator_info.h" | |||
| #include "utils/base_ref_extends.h" | |||
| namespace mindspore { | |||
| class PrimitivePy : public Primitive { | |||
| // Supported meta type | |||
| enum PrimType { | |||
| kPrimTypeUnknown = 0, | |||
| kPrimTypeBegin = kTypeUnknown, | |||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||
| kPrimTypeUserCustom | |||
| }; | |||
| class Primitive : public Named { | |||
| public: | |||
| PrimitivePy(const py::str &name, const py::object &python_obj) | |||
| : Primitive(name, false), python_obj_(python_obj), signatures_() {} | |||
| ~PrimitivePy() override = default; | |||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||
| py::function GetBpropFunction(); | |||
| py::function GetComputeFunction(); | |||
| explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | |||
| : Named(name), | |||
| is_base_(is_base), | |||
| has_signature_(false), | |||
| prim_type_(prim_type), | |||
| record_evaluate_add_attr_(false) {} | |||
| Primitive(const Primitive &prim) | |||
| : Named(prim), | |||
| attrs_(prim.attrs_), | |||
| instance_name_(prim.instance_name_), | |||
| is_base_(prim.is_base_), | |||
| has_signature_(prim.has_signature_), | |||
| prim_type_(prim.prim_type_), | |||
| record_evaluate_add_attr_(false) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||
| std::string ToString() const override { return name(); } | |||
| void BeginRecordAddAttr() { | |||
| evaluate_added_attrs_.clear(); | |||
| record_evaluate_add_attr_ = true; | |||
| } | |||
| void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } | |||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| attrs_[name] = attr; | |||
| if (record_evaluate_add_attr_) { | |||
| evaluate_added_attrs_[name] = attr; | |||
| } | |||
| return *this; | |||
| } | |||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| for (auto &attr : attrs) { | |||
| attrs_[attr.first] = attr.second; | |||
| } | |||
| return *this; | |||
| } | |||
| void set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | |||
| signatures); | |||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| ValuePtr GetAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; } | |||
| py::dict GetAttrDict(); | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||
| bool HasAttr() const { return !attrs_.empty(); } | |||
| bool HasAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return !(iter == attrs_.cend()); | |||
| } | |||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | |||
| void set_instance_name(const std::string s) { instance_name_ = s; } | |||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | |||
| const bool parse_info_ = true; | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| bool is_tuple_input_ = false; | |||
| PrimType prim_type() const { return prim_type_; } | |||
| std::string instance_name() const { return instance_name_; } | |||
| std::string GetAttrsText() const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Primitive &other) const; | |||
| ~Primitive() override = default; | |||
| void set_has_signature(bool has_signature) { has_signature_ = has_signature; } | |||
| bool has_signature() const { return has_signature_; } | |||
| bool is_base() const { return is_base_; } | |||
| virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } | |||
| virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } | |||
| protected: | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_; | |||
| private: | |||
| py::object python_obj_; | |||
| py::function hook_; | |||
| std::vector<Signature> signatures_; | |||
| std::string instance_name_; | |||
| bool is_base_; | |||
| bool has_signature_; | |||
| PrimType prim_type_; | |||
| bool record_evaluate_add_attr_; | |||
| }; | |||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||
| os << *p; | |||
| return os; | |||
| } | |||
| struct PrimitiveEqual { | |||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->name() == t2->name(); | |||
| } | |||
| }; | |||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | |||
| struct PrimitiveHasher { | |||
| std::size_t operator()(PrimitivePtr const &prim) const { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return prim->Hash(); | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ | |||
| @@ -1,71 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/primitive_base.h" | |||
| #include <utility> | |||
| namespace mindspore { | |||
| bool Primitive::operator==(const Value &other) const { | |||
| if (other.isa<Primitive>()) { | |||
| auto other_prim = static_cast<const Primitive &>(other); | |||
| return *this == other_prim; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Primitive::operator==(const Primitive &other) const { | |||
| if (name() != other.name()) { | |||
| return false; | |||
| } | |||
| if (attrs_.size() != other.attrs_.size()) { | |||
| return false; | |||
| } | |||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||
| if (item.second == nullptr) { | |||
| return false; | |||
| } | |||
| auto iter = other.attrs_.find(item.first); | |||
| if (iter == other.attrs_.end()) { | |||
| return false; | |||
| } | |||
| return *item.second == *iter->second; | |||
| }); | |||
| return all; | |||
| } | |||
| std::string Primitive::GetAttrsText() const { | |||
| if (attrs_.empty()) { | |||
| return ""; | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "["; | |||
| bool is_first = true; | |||
| for (auto &attr : attrs_) { | |||
| if (is_first) { | |||
| is_first = false; | |||
| } else { | |||
| oss << ", "; | |||
| } | |||
| oss << attr.first << "=" << attr.second->DumpText(); | |||
| } | |||
| oss << "]"; | |||
| return oss.str(); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -1,150 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||
| #define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include "ir/dtype/type.h" | |||
| #include "pybind11/pybind11.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| // Supported meta type | |||
| enum PrimType { | |||
| kPrimTypeUnknown = 0, | |||
| kPrimTypeBegin = kTypeUnknown, | |||
| kPrimTypeBuiltIn, // Built-in primitive operator | |||
| kPrimTypePyInferShape, // Primitive operator defined by custom | |||
| kPrimTypePyInferTensor, // Primitive operator defined by custom | |||
| kPrimTypeUserCustom | |||
| }; | |||
| class Primitive : public Named { | |||
| public: | |||
| explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) | |||
| : Named(name), | |||
| is_base_(is_base), | |||
| has_signature_(false), | |||
| prim_type_(prim_type), | |||
| record_evaluate_add_attr_(false) {} | |||
| Primitive(const Primitive &prim) | |||
| : Named(prim), | |||
| attrs_(prim.attrs_), | |||
| instance_name_(prim.instance_name_), | |||
| is_base_(prim.is_base_), | |||
| has_signature_(prim.has_signature_), | |||
| prim_type_(prim.prim_type_), | |||
| record_evaluate_add_attr_(false) {} | |||
| MS_DECLARE_PARENT(Primitive, Named); | |||
| abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); | |||
| std::string ToString() const override { return name(); } | |||
| void BeginRecordAddAttr() { | |||
| evaluate_added_attrs_.clear(); | |||
| record_evaluate_add_attr_ = true; | |||
| } | |||
| void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } | |||
| Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { | |||
| attrs_[name] = attr; | |||
| if (record_evaluate_add_attr_) { | |||
| evaluate_added_attrs_[name] = attr; | |||
| } | |||
| return *this; | |||
| } | |||
| Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| for (auto &attr : attrs) { | |||
| attrs_[attr.first] = attr.second; | |||
| } | |||
| return *this; | |||
| } | |||
| void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } | |||
| void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } | |||
| ValuePtr GetAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; } | |||
| // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. | |||
| bool HasAttr() const { return !attrs_.empty(); } | |||
| bool HasAttr(const std::string &attrName) const { | |||
| auto iter = attrs_.find(attrName); | |||
| return !(iter == attrs_.cend()); | |||
| } | |||
| void set_prim_type(const PrimType t) { prim_type_ = t; } | |||
| void set_instance_name(const std::string s) { instance_name_ = s; } | |||
| bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } | |||
| bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } | |||
| bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } | |||
| PrimType prim_type() const { return prim_type_; } | |||
| std::string instance_name() const { return instance_name_; } | |||
| std::string GetAttrsText() const; | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Primitive &other) const; | |||
| ~Primitive() override = default; | |||
| void set_has_signature(bool has_signature) { has_signature_ = has_signature; } | |||
| bool has_signature() const { return has_signature_; } | |||
| bool is_base() const { return is_base_; } | |||
| protected: | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_; | |||
| private: | |||
| std::string instance_name_; | |||
| bool is_base_; | |||
| bool has_signature_; | |||
| PrimType prim_type_; | |||
| bool record_evaluate_add_attr_; | |||
| }; | |||
| inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { | |||
| os << *p; | |||
| return os; | |||
| } | |||
| struct PrimitiveEqual { | |||
| bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { | |||
| MS_EXCEPTION_IF_NULL(t1); | |||
| MS_EXCEPTION_IF_NULL(t2); | |||
| return t1->name() == t2->name(); | |||
| } | |||
| }; | |||
| struct PrimitiveHasher { | |||
| std::size_t operator()(PrimitivePtr const &prim) const { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return prim->Hash(); | |||
| } | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ | |||
| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/primitive_base.h" | |||
| #include "ir/primitive.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| namespace mindspore { | |||
| @@ -0,0 +1,195 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/primitive_py.h" | |||
| #include <mutex> | |||
| #include <utility> | |||
| #include "ir/signature.h" | |||
| #include "operator/ops.h" | |||
| #include "./common.h" | |||
| #include "pipeline/parse/python_adapter.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| #include "pybind11/pytypes.h" | |||
| #include "utils/convert_utils_base.h" | |||
| #include "utils/primitive_utils.h" | |||
| #include "utils/base_ref_py.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| constexpr auto kBpropAttrName = "bprop"; | |||
| constexpr auto kCellHookAttrName = "cell_hook"; | |||
| constexpr auto kCellIDAttrName = "cell_id"; | |||
| void SyncData(const py::object &arg) { | |||
| if (py::isinstance<py::tuple>(arg)) { | |||
| py::tuple arg_list = py::cast<py::tuple>(arg); | |||
| for (size_t i = 0; i < arg_list.size(); i++) { | |||
| SyncData(arg_list[i]); | |||
| } | |||
| } | |||
| if (py::isinstance<tensor::Tensor>(arg)) { | |||
| auto tensor = py::cast<tensor::TensorPtr>(arg); | |||
| (void)tensor->data_sync(); | |||
| } | |||
| } | |||
| } // namespace | |||
| std::map<std::string, py::object> PrimitivePy::hook_grad_; | |||
| static ValuePtr PyArgToValue(const py::object &arg) { | |||
| if (py::isinstance<SignatureEnumKind>(arg) && | |||
| py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { | |||
| return nullptr; | |||
| } | |||
| return parse::data_converter::PyDataToValue(arg); | |||
| } | |||
| void PrimitivePy::set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) { | |||
| signatures_.clear(); | |||
| for (auto &signature : signatures) { | |||
| auto [name, rw, kind, arg_default, dtype] = signature; | |||
| auto default_value = PyArgToValue(arg_default); | |||
| signatures_.emplace_back(name, rw, kind, default_value, dtype); | |||
| } | |||
| set_has_signature(true); | |||
| } | |||
| py::function PrimitivePy::GetBpropFunction() { | |||
| static const char *const get_bprop_func_name = "get_bprop"; | |||
| if (py::hasattr(python_obj_, get_bprop_func_name)) { | |||
| py::function fn = python_obj_.attr(get_bprop_func_name)().cast<py::function>(); | |||
| return fn; | |||
| } else { | |||
| auto fn = GetBpropFunctionByObj(python_obj_); | |||
| return fn; | |||
| } | |||
| } | |||
| BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { | |||
| auto py_args = py::tuple(args.size()); | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg:" << i << ":"; | |||
| i++; | |||
| } | |||
| py::object obj; | |||
| bool is_bprop = this->HasAttr(kBpropAttrName); | |||
| if (is_bprop) { | |||
| SyncData(py_args); | |||
| obj = hook_(*py_args); | |||
| return std::make_shared<PyObjectRef>(obj); | |||
| } | |||
| SyncData(py_args[2]); | |||
| bool is_cell = this->HasAttr(kCellHookAttrName); | |||
| if (is_cell) { | |||
| auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName)); | |||
| auto iter = hook_grad_.find(cell_id); | |||
| if (iter != hook_grad_.end()) { | |||
| auto hook_args = py::tuple(3); | |||
| hook_args[0] = cell_id; | |||
| hook_args[1] = py::make_tuple(iter->second); | |||
| hook_args[2] = py::make_tuple(py_args[2]); | |||
| obj = hook_(*hook_args); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| hook_grad_.erase(cell_id); | |||
| } else { | |||
| hook_grad_[cell_id] = py_args[2]; | |||
| obj = py_args[2]; | |||
| } | |||
| } else { | |||
| // Hook operator for execute variable hook function | |||
| obj = hook_(py::make_tuple(py_args[2])); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| } | |||
| obj = py::make_tuple(obj); | |||
| return std::make_shared<PyObjectRef>(obj); | |||
| } | |||
| py::function PrimitivePy::GetComputeFunction() { | |||
| static const char *const compute_func_name = "vm_impl"; | |||
| if (py::hasattr(python_obj_, compute_func_name)) { | |||
| MS_LOG(INFO) << name() << " compute_func_name"; | |||
| py::function fn = python_obj_.attr(compute_func_name).cast<py::function>(); | |||
| return fn; | |||
| } | |||
| static const std::string vm_module = "mindspore.ops.vm_impl_registry"; | |||
| static const std::string get_vm_impl_fn = "get_vm_impl_fn"; | |||
| MS_LOG(INFO) << name() << ": get_vm_impl_fn"; | |||
| py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); | |||
| py::function vm_fn = get_fn(python_obj_); | |||
| if (py::isinstance<py::none>(vm_fn)) { | |||
| MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast<std::string>(); | |||
| vm_fn = mindspore::GetComputeFunction(Primitive::name()); | |||
| } | |||
| return vm_fn; | |||
| } | |||
| void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; | |||
| } | |||
| bool converted = parse::ConvertData(obj, &converted_ret); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||
| } | |||
| (void)this->AddAttr(attr_name, converted_ret); | |||
| } | |||
| py::dict PrimitivePy::GetAttrDict() { | |||
| py::dict attr_dict; | |||
| for (auto &attr : attrs_) { | |||
| attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); | |||
| } | |||
| return attr_dict; | |||
| } | |||
| void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| if (!primitive->isa<PrimitivePy>()) { | |||
| MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; | |||
| } | |||
| auto primitive_py = primitive->cast<PrimitivePyPtr>(); | |||
| MS_EXCEPTION_IF_NULL(primitive_py); | |||
| this->set_hook(primitive_py->hook()); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { | |||
| (void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic()) | |||
| .value("unknown", PrimType::kPrimTypeUnknown) | |||
| .value("builtin", PrimType::kPrimTypeBuiltIn) | |||
| .value("py_infer_shape", PrimType::kPrimTypePyInferShape) | |||
| .value("user_custom", PrimType::kPrimTypeUserCustom); | |||
| (void)py::class_<PrimitivePy, std::shared_ptr<PrimitivePy>>(*m, "Primitive_") | |||
| .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) | |||
| .def(py::init<py::str &, py::object>()) | |||
| .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") | |||
| .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") | |||
| .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") | |||
| .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") | |||
| .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") | |||
| .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ | |||
| #define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <map> | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "utils/misc.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/signature.h" | |||
| #include "parallel/ops_info/operator_info.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| class PrimitivePy : public Primitive { | |||
| public: | |||
| PrimitivePy(const py::str &name, const py::object &python_obj) | |||
| : Primitive(name, false), python_obj_(python_obj), signatures_() {} | |||
| ~PrimitivePy() override = default; | |||
| MS_DECLARE_PARENT(PrimitivePy, Primitive); | |||
| py::function GetBpropFunction(); | |||
| py::function GetComputeFunction(); | |||
| void set_signatures( | |||
| std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> | |||
| signatures); | |||
| const std::vector<Signature> &signatures() const { return signatures_; } | |||
| void CopyHookFunction(const PrimitivePtr &primitive) override; | |||
| void AddPyAttr(const py::str &name, const py::object &obj); | |||
| py::dict GetAttrDict(); | |||
| void set_hook(const py::function &hook) { hook_ = hook; } | |||
| py::function hook() const { return hook_; } | |||
| BaseRef RunHookFunction(const VectorRef &args) const override; | |||
| const bool parse_info_ = true; | |||
| const py::object &GetPyObj() const { return python_obj_; } | |||
| bool is_tuple_input_ = false; | |||
| private: | |||
| py::object python_obj_; | |||
| py::function hook_; | |||
| std::vector<Signature> signatures_; | |||
| static std::map<std::string, py::object> hook_grad_; | |||
| }; | |||
| using PrimitivePyPtr = std::shared_ptr<PrimitivePy>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ | |||
| @@ -16,7 +16,6 @@ | |||
| #include "kernel/cpu/addn_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -16,7 +16,6 @@ | |||
| #include "kernel/cpu/allgather_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "device/cpu/mpi/mpi_adapter.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore { | |||
| @@ -16,7 +16,6 @@ | |||
| #include "kernel/cpu/concat_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -17,7 +17,6 @@ | |||
| #include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "device/cpu/mpi/mpi_adapter.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "kernel/cpu/gather_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "kernel/cpu/slice_cpu_kernel.h" | |||
| #include "device/cpu/cpu_device_address.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -21,7 +21,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| // namespace to support primitive operators | |||
| @@ -20,7 +20,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "ir/meta_func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/manager.h" | |||
| @@ -232,10 +232,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res | |||
| std::vector<AnfNodePtr> outputs; | |||
| auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object()); | |||
| if (!prim->is_base()) { | |||
| PrimitivePyPtr prim_py = dyn_cast<PrimitivePy>(prim); | |||
| bprop_cut->set_hook(prim_py->hook()); | |||
| } | |||
| bprop_cut->CopyHookFunction(prim); | |||
| auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| if (cell_id != "") { | |||
| @@ -23,7 +23,7 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "common/utils.h" | |||
| @@ -33,7 +33,7 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "pipeline/static_analysis/analysis_context.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "pipeline/parse/parse.h" | |||
| @@ -27,7 +27,6 @@ | |||
| #include "utils/any.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ir/primitive.h" | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| @@ -181,15 +181,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo | |||
| if (AnfAlgo::IsGraphKernel(node)) { | |||
| return ProcessGraphKernelOp(func_graph, node); | |||
| } else { | |||
| // insert cast for single op. | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto new_node = InsertCastForInput(func_graph, cnode); | |||
| // process output | |||
| return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true)); | |||
| } | |||
| // insert cast for single op. | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" | |||
| #include "pre_activate/common/helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { | |||
| @@ -26,7 +26,7 @@ | |||
| #include <unordered_set> | |||
| #include "pybind11/pybind11.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| namespace mindspore { | |||
| @@ -29,7 +29,6 @@ | |||
| #include "ir/primitive.h" | |||
| #include "ir/value.h" | |||
| #include "transform/types.h" | |||
| #ifdef ENABLE_GE | |||
| #ifdef OPEN_SOURCE | |||
| #include "graph/types.h" | |||
| @@ -29,7 +29,7 @@ | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| #include "ir/primitive_base.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/scalar.h" | |||
| #include "ir/tensor.h" | |||
| #include "debug/label.h" | |||
| @@ -648,57 +648,8 @@ void FinalVM::SyncData(const py::object &arg) { | |||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "input for operation:"; | |||
| auto prim_py = dyn_cast<PrimitivePy>(prim); | |||
| std::size_t args_size = args.size(); | |||
| auto py_args = py::tuple(args_size); | |||
| size_t i = 0; | |||
| for (auto &arg : args) { | |||
| py_args[i] = BaseRefToPyData(arg); | |||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | |||
| i++; | |||
| } | |||
| // Hook operator for execute cell custom bprop function | |||
| py::object obj; | |||
| bool is_bprop = prim->HasAttr("bprop"); | |||
| if (is_bprop) { | |||
| SyncData(py_args); | |||
| py::function fn_bprop = prim_py->hook(); | |||
| obj = fn_bprop(*py_args); | |||
| return obj; | |||
| } | |||
| // Sync gradient data from device to host | |||
| SyncData(py_args[2]); | |||
| bool is_cell = prim->HasAttr("cell_hook"); | |||
| if (is_cell) { | |||
| // Hook operator for execute cell hook function | |||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | |||
| std::size_t hook_args_size = 3; | |||
| auto hook_args = py::tuple(hook_args_size); | |||
| hook_args[0] = cell_id; | |||
| hook_args[1] = py::make_tuple(_hook_grad[cell_id]); | |||
| hook_args[2] = py::make_tuple(py_args[2]); | |||
| py::function fn_hook = prim_py->hook(); | |||
| obj = fn_hook(*hook_args); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| _hook_grad.erase(cell_id); | |||
| } else { | |||
| _hook_grad[cell_id] = py_args[2]; | |||
| obj = py_args[2]; | |||
| } | |||
| } else { | |||
| // Hook operator for execute variable hook function | |||
| py::function fn_hook = prim_py->hook(); | |||
| obj = fn_hook(py::make_tuple(py_args[2])); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| obj = py_args[2]; | |||
| } | |||
| } | |||
| obj = py::make_tuple(obj); | |||
| return obj; | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| return prim->RunHookFunction(args); | |||
| } | |||
| } // namespace compile | |||
| } // namespace mindspore | |||
| @@ -161,7 +161,6 @@ class FinalVM { | |||
| {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, | |||
| {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, | |||
| {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; | |||
| std::map<std::string, py::object> _hook_grad; | |||
| }; | |||
| using FinalVMPtr = std::shared_ptr<FinalVM>; | |||
| @@ -30,7 +30,7 @@ | |||
| #include "operator/ops.h" | |||
| #include "ir/manager.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/primitive_utils.h" | |||
| #include "debug/draw.h" | |||
| @@ -19,7 +19,7 @@ | |||
| #include "common/common_test.h" | |||
| #include "ir/value.h" | |||
| #include "ir/primitive.h" | |||
| #include "ir/primitive_py.h" | |||
| #include "operator/ops.h" | |||
| #include "./common.h" | |||