Merge pull request !3042 from wangqiuliang/fix-sens-tensor-check-issuetags/v0.6.0-beta
| @@ -351,13 +351,13 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat | |||||
| for (size_t i = 0; i < op_inputs.size(); i++) { | for (size_t i = 0; i < op_inputs.size(); i++) { | ||||
| py::object input = op_inputs[i]; | py::object input = op_inputs[i]; | ||||
| if (py::hasattr(input, "__parameter__")) { | if (py::hasattr(input, "__parameter__")) { | ||||
| result[i] = py::getattr(input, "data"); | |||||
| } else { | |||||
| auto tensor = py::cast<tensor::TensorPtr>(input); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); | |||||
| new_tensor->set_device_address(tensor->device_address()); | |||||
| new_tensor->set_dirty(tensor->is_dirty()); | |||||
| result[i] = new_tensor; | |||||
| input = py::getattr(input, "data"); | |||||
| } | |||||
| auto tensor = py::cast<tensor::TensorPtr>(input); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); | |||||
| new_tensor->set_device_address(tensor->device_address()); | |||||
| new_tensor->set_dirty(tensor->is_dirty()); | |||||
| result[i] = new_tensor; | |||||
| } | } | ||||
| } | } | ||||
| *status = PYNATIVE_SUCCESS; | *status = PYNATIVE_SUCCESS; | ||||
| @@ -120,6 +120,9 @@ class GradOperation(GradOperation_): | |||||
| """ Pynative forward run to build grad graph. """ | """ Pynative forward run to build grad graph. """ | ||||
| if self.sens_param: | if self.sens_param: | ||||
| args = args[:-1] | args = args[:-1] | ||||
| for arg in args: | |||||
| if not isinstance(arg, Tensor): | |||||
| raise TypeError("grad inputs should be tensor in pynative mode") | |||||
| if isinstance(fn, FunctionType): | if isinstance(fn, FunctionType): | ||||
| _pynative_exec.set_grad_flag(True) | _pynative_exec.set_grad_flag(True) | ||||
| _pynative_exec.new_graph(fn, *args) | _pynative_exec.new_graph(fn, *args) | ||||
| @@ -150,9 +153,6 @@ class GradOperation(GradOperation_): | |||||
| else: | else: | ||||
| @_wrap_func | @_wrap_func | ||||
| def after_grad(*args): | def after_grad(*args): | ||||
| for arg in args: | |||||
| if not isinstance(arg, Tensor): | |||||
| raise TypeError("grad inputs should be tensor in pynative mode") | |||||
| self._pynative_forward_run(args, fn) | self._pynative_forward_run(args, fn) | ||||
| _pynative_exec.grad(grad_, fn, weights, *args) | _pynative_exec.grad(grad_, fn, weights, *args) | ||||
| out = _pynative_exec(*args) | out = _pynative_exec(*args) | ||||