Browse Source

!14846 [ME]Remove some redudunt codes

From: @chenfei52
Reviewed-by: 
Signed-off-by:
pull/14846/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
1e4e26c74b
3 changed files with 218 additions and 228 deletions
  1. +217
    -185
      mindspore/ccsrc/pipeline/jit/parse/data_converter.cc
  2. +0
    -42
      mindspore/ccsrc/pipeline/jit/parse/parse.cc
  3. +1
    -1
      mindspore/ccsrc/pipeline/jit/parse/parse_base.h

+ 217
- 185
mindspore/ccsrc/pipeline/jit/parse/data_converter.cc View File

@@ -38,6 +38,79 @@ using TensorPtr = mindspore::tensor::TensorPtr;
using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;

using InstanceCheckFunc = std::function<bool(const py::object &)>;
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
class DataConverter {
public:
explicit DataConverter(InstanceConvertFunc convert_func) : convert_func_(std::move(convert_func)) {}
virtual ~DataConverter() = default;
virtual bool Matched(const py::object &obj) = 0;
virtual ValuePtr ConvertPyObject(const py::object &obj, bool use_sig, const TypePtr &dtype) {
if (convert_func_ == nullptr) {
MS_LOG(EXCEPTION) << "convert func is null";
}
return convert_func_(obj, use_sig, dtype);
}

private:
InstanceConvertFunc convert_func_ = nullptr;
};
using DataConverterPtr = std::shared_ptr<DataConverter>;

using ArgsObjConvertFunc = std::function<ValuePtr(const py::object &)>;
using ArgsObjSigConvertFunc = std::function<ValuePtr(const py::object &, bool)>;
using ArgsOjbTypeConvertFunc = std::function<ValuePtr(const py::object &, const TypePtr &)>;

// Convert the data according instance type
template <typename T>
class ByTypeDataConverter : public DataConverter {
public:
explicit ByTypeDataConverter(const InstanceConvertFunc &convert_func)
: DataConverter(convert_func), check_func_(py::isinstance<T>) {}
explicit ByTypeDataConverter(const ValuePtr &converted_type)
: DataConverter(
[converted_type](const py::object &, bool, const TypePtr &) -> ValuePtr { return converted_type; }),
check_func_(py::isinstance<T>) {}
explicit ByTypeDataConverter(const ArgsObjConvertFunc &convert_func)
: DataConverter(
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
check_func_(py::isinstance<T>) {}
explicit ByTypeDataConverter(const ArgsObjSigConvertFunc &convert_func)
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
return convert_func(obj, use_sig);
}),
check_func_(py::isinstance<T>) {}
explicit ByTypeDataConverter(const ArgsOjbTypeConvertFunc &convert_func)
: DataConverter([convert_func](const py::object &obj, bool, const TypePtr &dtype) -> ValuePtr {
return convert_func(obj, dtype);
}),
check_func_(py::isinstance<T>) {}
~ByTypeDataConverter() override = default;

bool Matched(const py::object &obj) override { return check_func_ != nullptr ? check_func_(obj) : false; }

private:
InstanceCheckFunc check_func_ = nullptr;
};

// Convert the data according object attribute.
class ByAttrDataConverter : public DataConverter {
public:
ByAttrDataConverter(const char *attr_name, const ArgsObjConvertFunc &convert_func)
: DataConverter(
[convert_func](const py::object &obj, bool, const TypePtr &) -> ValuePtr { return convert_func(obj); }),
attr_name_(attr_name) {}
ByAttrDataConverter(const char *attr_name, const ArgsObjSigConvertFunc &convert_func)
: DataConverter([convert_func](const py::object &obj, bool use_sig, const TypePtr &) -> ValuePtr {
return convert_func(obj, use_sig);
}),
attr_name_(attr_name) {}
bool Matched(const py::object &obj) override { return py::hasattr(obj, attr_name_); }

private:
const char *attr_name_ = nullptr;
};

FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
@@ -69,7 +142,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
}

namespace {
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
ValuePtr ConvertTuple(const py::object &obj, bool use_signature) {
MS_LOG(DEBUG) << "Converting python tuple";
auto tuple = obj.cast<py::tuple>();
std::vector<ValuePtr> value_list;
@@ -77,16 +150,14 @@ bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signatur
ValuePtr out = nullptr;
bool success = ConvertData(tuple[it], &out, use_signature);
if (!success) {
return false;
return nullptr;
}
value_list.push_back(out);
}
*data = std::make_shared<ValueTuple>(value_list);

return true;
return std::make_shared<ValueTuple>(value_list);
}

bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) {
ValuePtr ConvertList(const py::object &obj, bool use_signature) {
MS_LOG(DEBUG) << "Converting python list";

auto list = obj.cast<py::list>();
@@ -95,15 +166,14 @@ bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature
ValuePtr out = nullptr;
bool success = ConvertData(list[it], &out, use_signature);
if (!success) {
return false;
return nullptr;
}
value_list.push_back(out);
}
*data = std::make_shared<ValueList>(value_list);
return true;
return std::make_shared<ValueList>(value_list);
}

bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) {
ValuePtr ConvertCellList(const py::object &obj, bool use_signature) {
MS_LOG(DEBUG) << "Converting cell list";
py::sequence list = obj;
std::vector<ValuePtr> value_list;
@@ -111,15 +181,14 @@ bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signa
ValuePtr out = nullptr;
bool success = ConvertData(list[it], &out, use_signature);
if (!success) {
return false;
return nullptr;
}
value_list.push_back(out);
}
*data = std::make_shared<ValueTuple>(value_list);
return true;
return std::make_shared<ValueTuple>(value_list);
}

bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
ValuePtr ConvertDict(const py::object &obj, bool use_signature) {
MS_LOG(DEBUG) << "Converting python dict";

auto dict_values = obj.cast<py::dict>();
@@ -127,36 +196,37 @@ bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) {
for (auto item : dict_values) {
if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "The key of dict is only support str.";
return false;
return nullptr;
}
std::string key = py::str(item.first);
ValuePtr out = nullptr;
bool success = ConvertData(dict_values[item.first], &out, use_signature);
if (!success) {
return false;
return nullptr;
}
key_values.emplace_back(key, out);
}
*data = std::make_shared<ValueDictionary>(key_values);
return true;
return std::make_shared<ValueDictionary>(key_values);
}

void ConvertNameSpace(const py::object &obj, ValuePtr *const data) {
ValuePtr ConvertNameSpace(const py::object &obj) {
MS_LOG(DEBUG) << "Converting python module";
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace));
return converted;
}

void ConvertDataClass(py::object obj, ValuePtr *const data) {
ValuePtr ConvertDataClass(const py::object &obj) {
MS_LOG(DEBUG) << "Converting dataclass";
// Maybe the obj is dataclass define
auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
*data = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
auto converted = std::make_shared<ClassObject>(obj, std::string(desc.begin() + 1, desc.end() - 1));
return converted;
}

bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) {
ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) {
MS_LOG(DEBUG) << "Converting primitive object" << use_signature;

// need check the primitive is class type or instance
@@ -164,96 +234,81 @@ bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature =
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
auto desc = py::cast<std::string>(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj));
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
*data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
} else {
auto primitive = obj.cast<PrimitivePyPtr>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
return false;
}
if (py::hasattr(obj, "__setattr_flag__")) {
if (py::hasattr(obj, "_clone")) {
auto clone_fn = obj.attr("_clone");
py::object new_obj = clone_fn();
primitive = new_obj.cast<PrimitivePyPtr>();
}
}
if (use_signature) {
*data = std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
} else {
*data = primitive;
}
MS_LOG(DEBUG) << "Converting primitive object ok " << (*data)->ToString();
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
}
auto primitive = obj.cast<PrimitivePyPtr>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null";
return nullptr;
}
return true;
if (py::hasattr(obj, "__setattr_flag__") && py::hasattr(obj, "_clone")) {
auto clone_fn = obj.attr("_clone");
py::object new_obj = clone_fn();
primitive = new_obj.cast<PrimitivePyPtr>();
}
if (use_signature) {
return std::make_shared<prim::DoSignaturePrimitive>(primitive->name(), primitive);
}
return primitive;
}

bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) {
ValuePtr ConvertMetaFuncGraph(const py::object &obj, bool use_signature = false) {
MS_LOG(DEBUG) << "Converting MetaFuncGraph object";
auto meta = obj.cast<MetaFuncGraphPtr>();
if (meta == nullptr) {
MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null";
return false;
return nullptr;
}
if (use_signature) {
*data = std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
} else {
*data = meta;
return std::make_shared<prim::DoSignaturePrimitive>(meta->name(), meta);
}
return true;
return meta;
}

bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) {
ValuePtr ConvertFuncGraph(const py::object &obj) {
MS_LOG(DEBUG) << "Converting FuncGraph object";
auto func_graph = obj.cast<FuncGraphPtr>();
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null";
return false;
return nullptr;
}
auto new_fg = BasicClone(func_graph);
new_fg->set_attr("is_load", MakeValue(true));
*data = new_fg;
return true;
return new_fg;
}

bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
ValuePtr ConvertSlice(const py::object &obj) {
MS_LOG(DEBUG) << "Converting slice object";

auto slice_obj = obj.cast<py::slice>();
auto convert_func = [obj](std::string attr) -> ValuePtr {
auto convert_func = [obj](const std::string &attr) -> ValuePtr {
auto py_attr = py::getattr(obj, attr.c_str());
if (py::isinstance<py::none>(py_attr)) {
return kNone;
} else if (py::isinstance<py::int_>(py_attr)) {
int64_t value = py::cast<int64_t>(py_attr);
}
if (py::isinstance<py::int_>(py_attr)) {
auto value = py::cast<int64_t>(py_attr);
return MakeValue(value);
} else {
MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none";
}
MS_LOG(EXCEPTION) << "Slice should contain only int64_t or none";
};
ValuePtr start = convert_func("start");
ValuePtr stop = convert_func("stop");
ValuePtr step = convert_func("step");
*data = std::make_shared<ValueSlice>(start, stop, step);
return true;
return std::make_shared<ValueSlice>(start, stop, step);
}

bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
auto obj = py::cast(cell);
ValuePtr ConvertCellObjToFuncGraph(const py::object &obj) {
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
return nullptr;
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
FuncGraphPtr bprop_graph =
enable_bprop_debug ? ConvertToBpropCut(obj) : ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
@@ -264,200 +319,183 @@ bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
func_graph->set_stage(stage);
}
*data = func_graph;
return true;
return func_graph;
}

bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
ValuePtr ConvertOtherObj(const py::object &obj) {
auto obj_type = data_converter::GetObjType(obj);
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
if (obj_type == RESOLVE_TYPE_CLASS_TYPE) {
MS_LOG(DEBUG) << "Resolve the class type, need create class instance.";
std::string desc = py::str(obj);
// desc has format "<class xxxx>", strip the '<' and '>' by offset 1;
*data = std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
return true;
return std::make_shared<ClassType>(obj, std::string(desc.begin() + 1, desc.end() - 1));
}
if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) {
MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type;
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error.";
return false;
return nullptr;
}
*data = func_graph;
return true;
return func_graph;
}
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
// Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct'
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
return true;
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
}
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
return false;
return nullptr;
}

template <typename T>
bool ConvertNumberWithType(const T &obj, ValuePtr *const data, TypePtr dtype) {
ValuePtr ConvertNumberWithType(const T &obj, TypePtr dtype) {
ValuePtr data = nullptr;
auto int_dypte = dyn_cast<Int>(dtype);
if (int_dypte != nullptr) {
switch (int_dypte->nbits()) {
case 8:
*data = std::make_shared<Int8Imm>(obj);
data = std::make_shared<Int8Imm>(obj);
break;
case 16:
*data = std::make_shared<Int16Imm>(obj);
data = std::make_shared<Int16Imm>(obj);
break;
case 32:
*data = std::make_shared<Int32Imm>(obj);
data = std::make_shared<Int32Imm>(obj);
break;
case 64:
*data = std::make_shared<Int64Imm>(obj);
data = std::make_shared<Int64Imm>(obj);
break;
default:
*data = std::make_shared<Int64Imm>(obj);
data = std::make_shared<Int64Imm>(obj);
}
return true;
return data;
}

auto uint_dypte = dyn_cast<UInt>(dtype);
if (uint_dypte != nullptr) {
switch (uint_dypte->nbits()) {
case 8:
*data = std::make_shared<UInt8Imm>(obj);
data = std::make_shared<UInt8Imm>(obj);
break;
case 16:
*data = std::make_shared<UInt16Imm>(obj);
data = std::make_shared<UInt16Imm>(obj);
break;
case 32:
*data = std::make_shared<UInt32Imm>(obj);
data = std::make_shared<UInt32Imm>(obj);
break;
case 64:
*data = std::make_shared<UInt64Imm>(obj);
data = std::make_shared<UInt64Imm>(obj);
break;
default:
*data = std::make_shared<UInt32Imm>(obj);
data = std::make_shared<UInt32Imm>(obj);
}
return true;
return data;
}

auto float_dypte = dyn_cast<Float>(dtype);
if (float_dypte != nullptr) {
switch (float_dypte->nbits()) {
case 32:
*data = std::make_shared<FP32Imm>(obj);
data = std::make_shared<FP32Imm>(obj);
break;
case 64:
*data = std::make_shared<FP64Imm>(obj);
data = std::make_shared<FP64Imm>(obj);
break;
default:
*data = std::make_shared<FP32Imm>(obj);
data = std::make_shared<FP32Imm>(obj);
}
return true;
return data;
}

return false;
return nullptr;
}

bool ConvertIntegerWithType(const int64_t &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
ValuePtr ConvertIntegerWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
auto obj_int64 = py::cast<int64_t>(obj);
if (dtype == nullptr) {
*data = std::make_shared<Int64Imm>(obj);
return true;
return std::make_shared<Int64Imm>(obj_int64);
}

return ConvertNumberWithType<int64_t>(obj, data, dtype);
return ConvertNumberWithType<int64_t>(obj_int64, dtype);
}

bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype = nullptr) {
ValuePtr ConvertFloatWithType(const py::object &obj, const TypePtr &dtype = nullptr) {
auto obj_float64 = py::cast<float>(obj);
if (dtype == nullptr) {
*data = std::make_shared<FP32Imm>(obj);
return true;
return std::make_shared<FP32Imm>(obj_float64);
}
return ConvertNumberWithType<float>(obj_float64, dtype);
}

return ConvertNumberWithType<float>(obj, data, dtype);
template <typename T, typename U>
ValuePtr PyCast(const py::object &obj) {
return std::make_shared<T>(py::cast<U>(obj));
}
} // namespace

