|
|
@@ -690,6 +690,7 @@ class Cell(Cell_): |
|
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. |
|
|
Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter. |
|
|
""" |
|
|
""" |
|
|
replace = dict() |
|
|
replace = dict() |
|
|
|
|
|
|
|
|
def _updata(param): |
|
|
def _updata(param): |
|
|
if param in replace: |
|
|
if param in replace: |
|
|
return replace[param] |
|
|
return replace[param] |
|
|
@@ -1078,6 +1079,10 @@ class GraphKernel(Cell): |
|
|
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when |
|
|
A `GraphKernel` a composite of basic primitives and can be compiled into a fused kernel automatically when |
|
|
enable_graph_kernel in context is set to True. |
|
|
enable_graph_kernel in context is set to True. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
auto_prefix (bool): Recursively generate namespaces. Default: True. |
|
|
|
|
|
flags (dict) : Set graph flags. Default: None. |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
>>> class Relu(nn.GraphKernel): |
|
|
>>> class Relu(nn.GraphKernel): |
|
|
... def __init__(self): |
|
|
... def __init__(self): |
|
|
@@ -1088,8 +1093,8 @@ class GraphKernel(Cell): |
|
|
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) |
|
|
... return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, auto_prefix=True, pips=None): |
|
|
|
|
|
super(GraphKernel, self).__init__(auto_prefix, pips) |
|
|
|
|
|
|
|
|
def __init__(self, auto_prefix=True, flags=None): |
|
|
|
|
|
super(GraphKernel, self).__init__(auto_prefix, flags) |
|
|
class_name = self.__class__.__name__ |
|
|
class_name = self.__class__.__name__ |
|
|
self.add_flags(graph_kernel=class_name) |
|
|
self.add_flags(graph_kernel=class_name) |
|
|
|
|
|
|
|
|
|