|
|
|
@@ -38,6 +38,79 @@ using TensorPtr = mindspore::tensor::TensorPtr; |
|
|
|
using MetaTensor = mindspore::tensor::MetaTensor; |
|
|
|
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; |
|
|
|
|
|
|
|
using InstanceCheckFunc = std::function<bool(const py::object &)>; |
|
|
|
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>; |
|
|
|
class DataConverter { |
|
|
|
public: |
|
|
|
explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {} |
|
|
|
virtual ~DataConverter() = default; |
|
|
|
virtual bool Matched(const py::object &obj) = 0; |
|
|
|
virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) { |
|
|
|
if (convert_func_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "convert func is null"; |
|
|
|
} |
|
|
|
return convert_func_(obj, use_sig, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
InstanceConvertFunc convert_func_ = nullptr; |
|
|
|
}; |
|
|
|
using DataConverterPtr = std::shared_ptr<DataConverter>; |
|
|
|
|
|
|
|
using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>; |
|
|
|
using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>; |
|
|
|
using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>; |
|
|
|
|
|
|
|
// Convert the data according instance type |
|
|
|
template <typename T> |
|
|
|
class ByTypeDataConverter : public DataConverter { |
|
|
|
public: |
|
|
|
explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func) |
|
|
|
: DataConverter(convert_func), check_func_(py::isinstance<T>) {} |
|
|
|
explicit ByTypeDataConverter(const ValuePtr &converted_type) |
|
|
|
: DataConverter( |
|
|
|
[converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }), |
|
|
|
check_func_(py::isinstance<T>) {} |
|
|
|
explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func) |
|
|
|
: DataConverter( |
|
|
|
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }), |
|
|
|
check_func_(py::isinstance<T>) {} |
|
|
|
explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func) |
|
|
|
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr { |
|
|
|
return convert_func(obj, use_sig); |
|
|
|
}), |
|
|
|
check_func_(py::isinstance<T>) {} |
|
|
|
explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func) |
|
|
|
: DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr { |
|
|
|
return convert_func(obj, dtype); |
|
|
|
}), |
|
|
|
check_func_(py::isinstance<T>) {} |
|
|
|
~ByTypeDataConverter() override = default; |
|
|
|
|
|
|
|
bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; } |
|
|
|
|
|
|
|
private: |
|
|
|
InstanceCheckFunc check_func_ = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
// Convert the data according object attribute. |
|
|
|
class ByAttrDataConverter : public DataConverter { |
|
|
|
public: |
|
|
|
ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func) |
|
|
|
: DataConverter( |
|
|
|
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }), |
|
|
|
attr_name_(attr_name) {} |
|
|
|
ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func) |
|
|
|
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr { |
|
|
|
return convert_func(obj, use_sig); |
|
|
|
}), |
|
|
|
attr_name_(attr_name) {} |
|
|
|
bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); } |
|
|
|
|
|
|
|
private: |
|
|
|
const char *attr_name_ = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
FuncGraphPtr ConvertToBpropCut(const py::object &obj) { |
|
|
|
std::vector<std::string> results = data_converter::GetObjKey(obj); |
|
|
|
std::string obj_key = results[0]; |
|
|
|
@@ -69,7 +142,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) { |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { |
|
|
|
ValuePtr ConvertTuple(const py::object &obj, bool use_signature) { |
|
|
|
MS_LOG(DEBUG) << "Converting python tuple"; |
|
|
|
auto tuple = obj.cast<py::tuple>(); |
|
|
|
std::vector<ValuePtr> value_list; |
|
|
|
@@ -77,16 +150,14 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur |
|
|
|
ValuePtr out = nullptr; |
|
|
|
bool success = ConvertData(tuple[it], &out, use_signature); |
|
|
|
if (!success) { |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
value_list.push_back(out); |
|
|
|
} |
|
|
|
*data = std::make_shared<ValueTuple>(value_list); |
|
|
|
|
|
|
|
return true; |
|
|
|
return std::make_shared<ValueTuple>(value_list); |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { |
|
|
|
ValuePtr ConvertList(const py::object &obj, bool use_signature) { |
|
|
|
MS_LOG(DEBUG) << "Converting python list"; |
|
|
|
|
|
|
|
auto list = obj.cast<py::list>(); |
|
|
|
@@ -95,15 +166,14 @@ bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature |
|
|
|
ValuePtr out = nullptr; |
|
|
|
bool success = ConvertData(list[it], &out, use_signature); |
|
|
|
if (!success) { |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
value_list.push_back(out); |
|
|
|
} |
|
|
|
*data = std::make_shared<ValueList>(value_list); |
|
|
|
return true; |
|
|
|
return std::make_shared<ValueList>(value_list); |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { |
|
|
|
ValuePtr ConvertCellList(const py::object &obj, bool use_signature) { |
|
|
|
MS_LOG(DEBUG) << "Converting cell list"; |
|
|
|
py::sequence list = obj; |
|
|
|
std::vector<ValuePtr> value_list; |
|
|
|
@@ -111,15 +181,14 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa |
|
|
|
ValuePtr out = nullptr; |
|
|
|
bool success = ConvertData(list[it], &out, use_signature); |
|
|
|
if (!success) { |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
value_list.push_back(out); |
|
|
|
} |
|
|
|
*data = std::make_shared<ValueTuple>(value_list); |
|
|
|
return true; |
|
|
|
return std::make_shared<ValueTuple>(value_list); |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { |
|
|
|
ValuePtr ConvertDict(const py::object &obj, bool use_signature) { |
|
|
|
MS_LOG(DEBUG) << "Converting python dict"; |
|
|
|
|
|
|
|
auto dict_values = obj.cast<py::dict>(); |
|
|
|
@@ -127,36 +196,37 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { |
|
|
|
for (auto item : dict_values) { |
|
|
|
if (!py::isinstance<py::str>(item.first)) { |
|
|
|
MS_LOG(ERROR) << "The key of dict is only support str."; |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
std::string key = py::str(item.first); |
|
|
|
ValuePtr out = nullptr; |
|
|
|
bool success = ConvertData(dict_values[item.first], &out, use_signature); |
|
|
|
if (!success) { |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
key_values.emplace_back(key, out); |
|
|
|
} |
|
|
|
*data = std::make_shared<ValueDictionary>(key_values); |
|
|
|
return true; |
|
|
|
return std::make_shared<ValueDictionary>(key_values); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { |
|
|
|
ValuePtr ConvertNameSpace(const py::object &obj) { |
|
|
|
MS_LOG(DEBUG) << "Converting python module"; |
|
|
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); |
|
|
|
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); |
|
|
|
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace)); |
|
|
|
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace)); |
|
|
|
return converted; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertDataClass(py::object obj, ValuePtr *const data) { |
|
|
|
ValuePtr ConvertDataClass(const py::object &obj) { |
|
|
|
MS_LOG(DEBUG) << "Converting dataclass"; |
|
|
|
// Maybe the obj is dataclass define |
|
|
|
auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); |
|
|
|
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1; |
|
|
|
*data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
return converted; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { |
|
|
|
ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) { |
|
|
|
MS_LOG(DEBUG) << "Converting primitive object" << use_signature; |
|
|
|
|
|
|
|
// need check the primitive is class type or instance |
|
|
|
@@ -164,96 +234,81 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = |
|
|
|
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { |
|
|
|
auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); |
|
|
|
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1; |
|
|
|
*data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
} else { |
|
|
|
auto primitive = obj.cast<PrimitivePyPtr>(); |
|
|
|
if (primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (py::hasattr(obj, "__setattr_flag__")) { |
|
|
|
if (py::hasattr(obj, "_clone")) { |
|
|
|
auto clone_fn = obj.attr("_clone"); |
|
|
|
py::object new_obj = clone_fn(); |
|
|
|
primitive = new_obj.cast<PrimitivePyPtr>(); |
|
|
|
} |
|
|
|
} |
|
|
|
if (use_signature) { |
|
|
|
*data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive); |
|
|
|
} else { |
|
|
|
*data = primitive; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString(); |
|
|
|
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
} |
|
|
|
auto primitive = obj.cast<PrimitivePyPtr>(); |
|
|
|
if (primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return true; |
|
|
|
if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) { |
|
|
|
auto clone_fn = obj.attr("_clone"); |
|
|
|
py::object new_obj = clone_fn(); |
|
|
|
primitive = new_obj.cast<PrimitivePyPtr>(); |
|
|
|
} |
|
|
|
if (use_signature) { |
|
|
|
return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive); |
|
|
|
} |
|
|
|
return primitive; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { |
|
|
|
ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) { |
|
|
|
MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; |
|
|
|
auto meta = obj.cast<MetaFuncGraphPtr>(); |
|
|
|
if (meta == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (use_signature) { |
|
|
|
*data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta); |
|
|
|
} else { |
|
|
|
*data = meta; |
|
|
|
return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta); |
|
|
|
} |
|
|
|
return true; |
|
|
|
return meta; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) { |
|
|
|
ValuePtr ConvertFuncGraph(const py::object &obj) { |
|
|
|
MS_LOG(DEBUG) << "Converting FuncGraph object"; |
|
|
|
auto func_graph = obj.cast<FuncGraphPtr>(); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null"; |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto new_fg = BasicClone(func_graph); |
|
|
|
new_fg->set_attr("is_load", MakeValue(true)); |
|
|
|
*data = new_fg; |
|
|
|
return true; |
|
|
|
return new_fg; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertSlice(const py::object &obj, ValuePtr *const data) { |
|
|
|
ValuePtr ConvertSlice(const py::object &obj) { |
|
|
|
MS_LOG(DEBUG) << "Converting slice object"; |
|
|
|
|
|
|
|
auto slice_obj = obj.cast<py::slice>(); |
|
|
|
auto convert_func = [obj](std::string attr) -> ValuePtr { |
|
|
|
auto convert_func = [obj](const std::string &attr) -> ValuePtr { |
|
|
|
auto py_attr = py::getattr(obj, attr.c_str()); |
|
|
|
if (py::isinstance<py::none>(py_attr)) { |
|
|
|
return kNone; |
|
|
|
} else if (py::isinstance<py::int_>(py_attr)) { |
|
|
|
int64_t value = py::cast<int64_t>(py_attr); |
|
|
|
} |
|
|
|
if (py::isinstance<py::int_>(py_attr)) { |
|
|
|
auto value = py::cast<int64_t>(py_attr); |
|
|
|
return MakeValue(value); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none"; |
|
|
|
} |
|
|
|
MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none"; |
|
|
|
}; |
|
|
|
ValuePtr start = convert_func("start"); |
|
|
|
ValuePtr stop = convert_func("stop"); |
|
|
|
ValuePtr step = convert_func("step"); |
|
|
|
*data = std::make_shared<ValueSlice>(start, stop, step); |
|
|
|
return true; |
|
|
|
return std::make_shared<ValueSlice>(start, stop, step); |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { |
|
|
|
auto obj = py::cast(cell); |
|
|
|
ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) { |
|
|
|
FuncGraphPtr func_graph = ConvertToFuncGraph(obj); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Parse resolve function error."; |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
// if the cell object has specified bprop, it has user-defined bprop function parse and record it |
|
|
|
if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { |
|
|
|
FuncGraphPtr bprop_graph = nullptr; |
|
|
|
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug")); |
|
|
|
if (enable_bprop_debug) { |
|
|
|
bprop_graph = ConvertToBpropCut(obj); |
|
|
|
} else { |
|
|
|
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); |
|
|
|
} |
|
|
|
FuncGraphPtr bprop_graph = |
|
|
|
enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); |
|
|
|
if (bprop_graph != nullptr) { |
|
|
|
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); |
|
|
|
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); |
|
|
|
@@ -264,200 +319,183 @@ bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { |
|
|
|
auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME)); |
|
|
|
func_graph->set_stage(stage); |
|
|
|
} |
|
|
|
*data = func_graph; |
|
|
|
return true; |
|
|
|
return func_graph; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertOtherObj(py::object obj, ValuePtr *const data) { |
|
|
|
ValuePtr ConvertOtherObj(const py::object &obj) { |
|
|
|
auto obj_type = data_converter::GetObjType(obj); |
|
|
|
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; |
|
|
|
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { |
|
|
|
MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; |
|
|
|
std::string desc = py::str(obj); |
|
|
|
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1; |
|
|
|
*data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
return true; |
|
|
|
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1)); |
|
|
|
} |
|
|
|
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { |
|
|
|
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; |
|
|
|
FuncGraphPtr func_graph = ConvertToFuncGraph(obj); |
|
|
|
if (func_graph == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Parse resolve function error."; |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
*data = func_graph; |
|
|
|
return true; |
|
|
|
return func_graph; |
|
|
|
} |
|
|
|
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { |
|
|
|
// Create the namespace for common class instance |
|
|
|
// When the obj is Cell, default parse the 'construct' |
|
|
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); |
|
|
|
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); |
|
|
|
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); |
|
|
|
return true; |
|
|
|
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); |
|
|
|
} |
|
|
|
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); |
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) { |
|
|
|
ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) { |
|
|
|
ValuePtr data = nullptr; |
|
|
|
auto int_dypte = dyn_cast<Int>(dtype); |
|
|
|
if (int_dypte != nullptr) { |
|
|
|
switch (int_dypte->nbits()) { |
|
|
|
case 8: |
|
|
|
*data = std::make_shared<Int8Imm>(obj); |
|
|
|
data = std::make_shared<Int8Imm>(obj); |
|
|
|
break; |
|
|
|
case 16: |
|
|
|
*data = std::make_shared<Int16Imm>(obj); |
|
|
|
data = std::make_shared<Int16Imm>(obj); |
|
|
|
break; |
|
|
|
case 32: |
|
|
|
*data = std::make_shared<Int32Imm>(obj); |
|
|
|
data = std::make_shared<Int32Imm>(obj); |
|
|
|
break; |
|
|
|
case 64: |
|
|
|
*data = std::make_shared<Int64Imm>(obj); |
|
|
|
data = std::make_shared<Int64Imm>(obj); |
|
|
|
break; |
|
|
|
default: |
|
|
|
*data = std::make_shared<Int64Imm>(obj); |
|
|
|
data = std::make_shared<Int64Imm>(obj); |
|
|
|
} |
|
|
|
return true; |
|
|
|
return data; |
|
|
|
} |
|
|
|
|
|
|
|
auto uint_dypte = dyn_cast<UInt>(dtype); |
|
|
|
if (uint_dypte != nullptr) { |
|
|
|
switch (uint_dypte->nbits()) { |
|
|
|
case 8: |
|
|
|
*data = std::make_shared<UInt8Imm>(obj); |
|
|
|
data = std::make_shared<UInt8Imm>(obj); |
|
|
|
break; |
|
|
|
case 16: |
|
|
|
*data = std::make_shared<UInt16Imm>(obj); |
|
|
|
data = std::make_shared<UInt16Imm>(obj); |
|
|
|
break; |
|
|
|
case 32: |
|
|
|
*data = std::make_shared<UInt32Imm>(obj); |
|
|
|
data = std::make_shared<UInt32Imm>(obj); |
|
|
|
break; |
|
|
|
case 64: |
|
|
|
*data = std::make_shared<UInt64Imm>(obj); |
|
|
|
data = std::make_shared<UInt64Imm>(obj); |
|
|
|
break; |
|
|
|
default: |
|
|
|
*data = std::make_shared<UInt32Imm>(obj); |
|
|
|
data = std::make_shared<UInt32Imm>(obj); |
|
|
|
} |
|
|
|
return true; |
|
|
|
return data; |
|
|
|
} |
|
|
|
|
|
|
|
auto float_dypte = dyn_cast<Float>(dtype); |
|
|
|
if (float_dypte != nullptr) { |
|
|
|
switch (float_dypte->nbits()) { |
|
|
|
case 32: |
|
|
|
*data = std::make_shared<FP32Imm>(obj); |
|
|
|
data = std::make_shared<FP32Imm>(obj); |
|
|
|
break; |
|
|
|
case 64: |
|
|
|
*data = std::make_shared<FP64Imm>(obj); |
|
|
|
data = std::make_shared<FP64Imm>(obj); |
|
|
|
break; |
|
|
|
default: |
|
|
|
*data = std::make_shared<FP32Imm>(obj); |
|
|
|
data = std::make_shared<FP32Imm>(obj); |
|
|
|
} |
|
|
|
return true; |
|
|
|
return data; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertIntegerWithType(const int64_t &obj, ValuePtr *const data, TypePtr dtype = nullptr) { |
|
|
|
ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) { |
|
|
|
auto obj_int64 = py::cast<int64_t>(obj); |
|
|
|
if (dtype == nullptr) { |
|
|
|
*data = std::make_shared<Int64Imm>(obj); |
|
|
|
return true; |
|
|
|
return std::make_shared<Int64Imm>(obj_int64); |
|
|
|
} |
|
|
|
|
|
|
|
return ConvertNumberWithType<int64_t>(obj, data, dtype); |
|
|
|
return ConvertNumberWithType<int64_t>(obj_int64, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) { |
|
|
|
ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) { |
|
|
|
auto obj_float64 = py::cast<float>(obj); |
|
|
|
if (dtype == nullptr) { |
|
|
|
*data = std::make_shared<FP32Imm>(obj); |
|
|
|
return true; |
|
|
|
return std::make_shared<FP32Imm>(obj_float64); |
|
|
|
} |
|
|
|
return ConvertNumberWithType<float>(obj_float64, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
return ConvertNumberWithType<float>(obj, data, dtype); |
|
|
|
template <typename T, typename U> |
|
|
|
ValuePtr PyCast(const py::object &obj) { |
|
|
|
return std::make_shared<T>(py::cast<U>(obj)); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool ConvertSingleData(const py::object &obj, ValuePtr *const data) { |
|
|
|
MS_EXCEPTION_IF_NULL(data); |
|
|
|
ValuePtr converted = nullptr; |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
converted = kNone; |
|
|
|
} else if (py::isinstance<py::bool_>(obj)) { |
|
|
|
converted = std::make_shared<BoolImm>(py::cast<bool>(obj)); |
|
|
|
} else if (py::isinstance<py::str>(obj)) { |
|
|
|
converted = std::make_shared<StringImm>(py::cast<std::string>(obj)); |
|
|
|
} else if (py::isinstance<py::ellipsis>(obj)) { |
|
|
|
converted = kEllipsis; |
|
|
|
} else if (py::isinstance<py::module>(obj)) { |
|
|
|
ConvertNameSpace(obj, &converted); |
|
|
|
} else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { |
|
|
|
ConvertDataClass(obj, &converted); |
|
|
|
} else if (py::isinstance<Type>(obj)) { |
|
|
|
converted = obj.cast<TypePtr>(); |
|
|
|
} else if (py::isinstance<Tensor>(obj)) { |
|
|
|
converted = obj.cast<TensorPtr>(); |
|
|
|
} else if (py::isinstance<MetaTensor>(obj)) { |
|
|
|
converted = obj.cast<MetaTensorPtr>(); |
|
|
|
} else if (py::isinstance<UMonad>(obj)) { |
|
|
|
converted = obj.cast<UMonadPtr>(); |
|
|
|
} else if (py::isinstance<IOMonad>(obj)) { |
|
|
|
converted = obj.cast<IOMonadPtr>(); |
|
|
|
} else if (py::isinstance<EnvInstance>(obj)) { |
|
|
|
auto env = obj.cast<std::shared_ptr<EnvInstance>>(); |
|
|
|
converted = env; |
|
|
|
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { |
|
|
|
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); |
|
|
|
} else { |
|
|
|
return false; |
|
|
|
} |
|
|
|
*data = converted; |
|
|
|
return true; |
|
|
|
template <typename T> |
|
|
|
ValuePtr ObjCast(const py::object &obj) { |
|
|
|
return obj.cast<T>(); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<DataConverterPtr> GetDataConverters() { |
|
|
|
static std::vector<DataConverterPtr> data_converters = { |
|
|
|
// Convert data by python object type. |
|
|
|
std::make_shared<ByTypeDataConverter<py::none>>(kNone), |
|
|
|
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>), |
|
|
|
std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>), |
|
|
|
std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis), |
|
|
|
std::make_shared<ByTypeDataConverter<py::module>>(ConvertNameSpace), |
|
|
|
std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass), |
|
|
|
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>), |
|
|
|
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>), |
|
|
|
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>), |
|
|
|
std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>), |
|
|
|
std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>), |
|
|
|
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>), |
|
|
|
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE, |
|
|
|
[](const py::object &obj) -> ValuePtr { |
|
|
|
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, |
|
|
|
obj); |
|
|
|
}), |
|
|
|
std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType), |
|
|
|
std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType), |
|
|
|
std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict), |
|
|
|
std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice), |
|
|
|
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple), |
|
|
|
std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList), |
|
|
|
std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph), |
|
|
|
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList), |
|
|
|
std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive), |
|
|
|
std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph), |
|
|
|
std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph), |
|
|
|
}; |
|
|
|
return data_converters; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { |
|
|
|
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) { |
|
|
|
// check parameter valid |
|
|
|
if (data == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Data is null pointer"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
ValuePtr converted = nullptr; |
|
|
|
bool ret = ConvertSingleData(obj, &converted); |
|
|
|
if (ret) { |
|
|
|
*data = converted; |
|
|
|
return true; |
|
|
|
bool matched = false; |
|
|
|
auto &&converters = GetDataConverters(); |
|
|
|
for (auto &converter : converters) { |
|
|
|
if (converter->Matched(obj)) { |
|
|
|
converted = converter->ConvertPyObject(obj, use_signature, dtype); |
|
|
|
matched = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (py::isinstance<py::int_>(obj)) { |
|
|
|
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype); |
|
|
|
} else if (py::isinstance<py::float_>(obj)) { |
|
|
|
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype); |
|
|
|
} else if (py::isinstance<py::dict>(obj)) { |
|
|
|
ret = ConvertDict(obj, &converted, use_signature); |
|
|
|
} else if (py::isinstance<py::slice>(obj)) { |
|
|
|
ret = ConvertSlice(obj, &converted); |
|
|
|
} else if (py::isinstance<py::tuple>(obj)) { |
|
|
|
ret = ConvertTuple(obj, &converted, use_signature); |
|
|
|
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { |
|
|
|
ret = ConvertCellList(obj, &converted, use_signature); |
|
|
|
} else if (py::isinstance<Cell>(obj)) { |
|
|
|
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data); |
|
|
|
} else if (py::isinstance<py::list>(obj)) { |
|
|
|
ret = ConvertList(obj, &converted, use_signature); |
|
|
|
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { |
|
|
|
ret = ConvertPrimitive(obj, &converted, use_signature); |
|
|
|
} else if (py::isinstance<MetaFuncGraph>(obj)) { |
|
|
|
ret = ConvertMetaFuncGraph(obj, &converted, use_signature); |
|
|
|
} else if (py::isinstance<FuncGraph>(obj)) { |
|
|
|
ret = ConvertFuncGraph(obj, &converted); |
|
|
|
} else { |
|
|
|
ret = ConvertOtherObj(obj, &converted); |
|
|
|
if (!matched) { |
|
|
|
converted = ConvertOtherObj(obj); |
|
|
|
} |
|
|
|
*data = converted; |
|
|
|
return ret; |
|
|
|
return converted != nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
// convert data to graph |
|
|
|
@@ -488,7 +526,6 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python |
|
|
|
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); |
|
|
|
data_converter::SetObjGraphValue(obj_key, func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
return func_graph; |
|
|
|
} |
|
|
|
namespace data_converter { |
|
|
|
@@ -549,13 +586,8 @@ bool IsCellInstance(const py::object &obj) { |
|
|
|
// create the python class instance |
|
|
|
py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { |
|
|
|
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); |
|
|
|
py::object obj; |
|
|
|
if (params.empty()) { |
|
|
|
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); |
|
|
|
} else { |
|
|
|
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); |
|
|
|
} |
|
|
|
return obj; |
|
|
|
return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type) |
|
|
|
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); |
|
|
|
} |
|
|
|
|
|
|
|
// Generate an appropriate name and set to graph debuginfo |
|
|
|
|