|
|
|
@@ -84,7 +84,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay") |
|
|
|
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): |
|
|
|
"""Get grad with weight_decay.""" |
|
|
|
if if_apply: |
|
|
|
return op_add((gradient, weight * F.scalar_to_array(weight_decay))) |
|
|
|
return op_add((gradient, weight * weight_decay)) |
|
|
|
return gradient |
|
|
|
|
|
|
|
|
|
|
|
|