Merge pull request !4370 from zhangbuxue/support_kw_and_kwargs_for_cell_in_pynativetags/v0.7.0-beta
| @@ -291,14 +291,14 @@ class _PynativeExecutor: | |||
| def __init__(self): | |||
| self._executor = PynativeExecutor_.get_instance() | |||
| def new_graph(self, obj, *args): | |||
| self._executor.new_graph(obj, *args) | |||
| def new_graph(self, obj, *args, **kwargs): | |||
| self._executor.new_graph(obj, *args, *(kwargs.values())) | |||
| def end_graph(self, obj, output, *args): | |||
| self._executor.end_graph(obj, output, *args) | |||
| def end_graph(self, obj, output, *args, **kwargs): | |||
| self._executor.end_graph(obj, output, *args, *(kwargs.values())) | |||
| def grad(self, grad, obj, weights, *args): | |||
| self._executor.grad_net(grad, obj, weights, *args) | |||
| def grad(self, grad, obj, weights, *args, **kwargs): | |||
| self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values())) | |||
| def clear(self, flag=""): | |||
| self._executor.clear(flag) | |||
| @@ -306,7 +306,8 @@ class _PynativeExecutor: | |||
| def set_grad_flag(self, flag): | |||
| self._executor.set_grad_flag(flag) | |||
| def __call__(self, *args): | |||
| def __call__(self, *args, **kwargs): | |||
| args = args + tuple(kwargs.values()) | |||
| return self._executor(args, "") | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """cell""" | |||
| import inspect | |||
| import time | |||
| import gc | |||
| from collections import OrderedDict | |||
| @@ -222,19 +223,27 @@ class Cell: | |||
| else: | |||
| object.__delattr__(self, name) | |||
| def __call__(self, *inputs): | |||
| def __call__(self, *inputs, **kwargs): | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| if kwargs: | |||
| raise ValueError("For 'graph' mode, the outermost network does not support passing " | |||
| "key-value pair parameters and variable key-value pair parameters.") | |||
| out = self.compile_and_run(*inputs) | |||
| return out | |||
| if kwargs: | |||
| bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) | |||
| inputs = bound_args.args | |||
| kwargs = bound_args.kwargs | |||
| for item in inputs: | |||
| if isinstance(item, numpy.ndarray): | |||
| raise TypeError("cell inputs should not be numpy array.") | |||
| orign_grad = [] | |||
| origin_grad = [] | |||
| if self.requires_grad is True: | |||
| _pynative_exec.set_grad_flag(True) | |||
| _pynative_exec.new_graph(self, *inputs) | |||
| _pynative_exec.new_graph(self, *inputs, **kwargs) | |||
| for cell in self.cells(): | |||
| orign_grad.append(cell.requires_grad) | |||
| origin_grad.append(cell.requires_grad) | |||
| cell.set_grad(True) | |||
| else: | |||
| _pynative_exec.set_grad_flag(False) | |||
| @@ -251,15 +260,15 @@ class Cell: | |||
| else: | |||
| cast_inputs = inputs | |||
| if self.enable_hook: | |||
| output = self._hook_construct(*cast_inputs) | |||
| output = self._hook_construct(*cast_inputs, **kwargs) | |||
| else: | |||
| output = self.construct(*cast_inputs) | |||
| output = self.construct(*cast_inputs, **kwargs) | |||
| if isinstance(output, Parameter): | |||
| output = output.data | |||
| if self.requires_grad is True: | |||
| _pynative_exec.end_graph(self, output, *inputs) | |||
| _pynative_exec.end_graph(self, output, *inputs, **kwargs) | |||
| for i, cell in enumerate(self.cells()): | |||
| cell.set_grad(orign_grad[i]) | |||
| cell.set_grad(origin_grad[i]) | |||
| self._already_run = True | |||
| return output | |||
| @@ -400,7 +409,6 @@ class Cell: | |||
| def _get_construct_inputs_number_and_name(self): | |||
| """Compute self._construct_inputs_names and self._construct_inputs_num""" | |||
| import inspect | |||
| from mindspore._extends.parse.parser import get_parse_method_of_class | |||
| fn = get_parse_method_of_class(self) | |||
| @@ -517,7 +525,7 @@ class Cell: | |||
| raise TypeError("Child cell type is incorrect.") | |||
| self._cells[child_name] = child | |||
| def construct(self, *inputs): | |||
| def construct(self, *inputs, **kwargs): | |||
| """ | |||
| Defines the computation to be performed. | |||
| @@ -878,7 +886,7 @@ class Cell: | |||
| self.add_flags(auto_parallel=True) | |||
| self._get_construct_inputs_number_and_name() | |||
| def _hook_construct(self, *inputs): | |||
| def _hook_construct(self, *inputs, **kwargs): | |||
| """Hook construct method to replace original construct method when hook function enabled.""" | |||
| inputs = self._backward_hook(*inputs) | |||
| inputs = self.construct(inputs) | |||
| @@ -116,7 +116,7 @@ class GradOperation(GradOperation_): | |||
| self.fn = None | |||
| self.need_forward = False | |||
| def _pynative_forward_run(self, args, fn): | |||
| def _pynative_forward_run(self, args, kwargs, fn): | |||
| """ Pynative forward run to build grad graph. """ | |||
| if self.sens_param: | |||
| args = args[:-1] | |||
| @@ -125,9 +125,9 @@ class GradOperation(GradOperation_): | |||
| raise TypeError("grad inputs should be tensor in pynative mode") | |||
| if isinstance(fn, FunctionType): | |||
| _pynative_exec.set_grad_flag(True) | |||
| _pynative_exec.new_graph(fn, *args) | |||
| output = fn(*args) | |||
| _pynative_exec.end_graph(fn, output, *args) | |||
| _pynative_exec.new_graph(fn, *args, **kwargs) | |||
| output = fn(*args, **kwargs) | |||
| _pynative_exec.end_graph(fn, output, *args, **kwargs) | |||
| else: | |||
| if fn.already_run and not fn.requires_grad: | |||
| raise ValueError("obj must set_grad.") | |||
| @@ -135,7 +135,7 @@ class GradOperation(GradOperation_): | |||
| self.need_forward = True | |||
| if self.need_forward: | |||
| fn.set_grad() | |||
| fn(*args) | |||
| fn(*args, **kwargs) | |||
| fn.already_run = False | |||
| def __call__(self, fn, weights=None): | |||
| @@ -152,10 +152,10 @@ class GradOperation(GradOperation_): | |||
| return grad_(fn)(*args) | |||
| else: | |||
| @_wrap_func | |||
| def after_grad(*args): | |||
| self._pynative_forward_run(args, fn) | |||
| _pynative_exec.grad(grad_, fn, weights, *args) | |||
| out = _pynative_exec(*args) | |||
| def after_grad(*args, **kwargs): | |||
| self._pynative_forward_run(args, kwargs, fn) | |||
| _pynative_exec.grad(grad_, fn, weights, *args, **kwargs) | |||
| out = _pynative_exec(*args, **kwargs) | |||
| _pynative_exec.clear() | |||
| return out | |||
| self.grad_fn = after_grad | |||
| @@ -30,6 +30,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_list_equal(): | |||
| class Net(nn.Cell): | |||
| def __init__(self, z: list): | |||
| @@ -156,8 +157,10 @@ def test_class_member_not_defined(): | |||
| z = [[1, 2], 3] | |||
| net = Net(z) | |||
| x = Tensor(np.ones([6, 8, 10], np.int32)) | |||
| y = Tensor(np.zeros([3, 4, 5], np.int32)) | |||
| with pytest.raises(TypeError) as ex: | |||
| net() | |||
| net(x, y) | |||
| assert "'self.x' was not defined in the class '__init__' function." in str(ex.value) | |||
| @@ -181,7 +184,7 @@ def test_change_list_element(): | |||
| class ListOperate(nn.Cell): | |||
| def __init__(self,): | |||
| def __init__(self): | |||
| super(ListOperate, self).__init__() | |||
| def construct(self, t, l): | |||
| @@ -201,7 +204,7 @@ class ListOperate(nn.Cell): | |||
| class InListNet(nn.Cell): | |||
| def __init__(self,): | |||
| def __init__(self): | |||
| super(InListNet, self).__init__() | |||
| self.list_ = [1, 2, 3, 4, 5, "ok"] | |||
| @@ -0,0 +1,139 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test dtype and shape as attr""" | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops.composite import base as C | |||
| def test_kw_nested(): | |||
| class NetKeyValueArg(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def construct(self, x, y, *arg, w, **kwargs): | |||
| return x + y + arg[0] + w + kwargs['c'] | |||
| class NetOut(nn.Cell): | |||
| def __init__(self, net): | |||
| super().__init__() | |||
| self.in_net = net | |||
| def construct(self, x, y, z): | |||
| ret = self.in_net(x, y, z, w=x, a=x, b=y, c=z) + x | |||
| return ret | |||
| in_net = NetKeyValueArg() | |||
| out_net = NetOut(in_net) | |||
| x = Tensor(np.ones([3, 4, 5], np.float32)) | |||
| y = Tensor(np.zeros([3, 4, 5], np.int32)) | |||
| z = Tensor(np.ones([3, 4, 5], np.float64)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| ret = out_net(x, y, z) | |||
| assert ret.dtype == mstype.float64 | |||
| assert ret.shape == (3, 4, 5) | |||
| assert (ret.asnumpy() == np.ones([3, 4, 5], np.float64) * 5).all() | |||
| def test_kw_grad(): | |||
| class KwNet(nn.Cell): | |||
| def __init__(self): | |||
| super(KwNet, self).__init__() | |||
| def construct(self, x, y, *arg, **kwargs): | |||
| return 2 * x + 3 * y + 4 * arg[0] + 5 * kwargs['v'] | |||
| class GradKwNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradKwNet, self).__init__() | |||
| self.net = net | |||
| self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) | |||
| def construct(self, x, y, *arg, **kwargs): | |||
| return self.grad_all_wit_sense(self.net)(x, y, *arg, **kwargs) | |||
| kw_net = KwNet() | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.float32)) | |||
| z = Tensor(np.ones([1, 2, 3], np.float64)) | |||
| u = Tensor(np.ones([1, 2, 3], np.float16)) | |||
| v = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| w = Tensor(np.ones([1, 2, 3], np.float64)) | |||
| sens = Tensor(np.ones([1, 2, 3], np.float64)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| kw_net.set_grad(True) | |||
| ret = kw_net(x, y, z, u=u, v=v, w=w) | |||
| assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 14).all() | |||
| grad_kw_net = GradKwNet(kw_net) | |||
| ret_grad = grad_kw_net(x, y, z, u=u, v=v, w=w, sens=sens) | |||
| assert len(ret_grad) == 6 | |||
| assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all() | |||
| assert ret_grad[0].dtype == mstype.int32 | |||
| assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all() | |||
| assert ret_grad[1].dtype == mstype.float32 | |||
| assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all() | |||
| assert ret_grad[2].dtype == mstype.float64 | |||
| assert (ret_grad[3].asnumpy() == np.zeros([1, 2, 3])).all() | |||
| assert ret_grad[3].dtype == mstype.float16 | |||
| assert (ret_grad[4].asnumpy() == np.ones([1, 2, 3]) * 5).all() | |||
| assert ret_grad[4].dtype == mstype.int32 | |||
| assert (ret_grad[5].asnumpy() == np.zeros([1, 2, 3])).all() | |||
| assert ret_grad[5].dtype == mstype.float64 | |||
| def test_grad(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x, y, z): | |||
| return 2 * x + 3 * y + 4 * z | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.grad_all_wit_sense = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) | |||
| def construct(self, x, y, z, sens): | |||
| return self.grad_all_wit_sense(self.net)(x, y, z, sens) | |||
| net = Net() | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.float32)) | |||
| z = Tensor(np.ones([1, 2, 3], np.float16)) | |||
| sens = Tensor(np.ones([1, 2, 3], np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net.set_grad(True) | |||
| ret = net(x, y, z) | |||
| assert (ret.asnumpy() == np.ones([1, 2, 3], np.float64) * 9).all() | |||
| grad_net = GradNet(net) | |||
| ret_grad = grad_net(x, y, z, sens) | |||
| assert len(ret_grad) == 3 | |||
| assert (ret_grad[0].asnumpy() == np.ones([1, 2, 3]) * 2).all() | |||
| assert ret_grad[0].dtype == mstype.int32 | |||
| assert (ret_grad[1].asnumpy() == np.ones([1, 2, 3]) * 3).all() | |||
| assert ret_grad[1].dtype == mstype.float32 | |||
| assert (ret_grad[2].asnumpy() == np.ones([1, 2, 3]) * 4).all() | |||
| assert ret_grad[2].dtype == mstype.float16 | |||