|
|
|
@@ -279,16 +279,34 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe |
|
|
|
} |
|
|
|
|
|
|
|
auto obj = out_args[i]; |
|
|
|
if (py::isinstance<tensor::Tensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::TensorPtr>(obj); |
|
|
|
TypeId arg_type_id = arg->data_type(); |
|
|
|
if (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second) { |
|
|
|
continue; |
|
|
|
auto sig = signature[i].rw; |
|
|
|
bool is_parameter = false; |
|
|
|
bool is_same_type = false; |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj); |
|
|
|
if (arg->is_parameter()) { |
|
|
|
is_parameter = true; |
|
|
|
} |
|
|
|
if (signature[i].rw == SignatureEnumRW::kRWWrite) { |
|
|
|
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), |
|
|
|
TypeIdToMsTypeStr(it->second)); |
|
|
|
arg_type_id = arg->data_type(); |
|
|
|
} |
|
|
|
if (arg_type_id != 0) { |
|
|
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); |
|
|
|
} |
|
|
|
if (is_sig_write) { |
|
|
|
if (!is_parameter) { |
|
|
|
prim::RaiseExceptionForCheckParameter(prim->name(), i, "not"); |
|
|
|
} |
|
|
|
if (arg_type_id != 0) { |
|
|
|
if (!is_same_type) { |
|
|
|
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), |
|
|
|
TypeIdToMsTypeStr(it->second)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (is_same_type) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) { |
|
|
|
|