|
|
|
@@ -161,13 +161,24 @@ public class SyncFLJob { |
|
|
|
private void updateDpNormClip(FLLiteClient client) { |
|
|
|
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); |
|
|
|
if (encryptLevel == EncryptLevel.DP_ENCRYPT) { |
|
|
|
int currentIter = client.getIteration(); |
|
|
|
Map<String, float[]> fedFeatureMap = getFeatureMap(); |
|
|
|
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap); |
|
|
|
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm)); |
|
|
|
client.dpNormClipAdapt = client.dpNormClipFactor*fedWeightUpdateNorm; |
|
|
|
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm; |
|
|
|
if (currentIter == 1) { |
|
|
|
client.dpNormClipAdapt = newNormCLip; |
|
|
|
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); |
|
|
|
} else { |
|
|
|
if (newNormCLip < client.dpNormClipAdapt) { |
|
|
|
client.dpNormClipAdapt = newNormCLip; |
|
|
|
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); |
|
|
|
} |
|
|
|
} |
|
|
|
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private void getOldFeatureMap(FLLiteClient client) { |
|
|
|
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); |
|
|
|
if (encryptLevel == EncryptLevel.DP_ENCRYPT) { |
|
|
|
@@ -175,17 +186,18 @@ public class SyncFLJob { |
|
|
|
oldFeatureMap = client.getOldMapCopy(featureMap); |
|
|
|
} |
|
|
|
} |
|
|
|
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData){ |
|
|
|
|
|
|
|
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) { |
|
|
|
float updateL2Norm = 0; |
|
|
|
for (String key: originalData.keySet()) { |
|
|
|
float[] data=originalData.get(key); |
|
|
|
for (String key : originalData.keySet()) { |
|
|
|
float[] data = originalData.get(key); |
|
|
|
float[] dataAfterUpdate = newData.get(key); |
|
|
|
for (int j = 0; j<data.length; j++) { |
|
|
|
for (int j = 0; j < data.length; j++) { |
|
|
|
float updateData = data[j] - dataAfterUpdate[j]; |
|
|
|
updateL2Norm += updateData*updateData; |
|
|
|
updateL2Norm += updateData * updateData; |
|
|
|
} |
|
|
|
} |
|
|
|
updateL2Norm = (float)Math.sqrt(updateL2Norm); |
|
|
|
updateL2Norm = (float) Math.sqrt(updateL2Norm); |
|
|
|
return updateL2Norm; |
|
|
|
} |
|
|
|
|
|
|
|
|