|
|
|
@@ -896,6 +896,16 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj |
|
|
|
// Update input for cast struct |
|
|
|
auto cast_struct = cast_struct_pair->second; |
|
|
|
cast_struct->op_inputs[0] = obj; |
|
|
|
auto grad = this->grad(); |
|
|
|
MS_EXCEPTION_IF_NULL(grad); |
|
|
|
if (grad->grad_flag()) { |
|
|
|
// Get forward op index |
|
|
|
if (!grad->cell_op_info_stack().empty()) { |
|
|
|
std::string &cell_op_info = grad->cell_op_info_stack().top(); |
|
|
|
cell_op_info += cast_struct->op_index; |
|
|
|
} |
|
|
|
grad->op_index_map()[cast_struct->op_name]++; |
|
|
|
} |
|
|
|
py::object ret = py::none(); |
|
|
|
RunOpInner(&ret, cast_struct); |
|
|
|
return ret; |
|
|
|
|