diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 0b65936f87..87a0d53169 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -690,6 +690,7 @@ class Cell(Cell_): Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. """ replace = dict() + def _updata(param): if param in replace: return replace[param] @@ -1078,6 +1079,10 @@ class GraphKernel(Cell): A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when enable_graph_kernel in context is set to True. + Args: + auto_prefix (bool): Recursively generate namespaces. Default: True. + flags (dict) : Set graph flags. Default: None. + Examples: >>> class Relu(nn.GraphKernel): ... def __init__(self): @@ -1088,8 +1093,8 @@ class GraphKernel(Cell): ... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) """ - def __init__(self, auto_prefix=True, pips=None): - super(GraphKernel, self).__init__(auto_prefix, pips) + def __init__(self, auto_prefix=True, flags=None): + super(GraphKernel, self).__init__(auto_prefix, flags) class_name = self.__class__.__name__ self.add_flags(graph_kernel=class_name)