|
|
@@ -1544,9 +1544,9 @@ class ApplyMomentum(PrimitiveWithInfer): |
|
|
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, |
|
|
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, |
|
|
sig_dtype.T), |
|
|
sig_dtype.T), |
|
|
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, |
|
|
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, |
|
|
sig_dtype.T), |
|
|
|
|
|
|
|
|
sig_dtype.T1), |
|
|
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), |
|
|
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) |
|
|
|
|
|
|
|
|
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
@prim_attr_register |
|
|
@prim_attr_register |
|
|
|