Merge pull request !8039 from zhangbuxue/support_key_ward_way_to_pass_arg_for_outermost_net_in_graph_modetags/v1.1.0
| @@ -282,19 +282,19 @@ class Cell(Cell_): | |||||
| return tuple(res) | return tuple(res) | ||||
| def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
| if kwargs: | |||||
| bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) | |||||
| inputs = bound_args.args | |||||
| kwargs = bound_args.kwargs | |||||
| if context.get_context("mode") == context.GRAPH_MODE: | if context.get_context("mode") == context.GRAPH_MODE: | ||||
| if kwargs: | if kwargs: | ||||
| raise ValueError("For 'graph' mode, the outermost network does not support passing " | raise ValueError("For 'graph' mode, the outermost network does not support passing " | ||||
| "key-value pair parameters and variable key-value pair parameters.") | |||||
| "variable key-value pair parameters.") | |||||
| if self.enable_hook: | if self.enable_hook: | ||||
| raise ValueError("The graph mode does not support hook function.") | raise ValueError("The graph mode does not support hook function.") | ||||
| out = self.compile_and_run(*inputs) | out = self.compile_and_run(*inputs) | ||||
| return out | return out | ||||
| if kwargs: | |||||
| bound_args = inspect.signature(self.construct).bind(*inputs, **kwargs) | |||||
| inputs = bound_args.args | |||||
| kwargs = bound_args.kwargs | |||||
| for item in inputs: | for item in inputs: | ||||
| if isinstance(item, numpy.ndarray): | if isinstance(item, numpy.ndarray): | ||||
| raise TypeError("cell inputs should not be numpy array.") | raise TypeError("cell inputs should not be numpy array.") | ||||
| @@ -22,7 +22,6 @@ from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | ||||
| grad_all = C.GradOperation(get_all=True) | grad_all = C.GradOperation(get_all=True) | ||||
| grad_all_with_sens = C.GradOperation(sens_param=True) | grad_all_with_sens = C.GradOperation(sens_param=True) | ||||
| @@ -285,3 +284,31 @@ def test_mixed_precision_const_parameter(): | |||||
| y = Tensor(np.ones((1, 3, 14, 14), np.float32)) | y = Tensor(np.ones((1, 3, 14, 14), np.float32)) | ||||
| z = Tensor(np.ones((1, 3, 28, 28), np.float32)) | z = Tensor(np.ones((1, 3, 28, 28), np.float32)) | ||||
| _ = net(x, y, z) | _ = net(x, y, z) | ||||
| def test_pass_args_by_key_ward_way(): | |||||
| class KeyWardNet(Cell): | |||||
| def __init__(self): | |||||
| super(KeyWardNet, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| return x + y - z | |||||
| class GradNet(Cell): | |||||
| def __init__(self, net): | |||||
| super(GradNet, self).__init__() | |||||
| self.grad = C.GradOperation(get_all=True, sens_param=True) | |||||
| self.net = net | |||||
| self.sens = Tensor(np.ones((3, 3, 4), np.float32)) | |||||
| def construct(self, x, y, z, sens): | |||||
| return self.grad(self.net)(x, y, z, sens) | |||||
| x = Tensor(np.ones((1, 3, 4), np.float32)) | |||||
| y = Tensor(np.ones((1, 3, 4), np.float32)) | |||||
| z = Tensor(np.ones((3, 3, 4), np.float32)) | |||||
| net = KeyWardNet() | |||||
| net(x, z=z, y=y) | |||||
| grad_net = GradNet(net) | |||||
| sens = Tensor(np.ones((3, 3, 4), np.float32)) | |||||
| grad_net(x, y=y, z=z, sens=sens) | |||||