|
|
|
@@ -256,41 +256,84 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { |
|
|
|
return RunOp(args)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExecInfoPtr &op_exec_info) { |
|
|
|
auto &out_args = op_exec_info->op_inputs; |
|
|
|
auto signature = prim->signatures(); |
|
|
|
std::vector<SignatureEnumDType> dtypes; |
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), |
|
|
|
[](const Signature &sig) { return sig.dtype; }); |
|
|
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); |
|
|
|
if (dtypes.empty() || static_cast<int>(dtypes.size()) == empty_dtype_count) { |
|
|
|
return; |
|
|
|
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { |
|
|
|
auto tensor = py::cast<tensor::TensorPtr>(obj); |
|
|
|
auto cast_type = tensor->cast_dtype(); |
|
|
|
py::object cast_output; |
|
|
|
if (cast_type != nullptr) { |
|
|
|
auto source_element = tensor->Dtype(); |
|
|
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { |
|
|
|
MS_LOG(DEBUG) << "cast to " << cast_type->ToString(); |
|
|
|
cast_output = DoAutoCast(obj, cast_type->type_id()); |
|
|
|
*is_cast = true; |
|
|
|
} |
|
|
|
} |
|
|
|
auto type_indexes = GetTypeIndex(dtypes); |
|
|
|
auto dst_type = GetDstType(out_args, type_indexes); |
|
|
|
return cast_output; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) { |
|
|
|
if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = dst_type.find(dtypes[i]); |
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) { |
|
|
|
continue; |
|
|
|
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { |
|
|
|
auto tuple_size = static_cast<int>(tuple.size()); |
|
|
|
py::tuple result(tuple_size); |
|
|
|
|
|
|
|
for (int i = 0; i < tuple_size; i++) { |
|
|
|
if (py::isinstance<tensor::MetaTensor>(tuple[i])) { |
|
|
|
MS_LOG(DEBUG) << "call cast for item " << i; |
|
|
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); |
|
|
|
} else if (py::isinstance<py::tuple>(tuple[i])) { |
|
|
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) { |
|
|
|
auto signature = prim->signatures(); |
|
|
|
bool has_sig_dtype = false; |
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(*dtypes), |
|
|
|
[&has_sig_dtype](const Signature &sig) { |
|
|
|
auto dtype = sig.dtype; |
|
|
|
if (dtype != SignatureEnumDType::kDTypeEmptyDefaultValue) { |
|
|
|
has_sig_dtype = true; |
|
|
|
} |
|
|
|
return dtype; |
|
|
|
}); |
|
|
|
return has_sig_dtype; |
|
|
|
} |
|
|
|
|
|
|
|
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type, |
|
|
|
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) { |
|
|
|
const auto &signature = prim->signatures(); |
|
|
|
auto &out_args = op_exec_info->op_inputs; |
|
|
|
bool has_dtype_sig = (dtypes.size() > 0); |
|
|
|
for (size_t i = 0; i < out_args.size(); ++i) { |
|
|
|
MS_LOG(DEBUG) << "check inputs " << i; |
|
|
|
auto obj = out_args[i]; |
|
|
|
auto sig = signature[i].rw; |
|
|
|
auto sig = SignatureEnumRW::kRWDefault; |
|
|
|
if (signature.size() > 0) { |
|
|
|
sig = signature[i].rw; |
|
|
|
} |
|
|
|
bool is_parameter = false; |
|
|
|
bool is_same_type = false; |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj); |
|
|
|
if (arg->is_parameter()) { |
|
|
|
is_parameter = true; |
|
|
|
MS_LOG(DEBUG) << "parameter is read " << i; |
|
|
|
} |
|
|
|
arg_type_id = arg->data_type(); |
|
|
|
} |
|
|
|
|
|
|
|
// No need to implicit cast if no dtype. |
|
|
|
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = dst_type.find(dtypes[i]); |
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// implicit cast |
|
|
|
bool is_same_type = false; |
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); |
|
|
|
if (arg_type_id != 0) { |
|
|
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); |
|
|
|
} |
|
|
|
@@ -317,7 +360,6 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, const OpExe |
|
|
|
} |
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second); |
|
|
|
out_args[i] = cast_output; |
|
|
|
ValuePtr input_value = PyAttrValue(cast_output); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -346,7 +388,6 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { |
|
|
|
op_exec_info->py_primitive = prim; |
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); |
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS]; |
|
|
|
ConvertInputs(prim, args[PY_INPUTS], op_exec_info); |
|
|
|
return op_exec_info; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -697,11 +738,53 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
inputs.push_back(NewValueNode(prim)); |
|
|
|
|
|
|
|
size_t size = op_exec_info->op_inputs.size(); |
|
|
|
auto const_input_index = prim->get_const_input_indexes(); |
|
|
|
bool have_const_input = !const_input_index.empty(); |
|
|
|
bool is_const_prim = prim->is_const_prim(); |
|
|
|
// ignore signature for cast op |
|
|
|
bool is_cast_op = (op_exec_info->op_name == "Cast"); |
|
|
|
if (!is_cast_op) { |
|
|
|
const auto &signature = prim->signatures(); |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
auto obj = op_exec_info->op_inputs[i]; |
|
|
|
auto sig = SignatureEnumRW::kRWDefault; |
|
|
|
if (signature.size() > 0) { |
|
|
|
sig = signature[i].rw; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "check mix precision " << op_exec_info->op_name << " input " << i << " " |
|
|
|
<< std::string(py::repr(obj)); |
|
|
|
// mix precision for non param |
|
|
|
bool is_cast = false; |
|
|
|
py::object cast_output; |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); |
|
|
|
if (meta_tensor && meta_tensor->is_parameter()) { |
|
|
|
if (sig != SignatureEnumRW::kRWRead) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
// redundant cast call if the tensor is a const Tensor. |
|
|
|
cast_output = DoParamMixPrecisionCast(&is_cast, obj); |
|
|
|
} else if (py::isinstance<py::tuple>(obj)) { |
|
|
|
// mix precision for tuple inputs |
|
|
|
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); |
|
|
|
} |
|
|
|
if (is_cast) { |
|
|
|
op_exec_info->op_inputs[i] = cast_output; |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<SignatureEnumDType> dtypes; |
|
|
|
|
|
|
|
bool has_dtype_sig = GetSignatureType(prim, &dtypes); |
|
|
|
std::map<SignatureEnumDType, TypeId> dst_types; |
|
|
|
if (has_dtype_sig) { |
|
|
|
// fetch info for implicit cast |
|
|
|
auto type_indexes = GetTypeIndex(dtypes); |
|
|
|
dst_types = GetDstType(op_exec_info->op_inputs, type_indexes); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "do signature for " << op_exec_info->op_name; |
|
|
|
DoSignatrueCast(prim, dst_types, dtypes, op_exec_info); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "make cnode for " << op_exec_info->op_name; |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
auto obj = op_exec_info->op_inputs[i]; |
|
|
|
const auto &obj = op_exec_info->op_inputs[i]; |
|
|
|
bool op_mask = false; |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); |
|
|
|
@@ -709,9 +792,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
op_mask = meta_tensor->is_parameter(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
(*op_masks).push_back(op_mask); |
|
|
|
MS_LOG(DEBUG) << "gen " << op_exec_info->op_name << " arg " << i << ": op mask " << op_mask << " grad_flag_ " |
|
|
|
MS_LOG(DEBUG) << "gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " |
|
|
|
<< grad_flag_; |
|
|
|
|
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
@@ -726,6 +808,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
if (node != nullptr && node->abstract() != nullptr) { |
|
|
|
abs = node->abstract(); |
|
|
|
} |
|
|
|
|
|
|
|
auto const_input_index = prim->get_const_input_indexes(); |
|
|
|
bool have_const_input = !const_input_index.empty(); |
|
|
|
bool is_const_prim = prim->is_const_prim(); |
|
|
|
MS_LOG(DEBUG) << prim->ToString() << " abs is nullptr " << (abs == nullptr) << " is_const_value " |
|
|
|
<< prim->is_const_prim(); |
|
|
|
bool is_const_input = have_const_input && std::count(const_input_index.begin(), const_input_index.end(), i); |
|
|
|
@@ -926,7 +1012,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple PynativeExecutor::RunOpInner(const py::args &args) { |
|
|
|
MS_LOG(DEBUG) << "RunOp start" << args.size(); |
|
|
|
MS_LOG(DEBUG) << "RunOp start " << args.size(); |
|
|
|
OpExecInfoPtr op_exec_info = nullptr; |
|
|
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); |
|
|
|
auto name = py::cast<std::string>(args[PY_NAME]); |
|
|
|
|