bool ConvertSingleData(const py::object &obj, ValuePtr *const data) {
MS_EXCEPTION_IF_NULL(data);
ValuePtr converted = nullptr;
if (py::isinstance<py::none>(obj)) {
converted = kNone;
} else if (py::isinstance<py::bool_>(obj)) {
converted = std::make_shared<BoolImm>(py::cast<bool>(obj));
} else if (py::isinstance<py::str>(obj)) {
converted = std::make_shared<StringImm>(py::cast<std::string>(obj));
} else if (py::isinstance<py::ellipsis>(obj)) {
converted = kEllipsis;
} else if (py::isinstance<py::module>(obj)) {
ConvertNameSpace(obj, &converted);
} else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
ConvertDataClass(obj, &converted);
} else if (py::isinstance<Type>(obj)) {
converted = obj.cast<TypePtr>();
} else if (py::isinstance<Tensor>(obj)) {
converted = obj.cast<TensorPtr>();
} else if (py::isinstance<MetaTensor>(obj)) {
converted = obj.cast<MetaTensorPtr>();
} else if (py::isinstance<UMonad>(obj)) {
converted = obj.cast<UMonadPtr>();
} else if (py::isinstance<IOMonad>(obj)) {
converted = obj.cast<IOMonadPtr>();
} else if (py::isinstance<EnvInstance>(obj)) {
auto env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env;
} else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) {
converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj);
} else {
return false;
}
*data = converted;
return true;
template <typename T>
ValuePtr ObjCast(const py::object &obj) {
return obj.cast<T>();
}

std::vector<DataConverterPtr> GetDataConverters() {
static std::vector<DataConverterPtr> data_converters = {
// Convert data by python object type.
std::make_shared<ByTypeDataConverter<py::none>>(kNone),
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
std::make_shared<ByTypeDataConverter<py::str>>(PyCast<StringImm, string>),
std::make_shared<ByTypeDataConverter<py::ellipsis>>(kEllipsis),
std::make_shared<ByTypeDataConverter<py::module>>(ConvertNameSpace),
std::make_shared<ByAttrDataConverter>(PYTHON_DATACLASS_FIELDS, ConvertDataClass),
std::make_shared<ByTypeDataConverter<Type>>(ObjCast<TypePtr>),
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<UMonad>>(ObjCast<UMonadPtr>),
std::make_shared<ByTypeDataConverter<IOMonad>>(ObjCast<IOMonadPtr>),
std::make_shared<ByTypeDataConverter<EnvInstance>>(ObjCast<std::shared_ptr<EnvInstance>>),
std::make_shared<ByAttrDataConverter>(PYTHON_CLASS_MEMBER_NAMESPACE,
[](const py::object &obj) -> ValuePtr {
return std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER,
obj);
}),
std::make_shared<ByTypeDataConverter<py::int_>>(ConvertIntegerWithType),
std::make_shared<ByTypeDataConverter<py::float_>>(ConvertFloatWithType),
std::make_shared<ByTypeDataConverter<py::dict>>(ConvertDict),
std::make_shared<ByTypeDataConverter<py::slice>>(ConvertSlice),
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
std::make_shared<ByAttrDataConverter>(PYTHON_CELL_AS_LIST, ConvertCellList),
std::make_shared<ByTypeDataConverter<Cell>>(ConvertCellObjToFuncGraph),
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
std::make_shared<ByAttrDataConverter>(PYTHON_PRIMITIVE_FLAG, ConvertPrimitive),
std::make_shared<ByTypeDataConverter<MetaFuncGraph>>(ConvertMetaFuncGraph),
std::make_shared<ByTypeDataConverter<FuncGraph>>(ConvertFuncGraph),
};
return data_converters;
}
} // namespace

bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) {
bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, const TypePtr &dtype) {
// check parameter valid
if (data == nullptr) {
MS_LOG(ERROR) << "Data is null pointer";
return false;
}

ValuePtr converted = nullptr;
bool ret = ConvertSingleData(obj, &converted);
if (ret) {
*data = converted;
return true;
bool matched = false;
auto &&converters = GetDataConverters();
for (auto &converter : converters) {
if (converter->Matched(obj)) {
converted = converter->ConvertPyObject(obj, use_signature, dtype);
matched = true;
break;
}
}
if (py::isinstance<py::int_>(obj)) {
ret = ConvertIntegerWithType(py::cast<int64_t>(obj), &converted, dtype);
} else if (py::isinstance<py::float_>(obj)) {
ret = ConvertFloatWithType(py::cast<float>(obj), &converted, dtype);
} else if (py::isinstance<py::dict>(obj)) {
ret = ConvertDict(obj, &converted, use_signature);
} else if (py::isinstance<py::slice>(obj)) {
ret = ConvertSlice(obj, &converted);
} else if (py::isinstance<py::tuple>(obj)) {
ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) {
ret = ConvertPrimitive(obj, &converted, use_signature);
} else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<FuncGraph>(obj)) {
ret = ConvertFuncGraph(obj, &converted);
} else {
ret = ConvertOtherObj(obj, &converted);
if (!matched) {
converted = ConvertOtherObj(obj);
}
*data = converted;
return ret;
return converted != nullptr;
}

// convert data to graph
@@ -488,7 +526,6 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python
MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString();
data_converter::SetObjGraphValue(obj_key, func_graph);
}

return func_graph;
}
namespace data_converter {
@@ -549,13 +586,8 @@ bool IsCellInstance(const py::object &obj) {
// create the python class instance
py::object CreatePythonObject(const py::object &type, const py::tuple &params) {
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object obj;
if (params.empty()) {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type);
} else {
obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
}
return obj;
return params.empty() ? python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type)
: python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params);
}

// Generate an appropriate name and set to graph debuginfo


+ 0
- 42
mindspore/ccsrc/pipeline/jit/parse/parse.cc View File

@@ -142,48 +142,6 @@ void Parser::CleanParserResource() {
ScopeManager::GetInstance().ClearScope();
}

AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
MS_EXCEPTION_IF_NULL(func_graph);
auto value = py::cast<tensor::MetaTensorPtr>(obj);
// Parameter object should not be none
if (value == nullptr || !value->is_parameter()) {
MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object.";
}

// Get the parameter name from parameter object
auto param_name = value->param_info()->name();

auto top_graph = func_graph;
// If the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
for (const auto &param : top_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) {
para_node = param;
break;
}
}
if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name);

node->set_default_param(value);
// set_abstract for parameter
auto abs = value->ToAbstract();
// Boarden value
abs = abs->Broaden();
node->set_abstract(abs);
para_node = node;
}
return para_node;
}

void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
for (const auto &param : params) {
(void)AppendParameterObj(top_graph, param);
}
}

void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
// Check whether the functions referred by this function and itself are missing 'return' statement
auto mng = Manage(fn, false);


+ 1
- 1
mindspore/ccsrc/pipeline/jit/parse/parse_base.h View File

@@ -163,7 +163,7 @@ enum ClassInstanceTypeDef {
};

// Convert python object to ValuePtr
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, TypePtr dtype = nullptr);
bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false, const TypePtr &dtype = nullptr);

// Convert python obj to graph
FuncGraphPtr ConvertToFuncGraph(const py::object &obj,


Loading…
Cancel
Save