GitOrigin-RevId: e82e5de480
tags/v1.6.0
| @@ -7,24 +7,19 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from ..._imperative_rt.ops import _custom | |||||
| from .._imperative_rt.ops._custom import _install, _uninstall, _get_custom_op_list, _make_custom_op | |||||
| __all__ = [] | |||||
| __all__ = ["load"] | |||||
| for k, v in _custom.__dict__.items(): | |||||
| globals()[k] = v | |||||
| __all__.append(k) | |||||
| def gen_custom_op_maker(custom_op_name): | |||||
| def _gen_custom_op_maker(custom_op_name): | |||||
| def op_maker(**kwargs): | def op_maker(**kwargs): | ||||
| return make_custom_op(custom_op_name, kwargs) | |||||
| return _make_custom_op(custom_op_name, kwargs) | |||||
| return op_maker | return op_maker | ||||
| def load(lib_path): | def load(lib_path): | ||||
| op_in_this_lib = install(lib_path[0:-3], lib_path) | |||||
| op_in_this_lib = _install(lib_path[0:-3], lib_path) | |||||
| for op in op_in_this_lib: | for op in op_in_this_lib: | ||||
| op_maker = gen_custom_op_maker(op) | |||||
| op_maker = _gen_custom_op_maker(op) | |||||
| globals()[op] = op_maker | globals()[op] = op_maker | ||||
| __all__.append(op) | __all__.append(op) | ||||
| @@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) { | |||||
| for (const auto &op: ops_in_lib) { | for (const auto &op: ops_in_lib) { | ||||
| ret.append(op); | ret.append(op); | ||||
| } | } | ||||
| return std::move(ret); | |||||
| return ret; | |||||
| } | } | ||||
| bool uninstall_custom(const std::string &name) { | bool uninstall_custom(const std::string &name) { | ||||
| @@ -701,12 +701,12 @@ py::list get_custom_op_list(void) { | |||||
| } | } | ||||
| void init_custom(pybind11::module m) { | void init_custom(pybind11::module m) { | ||||
| m.def("install", &install_custom); | |||||
| m.def("uninstall", &uninstall_custom); | |||||
| m.def("get_custom_op_list", &get_custom_op_list); | |||||
| m.def("_install", &install_custom); | |||||
| m.def("_uninstall", &uninstall_custom); | |||||
| m.def("_get_custom_op_list", &get_custom_op_list); | |||||
| static PyMethodDef method_def = { | static PyMethodDef method_def = { | ||||
| "make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||||
| "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||||
| }; | }; | ||||
| auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | ||||
| pybind11::setattr(m, method_def.ml_name, func); | pybind11::setattr(m, method_def.ml_name, func); | ||||
| @@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) { | |||||
| return op.name(); | return op.name(); | ||||
| } | } | ||||
| } // custom_opdef | |||||
| OP_TRAIT_REG(CustomOpDef, CustomOpDef) | OP_TRAIT_REG(CustomOpDef, CustomOpDef) | ||||
| .apply_on_physical_tensor(imperative::custom_opdef::apply_on_physical_tensor) | |||||
| .apply_on_var_node(imperative::custom_opdef::apply_on_var_node) | |||||
| .apply_on_device_tensornd(imperative::custom_opdef::apply_on_device_tensornd) | |||||
| .infer_output_attrs_fallible(imperative::custom_opdef::infer_output_attrs_fallible) | |||||
| .infer_output_mem_desc(imperative::custom_opdef::infer_output_mem_desc) | |||||
| .hash(imperative::custom_opdef::hash) | |||||
| .is_same_st(imperative::custom_opdef::is_same_st) | |||||
| .props(imperative::custom_opdef::props) | |||||
| .make_name(imperative::custom_opdef::make_name) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .infer_output_mem_desc(infer_output_mem_desc) | |||||
| .hash(hash) | |||||
| .is_same_st(is_same_st) | |||||
| .props(props) | |||||
| .make_name(make_name) | |||||
| .fallback(); | .fallback(); | ||||
| } // custom_opdef | |||||
| } // imperative | } // imperative | ||||
| } // mgb | } // mgb | ||||
| @@ -60,18 +60,5 @@ public: | |||||
| std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; | std::shared_ptr<OpDef> create_opdef(const custom::RunTimeId&, const custom::Param&) const; | ||||
| }; | }; | ||||
| namespace custom_opdef { // avoid name conflict | |||||
| void apply_on_device_tensornd(const OpDef&, const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*); | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor(const OpDef&, const SmallVector<TensorPtr>&); | |||||
| VarNodeArray apply_on_var_node(const OpDef&, const cg::VarNodeArray&); | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef&, const SmallVector<LogicalTensorDesc>&); | |||||
| size_t hash(const OpDef&); | |||||
| bool is_same_st(const OpDef&, const OpDef&); | |||||
| std::vector<std::pair<const char*, std::string>> props(const OpDef&); | |||||
| std::string make_name(const OpDef&); | |||||
| } // custom_opdef | |||||
| } // imperative | } // imperative | ||||
| } // mgb | } // mgb | ||||
| @@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() { | |||||
| } | } | ||||
| cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { | cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { | ||||
| // auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop()); | |||||
| // for (auto &&inp_var: input()) | |||||
| // ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE); | |||||
| // ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE); | |||||
| // return ret; | |||||
| return OperatorNodeBase::do_make_node_prop(); | return OperatorNodeBase::do_make_node_prop(); | ||||
| } | } | ||||