Browse Source

Fix a bug of flclient

tags/v1.3.0
jin-xiulang 4 years ago
parent
commit
ab2cece910
2 changed files with 20 additions and 8 deletions
  1. +1
    -1
      mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java
  2. +19
    -7
      mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java

+ 1
- 1
mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/FLLiteClient.java View File

@@ -58,7 +58,7 @@ public class FLLiteClient {
private double dpEps = 100;
private double dpDelta = 0.01;
public double dpNormClipFactor = 1.0;
public double dpNormClipAdapt = 0.5;
public double dpNormClipAdapt = 0.05;

private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();


+ 19
- 7
mindspore/lite/java/java/fl_client/src/main/java/com/mindspore/flclient/SyncFLJob.java View File

@@ -161,13 +161,24 @@ public class SyncFLJob {
private void updateDpNormClip(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
int currentIter = client.getIteration();
Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
client.dpNormClipAdapt = client.dpNormClipFactor*fedWeightUpdateNorm;
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm;
if (currentIter == 1) {
client.dpNormClipAdapt = newNormCLip;
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} else {
if (newNormCLip < client.dpNormClipAdapt) {
client.dpNormClipAdapt = newNormCLip;
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
}
}
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
}
}

private void getOldFeatureMap(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
@@ -175,17 +186,18 @@ public class SyncFLJob {
oldFeatureMap = client.getOldMapCopy(featureMap);
}
}
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData){

private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
float updateL2Norm = 0;
for (String key: originalData.keySet()) {
float[] data=originalData.get(key);
for (String key : originalData.keySet()) {
float[] data = originalData.get(key);
float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j<data.length; j++) {
for (int j = 0; j < data.length; j++) {
float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData*updateData;
updateL2Norm += updateData * updateData;
}
}
updateL2Norm = (float)Math.sqrt(updateL2Norm);
updateL2Norm = (float) Math.sqrt(updateL2Norm);
return updateL2Norm;
}



Loading…
Cancel
Save