|
|
|
@@ -58,7 +58,8 @@ using mindspore::tensor::TensorPy; |
|
|
|
|
|
|
|
const char SINGLE_OP_GRAPH[] = "single_op_graph"; |
|
|
|
// primitive unable to infer value for constant input in PyNative mode |
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"}; |
|
|
|
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient", |
|
|
|
"mixed_precision_cast"}; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace pynative { |
|
|
|
@@ -346,7 +347,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); |
|
|
|
|
|
|
|
auto &op_inputs = op_exec_info->op_inputs; |
|
|
|
if (op_exec_info->op_name == "HookBackward") { |
|
|
|
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") { |
|
|
|
py::tuple result(op_inputs.size()); |
|
|
|
for (size_t i = 0; i < op_inputs.size(); i++) { |
|
|
|
py::object input = op_inputs[i]; |
|
|
|
|