Merge pull request !1542 from vlne-v1/I1GZ0B-multitype-funcgraph-bugtags/v0.5.0-beta
| @@ -98,25 +98,29 @@ TypePtr TensorType::DeepCopy() const { | |||
| MS_EXCEPTION_IF_NULL(element_type_); | |||
| if (IsGeneric()) { | |||
| return std::make_shared<TensorType>(); | |||
| } else { | |||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||
| } | |||
| return std::make_shared<TensorType>(element_type_->DeepCopy()); | |||
| } | |||
| std::string TensorType::ToReprString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "tensor"; | |||
| } | |||
| return "tensor[" + element_type_->ToReprString() + "]"; | |||
| } | |||
| std::string TensorType::ToString() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } else { | |||
| return "Tensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| return "Tensor[" + element_type_->ToString() + "]"; | |||
| } | |||
| std::string TensorType::DumpText() const { | |||
| if (element_type_ == nullptr) { | |||
| return "Tensor"; | |||
| } else { | |||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||
| } | |||
| return "Tensor(" + element_type_->DumpText() + ")"; | |||
| } | |||
| bool TensorType::operator==(const Type &other) const { | |||
| @@ -121,7 +121,7 @@ class TensorType : public Object { | |||
| TypePtr DeepCopy() const override; | |||
| std::string ToString() const override; | |||
| std::string ToReprString() const override { return "tensor"; } | |||
| std::string ToReprString() const override; | |||
| std::string DumpText() const override; | |||
| bool operator==(const Type &other) const override; | |||
| @@ -363,6 +363,7 @@ REGISTER_PYBIND_DEFINE( | |||
| (void)m_sub.def("load_type", &TypeIdToType, "load type"); | |||
| (void)m_sub.def( | |||
| "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); | |||
| (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); | |||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | |||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | |||
| .def("__eq__", | |||
| @@ -649,115 +649,6 @@ REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { | |||
| py::arg("get_by_list"), py::arg("sens_param")); | |||
| })); | |||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | |||
| fn_cache_.clear(); | |||
| signatures_ = std::vector<Signature>({// def multitype(*args:ref): | |||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; | |||
| } | |||
| fn_cache_[types] = s_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; | |||
| } | |||
| fn_cache_py_[types] = py_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) { | |||
| TypePtrList types; | |||
| for (auto &type_name : types_name) { | |||
| auto type_ptr = StringToType(type_name); | |||
| if (type_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error "; | |||
| } | |||
| types.push_back(type_ptr); | |||
| } | |||
| Register(types, py_fn); | |||
| } | |||
| void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { | |||
| std::vector<std::string> types_name; | |||
| for (size_t it = 0; it < tuple.size(); ++it) { | |||
| py::object name_py = tuple[it]; | |||
| if (py::isinstance<py::str>(name_py)) { | |||
| types_name.push_back(name_py.cast<std::string>()); | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Register must be string"; | |||
| } | |||
| Register(types_name, py_fn); | |||
| } | |||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||
| if (type->isa<RefType>()) { | |||
| return type->cast<RefTypePtr>()->subtype(); | |||
| } | |||
| return type; | |||
| } | |||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||
| bool find_fn = false; | |||
| py::function py_fn; | |||
| for (auto &item : fn_cache_py_) { | |||
| TypePtrList sign = item.first; | |||
| if (sign.size() != types.size()) { | |||
| continue; | |||
| } | |||
| bool match = true; | |||
| for (size_t i = 0; i < sign.size(); ++i) { | |||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | |||
| match = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!match) { | |||
| continue; | |||
| } | |||
| find_fn = true; | |||
| py_fn = item.second; | |||
| break; | |||
| } | |||
| std::ostringstream buffer; | |||
| buffer << types; | |||
| if (find_fn) { | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); | |||
| } | |||
| MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); | |||
| return func_graph; | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | |||
| << "`, corresponding location info:\n"; | |||
| int idx = 0; | |||
| for (auto &item : fn_cache_py_) { | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; | |||
| continue; | |||
| } | |||
| oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" | |||
| << oss.str(); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { | |||
| (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( | |||
| *m, "MultitypeFuncGraph_") | |||
| .def(py::init<std::string &>()) | |||
| .def("register_fn", &MultitypeFuncGraph::PyRegister); | |||
| })); | |||
| // Generate the ListMap func graph. | |||
| FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| size_t args_num = args_spec_list.size(); | |||
| @@ -30,6 +30,7 @@ | |||
| #include "operator/composite/list_append_operation.h" | |||
| #include "operator/composite/do_signature.h" | |||
| #include "operator/composite/unpack_call.h" | |||
| #include "operator/composite/multitype_funcgraph.h" | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "utils/misc.h" | |||
| #include "utils/any.h" | |||
| @@ -45,31 +46,6 @@ using AbstractTensorPtr = abstract::AbstractTensorPtr; | |||
| using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>; | |||
| using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; | |||
| class MultitypeFuncGraph : public MetaFuncGraph { | |||
| public: | |||
| explicit MultitypeFuncGraph(const std::string &name); | |||
| ~MultitypeFuncGraph() override = default; | |||
| MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) | |||
| using specialize_fn = FuncGraph *(*)(TypePtrList); | |||
| // Register a method which specialize based on types vectors; | |||
| virtual void Register(const TypePtrList &types, specialize_fn s_fn); | |||
| virtual void Register(const TypePtrList &types, const py::function &py_fn); | |||
| virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn); | |||
| virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; | |||
| size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } | |||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const { | |||
| return fn_cache_py_; | |||
| } | |||
| private: | |||
| std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_; | |||
| std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | |||
| }; | |||
| using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | |||
| class HyperMap : public MetaFuncGraph { | |||
| public: | |||
| explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr); | |||
| @@ -0,0 +1,153 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "operator/composite/multitype_funcgraph.h" | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <sstream> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "pipeline/static_analysis/abstract_function.h" | |||
| #include "pipeline/static_analysis/dshape.h" | |||
| #include "pipeline/static_analysis/param_validator.h" | |||
| #include "operator/cc_implementations.h" | |||
| #include "optimizer/opt.h" | |||
| #include "utils/symbolic.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "./common.h" | |||
| #include "ir/signature.h" | |||
| #include "debug/trace.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | |||
| fn_cache_.clear(); | |||
| signatures_ = std::vector<Signature>({// def multitype(*args:ref): | |||
| {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; | |||
| } | |||
| fn_cache_[types] = s_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { | |||
| MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; | |||
| auto fn = fn_cache_.find(types); | |||
| if (fn != fn_cache_.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; | |||
| } | |||
| fn_cache_py_[types] = py_fn; | |||
| } | |||
| void MultitypeFuncGraph::Register(const std::vector<std::string> &types_name, const py::function &py_fn) { | |||
| TypePtrList types; | |||
| for (auto &type_name : types_name) { | |||
| auto type_ptr = StringToType(type_name); | |||
| if (type_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << type_name << " convert from string error "; | |||
| } | |||
| types.push_back(type_ptr); | |||
| } | |||
| Register(types, py_fn); | |||
| } | |||
| void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { | |||
| std::vector<std::string> types_name; | |||
| for (size_t it = 0; it < tuple.size(); ++it) { | |||
| py::object name_py = tuple[it]; | |||
| if (py::isinstance<py::str>(name_py)) { | |||
| types_name.push_back(name_py.cast<std::string>()); | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Register must be string"; | |||
| } | |||
| Register(types_name, py_fn); | |||
| } | |||
| static TypePtr UnwrapRef(const TypePtr &type) { | |||
| if (type->isa<RefType>()) { | |||
| return type->cast<RefTypePtr>()->subtype(); | |||
| } | |||
| return type; | |||
| } | |||
| FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { | |||
| bool find_fn = false; | |||
| py::function py_fn; | |||
| for (auto &item : fn_cache_py_) { | |||
| TypePtrList sign = item.first; | |||
| if (sign.size() != types.size()) { | |||
| continue; | |||
| } | |||
| bool match = true; | |||
| for (size_t i = 0; i < sign.size(); ++i) { | |||
| if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { | |||
| match = false; | |||
| break; | |||
| } | |||
| } | |||
| if (!match) { | |||
| continue; | |||
| } | |||
| find_fn = true; | |||
| py_fn = item.second; | |||
| break; | |||
| } | |||
| std::ostringstream buffer; | |||
| buffer << types; | |||
| if (find_fn) { | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); | |||
| } | |||
| MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); | |||
| return func_graph; | |||
| } | |||
| std::ostringstream oss; | |||
| oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ | |||
| << "`, corresponding location info:\n"; | |||
| int idx = 0; | |||
| for (auto &item : fn_cache_py_) { | |||
| FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; | |||
| continue; | |||
| } | |||
| oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" | |||
| << oss.str(); | |||
| } | |||
| REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { | |||
| (void)py::class_<MultitypeFuncGraph, MetaFuncGraph, std::shared_ptr<MultitypeFuncGraph>>( | |||
| *m, "MultitypeFuncGraph_") | |||
| .def(py::init<std::string &>()) | |||
| .def("register_fn", &MultitypeFuncGraph::PyRegister); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ | |||
| #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <map> | |||
| #include <set> | |||
| #include <memory> | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "utils/misc.h" | |||
| #include "ir/dtype.h" | |||
| #include "ir/meta_func_graph.h" | |||
| namespace mindspore { | |||
| // namespace to support composite operators definition | |||
| namespace prim { | |||
| class MultitypeFuncGraph : public MetaFuncGraph { | |||
| public: | |||
| explicit MultitypeFuncGraph(const std::string &name); | |||
| ~MultitypeFuncGraph() override = default; | |||
| MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) | |||
| using specialize_fn = FuncGraph *(*)(TypePtrList); | |||
| // Register a method which specialize based on types vectors; | |||
| virtual void Register(const TypePtrList &types, specialize_fn s_fn); | |||
| virtual void Register(const TypePtrList &types, const py::function &py_fn); | |||
| virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn); | |||
| virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); | |||
| FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; | |||
| size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } | |||
| const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const { | |||
| return fn_cache_py_; | |||
| } | |||
| private: | |||
| std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_; | |||
| std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | |||
| }; | |||
| using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ | |||
| @@ -156,7 +156,7 @@ def pytype_to_dtype(obj): | |||
| return obj | |||
| if isinstance(obj, type) and obj in _simple_types: | |||
| return _simple_types[obj] | |||
| raise NotImplementedError() | |||
| raise NotImplementedError(f"Unsupported type {obj} for `pytype_to_dtype`.") | |||
| def get_py_obj_dtype(obj): | |||
| @@ -169,7 +169,11 @@ def get_py_obj_dtype(obj): | |||
| Returns: | |||
| Type of MindSpore type. | |||
| """ | |||
| # Tensor | |||
| if hasattr(obj, 'dtype'): | |||
| return tensor_type(obj.dtype()) | |||
| if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): | |||
| return function | |||
| if isinstance(obj, (typing.Type, type)): | |||
| return pytype_to_dtype(obj) | |||
| return pytype_to_dtype(type(obj)) | |||
| @@ -359,6 +359,4 @@ def tensor_grad_scale(scale, grad): | |||
| """Get grad with scale.""" | |||
| if scale == 1.0: | |||
| return grad | |||
| cast_op = P.Cast() | |||
| type_op = P.DType() | |||
| return grad * cast_op(F.scalar_to_array(scale), type_op(grad)) | |||
| return grad * scale | |||
| @@ -16,6 +16,7 @@ | |||
| # ============================================================================ | |||
| """Basic composite operations.""" | |||
| from functools import partial | |||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | |||
| @@ -23,6 +24,7 @@ from ...common import dtype as mstype | |||
| from ...common.api import ms_function | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| from ...common.parameter import Parameter | |||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||
| @@ -144,7 +146,6 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| >>> # `add` is a metagraph object which will add two objects according to | |||
| >>> # input type using ".register" decorator. | |||
| >>> add = MultitypeFuncGraph('add') | |||
| """ | |||
| def __init__(self, name): | |||
| @@ -152,8 +153,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| self.entries = list() | |||
| def __call__(self, *args): | |||
| for sig, fn in self.entries: | |||
| if len(sig) != len(args): | |||
| def unwrap(arg): | |||
| if isinstance(arg, Parameter): | |||
| return arg.data | |||
| return arg | |||
| types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args)) | |||
| for sigs, fn in self.entries: | |||
| if len(sigs) != len(types): | |||
| continue | |||
| if any(not mstype.issubclass_(type_, sig) for sig, type_ in zip(sigs, types)): | |||
| continue | |||
| output = fn(*args) | |||
| return output | |||
| @@ -162,8 +170,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): | |||
| def register(self, *type_names): | |||
| """Register a function for the given type string.""" | |||
| def deco(fn): | |||
| types = tuple(map(mstype.typing.str_to_type, type_names)) | |||
| self.register_fn(type_names, fn) | |||
| self.entries.append((type_names, fn)) | |||
| self.entries.append((types, fn)) | |||
| return fn | |||
| return deco | |||
| @@ -198,38 +207,17 @@ class HyperMap(HyperMap_): | |||
| HyperMap_.__init__(self) | |||
| def __call__(self, *args): | |||
| func = args[0] | |||
| count = 0 | |||
| count_max = 1 | |||
| args_list = args[1:] | |||
| if self.ops is not None: | |||
| func = self.ops | |||
| args_list = args | |||
| for item in args_list: | |||
| if isinstance(item, (tuple, list)): | |||
| count_max = len(item) | |||
| break | |||
| def get_item(x): | |||
| nonlocal count | |||
| if isinstance(x, (tuple, list)): | |||
| return x[count] | |||
| return x | |||
| for i in range(count_max): | |||
| true_args = tuple(map(get_item, args_list)) | |||
| func(*true_args) | |||
| count = i + 1 | |||
| return True | |||
| def register(self, *type_names): | |||
| """Register a function for the given type string.""" | |||
| def deco(fn): | |||
| self.register_fn(type_names, fn) | |||
| return fn | |||
| return deco | |||
| func = self.ops | |||
| args_list = args | |||
| hypermap = self | |||
| if self.ops is None: | |||
| func = args[0] | |||
| args_list = args[1:] | |||
| hypermap = partial(self, func) | |||
| # is leaf | |||
| if not isinstance(args_list[0], (tuple, list)): | |||
| return func(*args_list) | |||
| return tuple(map(hypermap, *args_list)) | |||
| class _ListAppend(ListAppend_): | |||
| """ | |||