diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 4c1d2bf50d..45cce7b473 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -24,7 +24,7 @@ #include #include "ir/func_graph.h" -#include "ir/primitive_base.h" +#include "ir/primitive.h" #include "utils/context/ms_context.h" #include "operator/ops.h" diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index 3526e47f96..352c0f31ae 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -15,108 +15,57 @@ */ #include "ir/primitive.h" -#include -#include -#include "ir/signature.h" -#include "operator/ops.h" -#include "./common.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" -#include "pybind11/pytypes.h" -#include "utils/convert_utils_base.h" -#include "utils/primitive_utils.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" +#include namespace mindspore { -static ValuePtr PyArgToValue(const py::object &arg) { - if (py::isinstance(arg) && - py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { - return nullptr; - } - return parse::data_converter::PyDataToValue(arg); -} - -void PrimitivePy::set_signatures( - std::vector> signatures) { - signatures_.clear(); - for (auto &signature : signatures) { - auto [name, rw, kind, arg_default, dtype] = signature; - auto default_value = PyArgToValue(arg_default); - signatures_.emplace_back(name, rw, kind, default_value, dtype); - } - set_has_signature(true); -} - -py::function PrimitivePy::GetBpropFunction() { - static const char *const get_bprop_func_name = "get_bprop"; - if (py::hasattr(python_obj_, get_bprop_func_name)) { - py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); - return fn; +bool Primitive::operator==(const Value &other) const { + if (other.isa()) { + auto other_prim = static_cast(other); + return *this == other_prim; } else { - auto fn = GetBpropFunctionByObj(python_obj_); - return fn; + return false; } } -py::function PrimitivePy::GetComputeFunction() { - static const char *const compute_func_name = "vm_impl"; - - if (py::hasattr(python_obj_, compute_func_name)) { - MS_LOG(INFO) << name() << " compute_func_name"; - py::function fn = python_obj_.attr(compute_func_name).cast(); - return fn; +bool Primitive::operator==(const Primitive &other) const { + if (name() != other.name()) { + return false; } - - static const std::string vm_module = "mindspore.ops.vm_impl_registry"; - static const std::string get_vm_impl_fn = "get_vm_impl_fn"; - MS_LOG(INFO) << name() << ": get_vm_impl_fn"; - py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); - py::function vm_fn = get_fn(python_obj_); - - if (py::isinstance(vm_fn)) { - MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); - vm_fn = mindspore::GetComputeFunction(Primitive::name()); + if (attrs_.size() != other.attrs_.size()) { + return false; } - return vm_fn; + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { + if (item.second == nullptr) { + return false; + } + auto iter = other.attrs_.find(item.first); + if (iter == other.attrs_.end()) { + return false; + } + return *item.second == *iter->second; + }); + return all; } -void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { - std::string attr_name = name; - ValuePtr converted_ret = nullptr; - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; - } - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); +std::string Primitive::GetAttrsText() const { + if (attrs_.empty()) { + return ""; } - (void)this->AddAttr(attr_name, converted_ret); -} -py::dict PrimitivePy::GetAttrDict() { - py::dict attr_dict; + std::ostringstream oss; + oss << "["; + bool is_first = true; for (auto &attr : attrs_) { - attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); + if (is_first) { + is_first = false; + } else { + oss << ", "; + } + oss << attr.first << "=" << attr.second->DumpText(); } - return attr_dict; -} + oss << "]"; -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { - (void)py::enum_(*m, "prim_type", py::arithmetic()) - .value("unknown", PrimType::kPrimTypeUnknown) - .value("builtin", PrimType::kPrimTypeBuiltIn) - .value("py_infer_shape", PrimType::kPrimTypePyInferShape) - .value("user_custom", PrimType::kPrimTypeUserCustom); - (void)py::class_>(*m, "Primitive_") - .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) - .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") - .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") - .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") - .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") - .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") - .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); - })); + return oss.str(); +} } // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 257302c0c4..9732e173ac 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,45 +23,129 @@ #include #include +#include "ir/dtype/type.h" #include "pipeline/static_analysis/abstract_value.h" -#include "utils/misc.h" -#include "utils/log_adapter.h" -#include "ir/primitive_base.h" -#include "ir/signature.h" #include "parallel/ops_info/operator_info.h" - +#include "utils/base_ref_extends.h" namespace mindspore { -class PrimitivePy : public Primitive { +// Supported meta type +enum PrimType { + kPrimTypeUnknown = 0, + kPrimTypeBegin = kTypeUnknown, + kPrimTypeBuiltIn, // Built-in primitive operator + kPrimTypePyInferShape, // Primitive operator defined by custom + kPrimTypePyInferTensor, // Primitive operator defined by custom + kPrimTypeUserCustom +}; + +class Primitive : public Named { public: - PrimitivePy(const py::str &name, const py::object &python_obj) - : Primitive(name, false), python_obj_(python_obj), signatures_() {} - ~PrimitivePy() override = default; - MS_DECLARE_PARENT(PrimitivePy, Primitive); - py::function GetBpropFunction(); - py::function GetComputeFunction(); + explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) + : Named(name), + is_base_(is_base), + has_signature_(false), + prim_type_(prim_type), + record_evaluate_add_attr_(false) {} + + Primitive(const Primitive &prim) + : Named(prim), + attrs_(prim.attrs_), + instance_name_(prim.instance_name_), + is_base_(prim.is_base_), + has_signature_(prim.has_signature_), + prim_type_(prim.prim_type_), + record_evaluate_add_attr_(false) {} + + MS_DECLARE_PARENT(Primitive, Named); + + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); + std::string ToString() const override { return name(); } + void BeginRecordAddAttr() { + evaluate_added_attrs_.clear(); + record_evaluate_add_attr_ = true; + } + void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { + attrs_[name] = attr; + if (record_evaluate_add_attr_) { + evaluate_added_attrs_[name] = attr; + } + return *this; + } + + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + return *this; + } - void set_signatures( - std::vector> - signatures); + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - const std::vector &signatures() const { return signatures_; } + ValuePtr GetAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return iter == attrs_.cend() ? nullptr : iter->second; + } - void AddPyAttr(const py::str &name, const py::object &obj); + const std::unordered_map &attrs() const { return attrs_; } + const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } - py::dict GetAttrDict(); - void set_hook(const py::function &hook) { hook_ = hook; } - py::function hook() const { return hook_; } + // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. + bool HasAttr() const { return !attrs_.empty(); } + bool HasAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return !(iter == attrs_.cend()); + } + void set_prim_type(const PrimType t) { prim_type_ = t; } + void set_instance_name(const std::string s) { instance_name_ = s; } + bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } + bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } + bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } - const bool parse_info_ = true; - const py::object &GetPyObj() const { return python_obj_; } - bool is_tuple_input_ = false; + PrimType prim_type() const { return prim_type_; } + std::string instance_name() const { return instance_name_; } + std::string GetAttrsText() const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; + ~Primitive() override = default; + + void set_has_signature(bool has_signature) { has_signature_ = has_signature; } + bool has_signature() const { return has_signature_; } + bool is_base() const { return is_base_; } + virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } + virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } + + protected: + std::unordered_map attrs_; + std::unordered_map evaluate_added_attrs_; private: - py::object python_obj_; - py::function hook_; - std::vector signatures_; + std::string instance_name_; + bool is_base_; + bool has_signature_; + PrimType prim_type_; + bool record_evaluate_add_attr_; +}; + +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { + os << *p; + return os; +} + +struct PrimitiveEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->name() == t2->name(); + } }; -using PrimitivePyPtr = std::shared_ptr; +struct PrimitiveHasher { + std::size_t operator()(PrimitivePtr const &prim) const { + MS_EXCEPTION_IF_NULL(prim); + return prim->Hash(); + } +}; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base.cc b/mindspore/ccsrc/ir/primitive_base.cc deleted file mode 100644 index 864427fe13..0000000000 --- a/mindspore/ccsrc/ir/primitive_base.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive_base.h" - -#include - -namespace mindspore { -bool Primitive::operator==(const Value &other) const { - if (other.isa()) { - auto other_prim = static_cast(other); - return *this == other_prim; - } else { - return false; - } -} - -bool Primitive::operator==(const Primitive &other) const { - if (name() != other.name()) { - return false; - } - if (attrs_.size() != other.attrs_.size()) { - return false; - } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { - if (item.second == nullptr) { - return false; - } - auto iter = other.attrs_.find(item.first); - if (iter == other.attrs_.end()) { - return false; - } - return *item.second == *iter->second; - }); - return all; -} - -std::string Primitive::GetAttrsText() const { - if (attrs_.empty()) { - return ""; - } - - std::ostringstream oss; - oss << "["; - bool is_first = true; - for (auto &attr : attrs_) { - if (is_first) { - is_first = false; - } else { - oss << ", "; - } - oss << attr.first << "=" << attr.second->DumpText(); - } - oss << "]"; - - return oss.str(); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h deleted file mode 100644 index b34c43d00e..0000000000 --- a/mindspore/ccsrc/ir/primitive_base.h +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ - -#include -#include -#include -#include -#include - -#include "ir/dtype/type.h" -#include "pybind11/pybind11.h" - -namespace py = pybind11; - -namespace mindspore { -// Supported meta type -enum PrimType { - kPrimTypeUnknown = 0, - kPrimTypeBegin = kTypeUnknown, - kPrimTypeBuiltIn, // Built-in primitive operator - kPrimTypePyInferShape, // Primitive operator defined by custom - kPrimTypePyInferTensor, // Primitive operator defined by custom - kPrimTypeUserCustom -}; - -class Primitive : public Named { - public: - explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), - is_base_(is_base), - has_signature_(false), - prim_type_(prim_type), - record_evaluate_add_attr_(false) {} - - Primitive(const Primitive &prim) - : Named(prim), - attrs_(prim.attrs_), - instance_name_(prim.instance_name_), - is_base_(prim.is_base_), - has_signature_(prim.has_signature_), - prim_type_(prim.prim_type_), - record_evaluate_add_attr_(false) {} - - MS_DECLARE_PARENT(Primitive, Named); - - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); - std::string ToString() const override { return name(); } - void BeginRecordAddAttr() { - evaluate_added_attrs_.clear(); - record_evaluate_add_attr_ = true; - } - void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } - Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { - attrs_[name] = attr; - if (record_evaluate_add_attr_) { - evaluate_added_attrs_[name] = attr; - } - return *this; - } - - Primitive &SetAttrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } - return *this; - } - - void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - - ValuePtr GetAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return iter == attrs_.cend() ? nullptr : iter->second; - } - - const std::unordered_map &attrs() const { return attrs_; } - const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } - - // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. - bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return !(iter == attrs_.cend()); - } - void set_prim_type(const PrimType t) { prim_type_ = t; } - void set_instance_name(const std::string s) { instance_name_ = s; } - bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } - bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } - bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } - - PrimType prim_type() const { return prim_type_; } - std::string instance_name() const { return instance_name_; } - std::string GetAttrsText() const; - bool operator==(const Value &other) const override; - bool operator==(const Primitive &other) const; - ~Primitive() override = default; - - void set_has_signature(bool has_signature) { has_signature_ = has_signature; } - bool has_signature() const { return has_signature_; } - bool is_base() const { return is_base_; } - - protected: - std::unordered_map attrs_; - std::unordered_map evaluate_added_attrs_; - - private: - std::string instance_name_; - bool is_base_; - bool has_signature_; - PrimType prim_type_; - bool record_evaluate_add_attr_; -}; - -inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { - os << *p; - return os; -} - -struct PrimitiveEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->name() == t2->name(); - } -}; - -struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { - MS_EXCEPTION_IF_NULL(prim); - return prim->Hash(); - } -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base_extends.cc b/mindspore/ccsrc/ir/primitive_extends.cc similarity index 96% rename from mindspore/ccsrc/ir/primitive_base_extends.cc rename to mindspore/ccsrc/ir/primitive_extends.cc index 64bdafa4d1..9df46920bf 100644 --- a/mindspore/ccsrc/ir/primitive_base_extends.cc +++ b/mindspore/ccsrc/ir/primitive_extends.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ir/primitive_base.h" +#include "ir/primitive.h" #include "pipeline/static_analysis/abstract_function.h" namespace mindspore { diff --git a/mindspore/ccsrc/ir/primitive_py.cc b/mindspore/ccsrc/ir/primitive_py.cc new file mode 100644 index 0000000000..b672f470c9 --- /dev/null +++ b/mindspore/ccsrc/ir/primitive_py.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive_py.h" +#include +#include +#include "ir/signature.h" +#include "operator/ops.h" +#include "./common.h" +#include "pipeline/parse/python_adapter.h" +#include "pipeline/parse/data_converter.h" +#include "pybind11/pytypes.h" +#include "utils/convert_utils_base.h" +#include "utils/primitive_utils.h" +#include "utils/base_ref_py.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace { +constexpr auto kBpropAttrName = "bprop"; +constexpr auto kCellHookAttrName = "cell_hook"; +constexpr auto kCellIDAttrName = "cell_id"; +void SyncData(const py::object &arg) { + if (py::isinstance(arg)) { + py::tuple arg_list = py::cast(arg); + for (size_t i = 0; i < arg_list.size(); i++) { + SyncData(arg_list[i]); + } + } + if (py::isinstance(arg)) { + auto tensor = py::cast(arg); + (void)tensor->data_sync(); + } +} +} // namespace +std::map PrimitivePy::hook_grad_; +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} + +void PrimitivePy::set_signatures( + std::vector> signatures) { + signatures_.clear(); + for (auto &signature : signatures) { + auto [name, rw, kind, arg_default, dtype] = signature; + auto default_value = PyArgToValue(arg_default); + signatures_.emplace_back(name, rw, kind, default_value, dtype); + } + set_has_signature(true); +} + +py::function PrimitivePy::GetBpropFunction() { + static const char *const get_bprop_func_name = "get_bprop"; + if (py::hasattr(python_obj_, get_bprop_func_name)) { + py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); + return fn; + } else { + auto fn = GetBpropFunctionByObj(python_obj_); + return fn; + } +} + +BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { + auto py_args = py::tuple(args.size()); + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg:" << i << ":"; + i++; + } + py::object obj; + bool is_bprop = this->HasAttr(kBpropAttrName); + if (is_bprop) { + SyncData(py_args); + obj = hook_(*py_args); + return std::make_shared(obj); + } + SyncData(py_args[2]); + bool is_cell = this->HasAttr(kCellHookAttrName); + if (is_cell) { + auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); + auto iter = hook_grad_.find(cell_id); + if (iter != hook_grad_.end()) { + auto hook_args = py::tuple(3); + hook_args[0] = cell_id; + hook_args[1] = py::make_tuple(iter->second); + hook_args[2] = py::make_tuple(py_args[2]); + obj = hook_(*hook_args); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + hook_grad_.erase(cell_id); + } else { + hook_grad_[cell_id] = py_args[2]; + obj = py_args[2]; + } + } else { + // Hook operator for execute variable hook function + obj = hook_(py::make_tuple(py_args[2])); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + } + obj = py::make_tuple(obj); + return std::make_shared(obj); +} + +py::function PrimitivePy::GetComputeFunction() { + static const char *const compute_func_name = "vm_impl"; + + if (py::hasattr(python_obj_, compute_func_name)) { + MS_LOG(INFO) << name() << " compute_func_name"; + py::function fn = python_obj_.attr(compute_func_name).cast(); + return fn; + } + + static const std::string vm_module = "mindspore.ops.vm_impl_registry"; + static const std::string get_vm_impl_fn = "get_vm_impl_fn"; + MS_LOG(INFO) << name() << ": get_vm_impl_fn"; + py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); + py::function vm_fn = get_fn(python_obj_); + + if (py::isinstance(vm_fn)) { + MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); + vm_fn = mindspore::GetComputeFunction(Primitive::name()); + } + return vm_fn; +} + +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { + std::string attr_name = name; + ValuePtr converted_ret = nullptr; + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; + } + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); + } + (void)this->AddAttr(attr_name, converted_ret); +} + +py::dict PrimitivePy::GetAttrDict() { + py::dict attr_dict; + for (auto &attr : attrs_) { + attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); + } + return attr_dict; +} + +void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->isa()) { + MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; + } + auto primitive_py = primitive->cast(); + MS_EXCEPTION_IF_NULL(primitive_py); + this->set_hook(primitive_py->hook()); +} + +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { + (void)py::enum_(*m, "prim_type", py::arithmetic()) + .value("unknown", PrimType::kPrimTypeUnknown) + .value("builtin", PrimType::kPrimTypeBuiltIn) + .value("py_infer_shape", PrimType::kPrimTypePyInferShape) + .value("user_custom", PrimType::kPrimTypeUserCustom); + (void)py::class_>(*m, "Primitive_") + .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) + .def(py::init()) + .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") + .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") + .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") + .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") + .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") + .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_py.h b/mindspore/ccsrc/ir/primitive_py.h new file mode 100644 index 0000000000..96acc831f2 --- /dev/null +++ b/mindspore/ccsrc/ir/primitive_py.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ + +#include +#include +#include +#include +#include +#include + +#include "pipeline/static_analysis/abstract_value.h" +#include "utils/misc.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" +#include "ir/primitive.h" +#include "ir/signature.h" +#include "parallel/ops_info/operator_info.h" +namespace py = pybind11; +namespace mindspore { +class PrimitivePy : public Primitive { + public: + PrimitivePy(const py::str &name, const py::object &python_obj) + : Primitive(name, false), python_obj_(python_obj), signatures_() {} + ~PrimitivePy() override = default; + MS_DECLARE_PARENT(PrimitivePy, Primitive); + py::function GetBpropFunction(); + py::function GetComputeFunction(); + + void set_signatures( + std::vector> + signatures); + + const std::vector &signatures() const { return signatures_; } + + void CopyHookFunction(const PrimitivePtr &primitive) override; + + void AddPyAttr(const py::str &name, const py::object &obj); + + py::dict GetAttrDict(); + void set_hook(const py::function &hook) { hook_ = hook; } + py::function hook() const { return hook_; } + BaseRef RunHookFunction(const VectorRef &args) const override; + const bool parse_info_ = true; + const py::object &GetPyObj() const { return python_obj_; } + bool is_tuple_input_ = false; + + private: + py::object python_obj_; + py::function hook_; + std::vector signatures_; + static std::map hook_grad_; +}; + +using PrimitivePyPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc index 5b3194608e..021b49e20c 100644 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "kernel/cpu/addn_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc index 9cc5126c08..811ea3ea16 100644 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "kernel/cpu/allgather_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" #include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc index d8f2ef421b..dac382f447 100644 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc @@ -16,7 +16,6 @@ #include "kernel/cpu/concat_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc index 07da3dcc25..c9e60f0f4c 100644 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc @@ -17,7 +17,6 @@ #include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" #include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc index 28090817cb..8aad9d19e6 100644 --- a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "kernel/cpu/gather_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc index d2530430e9..afb3e6a247 100644 --- a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc +++ b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc @@ -15,7 +15,6 @@ */ #include "kernel/cpu/slice_cpu_kernel.h" #include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 2a98cc7e15..02673d9373 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -21,7 +21,7 @@ #include #include #include "ir/anf.h" -#include "ir/primitive_base.h" +#include "ir/primitive.h" namespace mindspore { // namespace to support primitive operators diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 4141fb5413..bdec1dc93c 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -20,7 +20,7 @@ #include #include #include "ir/anf.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "ir/meta_func_graph.h" #include "ir/func_graph_cloner.h" #include "ir/manager.h" @@ -232,10 +232,7 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res std::vector outputs; auto bprop_cut = std::make_shared("bprop_cut", py::object()); - if (!prim->is_base()) { - PrimitivePyPtr prim_py = dyn_cast(prim); - bprop_cut->set_hook(prim_py->hook()); - } + bprop_cut->CopyHookFunction(prim); auto cell_id = GetValue(prim->GetAttr("cell_id")); if (cell_id != "") { diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.h b/mindspore/ccsrc/optimizer/py_pass_manager.h index eaeefce213..f7218d5ab2 100644 --- a/mindspore/ccsrc/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/optimizer/py_pass_manager.h @@ -23,7 +23,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "utils/graph_utils.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h index a0b7ee5478..c33ea9f588 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h @@ -33,7 +33,7 @@ #include "utils/log_adapter.h" #include "ir/anf.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "pipeline/static_analysis/analysis_context.h" #include "pipeline/static_analysis/abstract_function.h" #include "pipeline/parse/parse.h" diff --git a/mindspore/ccsrc/pipeline/static_analysis/utils.h b/mindspore/ccsrc/pipeline/static_analysis/utils.h index 6a709ea99c..97227dbbe3 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/utils.h +++ b/mindspore/ccsrc/pipeline/static_analysis/utils.h @@ -27,7 +27,6 @@ #include "utils/any.h" #include "utils/misc.h" #include "utils/convert_utils.h" -#include "ir/primitive.h" namespace mindspore { namespace abstract { diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc index 3d09233d99..2b2749090a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc @@ -181,15 +181,6 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo if (AnfAlgo::IsGraphKernel(node)) { return ProcessGraphKernelOp(func_graph, node); - } else { - // insert cast for single op. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); } // insert cast for single op. AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc index 59be003b15..4db08d0859 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -15,7 +15,6 @@ */ #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" #include "pre_activate/common/helper.h" - namespace mindspore { namespace opt { AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index 60ae869227..4b4d44858b 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -26,7 +26,7 @@ #include #include "pybind11/pybind11.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h index 01f96e251d..16c8d1fa7c 100644 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ b/mindspore/ccsrc/transform/op_adapter_base.h @@ -29,7 +29,6 @@ #include "ir/primitive.h" #include "ir/value.h" #include "transform/types.h" - #ifdef ENABLE_GE #ifdef OPEN_SOURCE #include "graph/types.h" diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 93edda3e34..2a9240ac84 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -29,7 +29,7 @@ #include #include "ir/anf.h" -#include "ir/primitive_base.h" +#include "ir/primitive.h" #include "ir/scalar.h" #include "ir/tensor.h" #include "debug/label.h" diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index ed6f15ce70..047b330158 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -648,57 +648,8 @@ void FinalVM::SyncData(const py::object &arg) { BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "input for operation:"; - auto prim_py = dyn_cast(prim); - std::size_t args_size = args.size(); - auto py_args = py::tuple(args_size); - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg: " << i << ":"; - i++; - } - // Hook operator for execute cell custom bprop function - py::object obj; - bool is_bprop = prim->HasAttr("bprop"); - if (is_bprop) { - SyncData(py_args); - py::function fn_bprop = prim_py->hook(); - obj = fn_bprop(*py_args); - return obj; - } - // Sync gradient data from device to host - SyncData(py_args[2]); - bool is_cell = prim->HasAttr("cell_hook"); - if (is_cell) { - // Hook operator for execute cell hook function - std::string cell_id = GetValue(prim->GetAttr("cell_id")); - if (_hook_grad.find(cell_id) != _hook_grad.end()) { - std::size_t hook_args_size = 3; - auto hook_args = py::tuple(hook_args_size); - hook_args[0] = cell_id; - hook_args[1] = py::make_tuple(_hook_grad[cell_id]); - hook_args[2] = py::make_tuple(py_args[2]); - py::function fn_hook = prim_py->hook(); - obj = fn_hook(*hook_args); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - _hook_grad.erase(cell_id); - } else { - _hook_grad[cell_id] = py_args[2]; - obj = py_args[2]; - } - } else { - // Hook operator for execute variable hook function - py::function fn_hook = prim_py->hook(); - obj = fn_hook(py::make_tuple(py_args[2])); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - } - obj = py::make_tuple(obj); - return obj; + MS_EXCEPTION_IF_NULL(prim); + return prim->RunHookFunction(args); } - } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index e905ec528b..02a1ad4ddb 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -161,7 +161,6 @@ class FinalVM { {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; - std::map _hook_grad; }; using FinalVMPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index 51b2c9b3d5..cb23cdaf43 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -30,7 +30,7 @@ #include "operator/ops.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "utils/convert_utils.h" #include "utils/primitive_utils.h" #include "debug/draw.h" diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 1d1389b54a..87d32f3e76 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "ir/value.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "operator/ops.h" #include "./common.h"