|
|
|
@@ -101,7 +101,6 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Tensor", "Bool", "Bool", "Tensor") |
|
|
|
|
|
|
|
def _tensors_allreduce_post(degree, mean, allreduce_filter, grad): |
|
|
|
""" |
|
|
|
Apply allreduce on gradient in PyNative mode. |
|
|
|
@@ -125,7 +124,6 @@ def _tensors_allreduce_post(degree, mean, allreduce_filter, grad): |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor", "Bool") |
|
|
|
|
|
|
|
def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): |
|
|
|
""" |
|
|
|
Apply allreduce on gradient. |
|
|
|
@@ -154,7 +152,6 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor") |
|
|
|
|
|
|
|
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): |
|
|
|
""" |
|
|
|
Apply allgather on gradient instead of allreduce for sparse feature. |
|
|
|
@@ -181,7 +178,6 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce |
|
|
|
|
|
|
|
|
|
|
|
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool") |
|
|
|
|
|
|
|
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): |
|
|
|
""" |
|
|
|
Apply allgather on gradient instead of allreduce for sparse feature. |
|
|
|
@@ -215,7 +211,6 @@ _get_datatype = C.MultitypeFuncGraph("_get_datatype") |
|
|
|
|
|
|
|
|
|
|
|
@_get_datatype.register("Tensor") |
|
|
|
|
|
|
|
def _tensors_get_datatype(grad): |
|
|
|
""" |
|
|
|
Acquire gradient datatype. |
|
|
|
@@ -230,7 +225,6 @@ def _tensors_get_datatype(grad): |
|
|
|
|
|
|
|
|
|
|
|
@_get_datatype.register("RowTensor") |
|
|
|
|
|
|
|
def _tensors_get_datatype_with_sparse(grad): |
|
|
|
""" |
|
|
|
Acquire gradient datatype. |
|
|
|
@@ -248,7 +242,6 @@ _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") |
|
|
|
|
|
|
|
|
|
|
|
@_cast_datatype.register("TypeType", "Tensor") |
|
|
|
|
|
|
|
def _tensors_cast_datatype(datatype, grad): |
|
|
|
""" |
|
|
|
Cast gradient to datatype. |
|
|
|
@@ -264,7 +257,6 @@ def _tensors_cast_datatype(datatype, grad): |
|
|
|
|
|
|
|
|
|
|
|
@_cast_datatype.register("TypeType", "RowTensor") |
|
|
|
|
|
|
|
def _tensors_cast_datatype_with_sparse(datatype, grad): |
|
|
|
""" |
|
|
|
Cast gradient to datatype. |
|
|
|
|