|
|
|
@@ -49,7 +49,8 @@ void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) |
|
|
|
} |
|
|
|
|
|
|
|
void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { |
|
|
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; |
|
|
|
MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << py::str(py_fn.cast<py::object>()) |
|
|
|
<< ")."; |
|
|
|
auto fn = fn_cache_.find(types); |
|
|
|
if (fn != fn_cache_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; |
|
|
|
@@ -116,7 +117,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { |
|
|
|
auto py_fn = SignMatch(types); |
|
|
|
std::ostringstream buffer; |
|
|
|
buffer << types; |
|
|
|
if (py_fn != py::none()) { |
|
|
|
if (!py_fn.is_none()) { |
|
|
|
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); |
|
|
|
|