Browse Source

!4283 add InsertGradientOf operator support in pynative mode

Merge pull request !4283 from wangqiuliang/add-insert-gradient-of-operator-support
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
e203503b80
2 changed files with 3 additions and 6 deletions
  1. +3
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +0
    -4
      mindspore/ops/operations/debug_ops.py

+ 3
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -58,7 +58,8 @@ using mindspore::tensor::TensorPy;

const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
"mixed_precision_cast"};

namespace mindspore {
namespace pynative {
@@ -346,7 +347,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);

auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward") {
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") {
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];


+ 0
- 4
mindspore/ops/operations/debug_ops.py View File

@@ -238,10 +238,6 @@ class InsertGradientOf(PrimitiveWithInfer):
def __init__(self, f):
self.f = f

def __call__(self, x):
"""run in PyNative mode."""
return x

def infer_shape(self, x_shape):
return x_shape



Loading…
Cancel
Save