|
|
|
@@ -532,6 +532,7 @@ class Push(PrimitiveWithInfer): |
|
|
|
def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None): |
|
|
|
"""init Push""" |
|
|
|
self.add_prim_attr("primitive_target", "CPU") |
|
|
|
self.add_prim_attr("_side_effect", True) |
|
|
|
self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key']) |
|
|
|
|
|
|
|
def infer_shape(self, inputs, shapes): |
|
|
|
|