|
|
|
@@ -314,6 +314,18 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat |
|
|
|
py::tuple err_ret(0); |
|
|
|
return std::move(err_ret); |
|
|
|
} |
|
|
|
if (op_exec_info->op_name == "stop_gradient" && py::isinstance<tensor::Tensor>(result)) { |
|
|
|
py::tuple tuple_result(1); |
|
|
|
auto tensor = py::cast<tensor::TensorPtr>(result); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); |
|
|
|
new_tensor->set_device_address(tensor->device_address()); |
|
|
|
new_tensor->set_sync_status(tensor->sync_status()); |
|
|
|
tuple_result[0] = new_tensor; |
|
|
|
*status = PYNATIVE_SUCCESS; |
|
|
|
MS_LOG(INFO) << "RunOpInVM end"; |
|
|
|
return std::move(tuple_result); |
|
|
|
} |
|
|
|
|
|
|
|
// execute op |
|
|
|
py::tuple tuple_result = py::make_tuple(result); |
|
|
|
|