| @@ -0,0 +1,578 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.cipher.AESEncrypt; | |||
| import com.mindspore.flclient.cipher.BaseUtil; | |||
| import com.mindspore.flclient.cipher.ClientListReq; | |||
| import com.mindspore.flclient.cipher.KEYAgreement; | |||
| import com.mindspore.flclient.cipher.Random; | |||
| import com.mindspore.flclient.cipher.ReconstructSecretReq; | |||
| import com.mindspore.flclient.cipher.ShareSecrets; | |||
| import com.mindspore.flclient.cipher.struct.ClientPublicKey; | |||
| import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; | |||
| import com.mindspore.flclient.cipher.struct.EncryptShare; | |||
| import com.mindspore.flclient.cipher.struct.NewArray; | |||
| import com.mindspore.flclient.cipher.struct.ShareSecret; | |||
| import mindspore.schema.ClientShare; | |||
| import mindspore.schema.GetExchangeKeys; | |||
| import mindspore.schema.GetShareSecrets; | |||
| import mindspore.schema.RequestExchangeKeys; | |||
| import mindspore.schema.RequestShareSecrets; | |||
| import mindspore.schema.ResponseCode; | |||
| import mindspore.schema.ResponseExchangeKeys; | |||
| import mindspore.schema.ResponseShareSecrets; | |||
| import mindspore.schema.ReturnExchangeKeys; | |||
| import mindspore.schema.ReturnShareSecrets; | |||
| import java.io.UnsupportedEncodingException; | |||
| import java.math.BigInteger; | |||
| import java.nio.ByteBuffer; | |||
| import java.security.NoSuchAlgorithmException; | |||
| import java.security.spec.InvalidKeySpecException; | |||
| import java.time.LocalDateTime; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.logging.Logger; | |||
| import static com.mindspore.flclient.FLParameter.SLEEP_TIME; | |||
| import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN; | |||
| import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE; | |||
| public class CipherClient { | |||
| private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString()); | |||
| private FLCommunication flCommunication; | |||
| private FLParameter flParameter = FLParameter.getInstance(); | |||
| private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); | |||
| private final int iteration; | |||
| private int featureSize; | |||
| private int t; | |||
| private List<byte[]> cKey = new ArrayList<>(); | |||
| private List<byte[]> sKey = new ArrayList<>(); | |||
| private byte[] bu; | |||
| private String nextRequestTime; | |||
| private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>(); | |||
| private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>(); | |||
| private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>(); | |||
| private List<EncryptShare> clientShareList = new ArrayList<>(); | |||
| private List<EncryptShare> returnShareList = new ArrayList<>(); | |||
| private float[] featureMask; | |||
| private List<String> u1ClientList = new ArrayList<>(); | |||
| private List<String> u2UClientList = new ArrayList<>(); | |||
| private List<String> u3ClientList = new ArrayList<>(); | |||
| private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>(); | |||
| private byte[] prime; | |||
| private KEYAgreement keyAgreement = new KEYAgreement(); | |||
| private Random random = new Random(); | |||
| private ClientListReq clientListReq = new ClientListReq(); | |||
| private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq(); | |||
| public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) { | |||
| flCommunication = FLCommunication.getInstance(); | |||
| this.iteration = iter; | |||
| this.featureSize = featureSize; | |||
| this.t = minSecretNum; | |||
| this.prime = prime; | |||
| this.featureMask = new float[this.featureSize]; | |||
| } | |||
| public void setNextRequestTime(String nextRequestTime) { | |||
| this.nextRequestTime = nextRequestTime; | |||
| } | |||
| public void setBU(byte[] bu) { | |||
| this.bu = bu; | |||
| } | |||
| public void setClientShareList(List<EncryptShare> clientShareList) { | |||
| this.clientShareList.clear(); | |||
| this.clientShareList = clientShareList; | |||
| } | |||
| public String getNextRequestTime() { | |||
| return nextRequestTime; | |||
| } | |||
| public void genDHKeyPairs() { | |||
| byte[] csk = keyAgreement.generatePrivateKey(); | |||
| byte[] cpk = keyAgreement.generatePublicKey(csk); | |||
| byte[] ssk = keyAgreement.generatePrivateKey(); | |||
| byte[] spk = keyAgreement.generatePublicKey(ssk); | |||
| this.cKey.add(cpk); | |||
| this.cKey.add(csk); | |||
| this.sKey.add(spk); | |||
| this.sKey.add(ssk); | |||
| } | |||
| public void genIndividualSecret() { | |||
| byte[] key = new byte[SEED_SIZE]; | |||
| random.getRandomBytes(key); | |||
| setBU(key); | |||
| } | |||
| public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException { | |||
| List<ShareSecret> shareSecretList = new ArrayList<>(); | |||
| int size = u1ClientList.size(); | |||
| ShareSecrets shamir = new ShareSecrets(t, size - 1); | |||
| ShareSecrets.SecretShare[] shares = shamir.split(secret, prime); | |||
| int j = 0; | |||
| for (int i = 0; i < size; i++) { | |||
| String vFlID = u1ClientList.get(i); | |||
| if (localFLParameter.getFlID().equals(vFlID)) { | |||
| continue; | |||
| } else { | |||
| ShareSecret shareSecret = new ShareSecret(); | |||
| NewArray<byte[]> array = new NewArray<>(); | |||
| int index = shares[j].getNum(); | |||
| BigInteger intShare = shares[j].getShare(); | |||
| byte[] share = BaseUtil.bigInteger2byteArray(intShare); | |||
| array.setSize(share.length); | |||
| array.setArray(share); | |||
| shareSecret.setFlID(vFlID); | |||
| shareSecret.setShare(array); | |||
| shareSecret.setIndex(index); | |||
| shareSecretList.add(shareSecret); | |||
| j += 1; | |||
| } | |||
| } | |||
| return shareSecretList; | |||
| } | |||
| public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException { | |||
| cUVKeys.clear(); | |||
| for (String key : clientPublicKeyList.keySet()) { | |||
| ClientPublicKey curPublicKey = clientPublicKeyList.get(key); | |||
| String vFlID = curPublicKey.getFlID(); | |||
| if (localFLParameter.getFlID().equals(vFlID)) { | |||
| continue; | |||
| } else { | |||
| byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray()); | |||
| byte[] salt = new byte[0]; | |||
| byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt); | |||
| cUVKeys.put(vFlID, secret); | |||
| } | |||
| } | |||
| } | |||
| public void encryptShares() throws Exception { | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************")); | |||
| List<EncryptShare> encryptShareList = new ArrayList<>(); | |||
| // connect sSkUv, bUV, sIndex, indexB and then Encrypt them | |||
| List<ShareSecret> sSkUv = genSecretShares(sKey.get(1)); | |||
| List<ShareSecret> bUV = genSecretShares(bu); | |||
| for (int i = 0; i < bUV.size(); i++) { | |||
| EncryptShare encryptShare = new EncryptShare(); | |||
| NewArray<byte[]> array = new NewArray<>(); | |||
| String vFlID = bUV.get(i).getFlID(); | |||
| byte[] sShare = sSkUv.get(i).getShare().getArray(); | |||
| byte[] bShare = bUV.get(i).getShare().getArray(); | |||
| byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex()); | |||
| byte[] bIndex = BaseUtil.integer2byteArray(bUV.get(i).getIndex()); | |||
| byte[] allSecret = new byte[sShare.length + bShare.length + sIndex.length + bIndex.length + 4]; | |||
| allSecret[0] = (byte) sShare.length; | |||
| allSecret[1] = (byte) bShare.length; | |||
| allSecret[2] = (byte) sIndex.length; | |||
| allSecret[3] = (byte) bIndex.length; | |||
| System.arraycopy(sIndex, 0, allSecret, 4, sIndex.length); | |||
| System.arraycopy(bIndex, 0, allSecret, 4 + sIndex.length, bIndex.length); | |||
| System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length); | |||
| System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length); | |||
| // encrypt: | |||
| byte[] iVecIn = new byte[IVEC_LEN]; | |||
| AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC"); | |||
| byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret); | |||
| array.setSize(encryptData.length); | |||
| array.setArray(encryptData); | |||
| encryptShare.setFlID(vFlID); | |||
| encryptShare.setShare(array); | |||
| encryptShareList.add(encryptShare); | |||
| } | |||
| setClientShareList(encryptShareList); | |||
| } | |||
| public float[] doubleMaskingWeight() throws Exception { | |||
| int size = u2UClientList.size(); | |||
| List<Float> noiseBu = new ArrayList<>(); | |||
| random.randomAESCTR(noiseBu, featureSize, bu); | |||
| float[] mask = new float[featureSize]; | |||
| for (int i = 0; i < size; i++) { | |||
| String vFlID = u2UClientList.get(i); | |||
| ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID); | |||
| if (localFLParameter.getFlID().equals(vFlID)) { | |||
| continue; | |||
| } else { | |||
| byte[] salt = new byte[0]; | |||
| byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray()); | |||
| byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt); | |||
| sUVKeys.put(vFlID, secret); | |||
| List<Float> noiseSuv = new ArrayList<>(); | |||
| random.randomAESCTR(noiseSuv, featureSize, secret); | |||
| int sign; | |||
| if (localFLParameter.getFlID().compareTo(vFlID) > 0) { | |||
| sign = 1; | |||
| } else { | |||
| sign = -1; | |||
| } | |||
| for (int j = 0; j < noiseSuv.size(); j++) { | |||
| mask[j] = mask[j] + sign * noiseSuv.get(j); | |||
| } | |||
| } | |||
| } | |||
| for (int j = 0; j < noiseBu.size(); j++) { | |||
| mask[j] = mask[j] + noiseBu.get(j); | |||
| } | |||
| return mask; | |||
| } | |||
| public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) { | |||
| NewArray<byte[]> newArray = new NewArray<>(); | |||
| newArray.setSize(size); | |||
| byte[] array = new byte[size]; | |||
| for (int i = 0; i < size; i++) { | |||
| byte word = buf.get(); | |||
| array[i] = word; | |||
| } | |||
| newArray.setArray(array); | |||
| return newArray; | |||
| } | |||
| public FLClientStatus requestExchangeKeys() { | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============")); | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============requestExchangeKeys url: " + url + "==============")); | |||
| genDHKeyPairs(); | |||
| byte[] cPK = cKey.get(0); | |||
| byte[] sPK = sKey.get(0); | |||
| FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); | |||
| int id = fbBuilder.createString(localFLParameter.getFlID()); | |||
| int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK); | |||
| int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = fbBuilder.createString(dateTime); | |||
| int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time); | |||
| fbBuilder.finish(exchangeKeysRoot); | |||
| byte[] msg = fbBuilder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer); | |||
| FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys); | |||
| return status; | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success")); | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseExchangeKeys is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus getExchangeKeys() { | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============getExchangeKeys url: " + url + "==============")); | |||
| FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); | |||
| int id = fbBuilder.createString(localFLParameter.getFlID()); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = fbBuilder.createString(dateTime); | |||
| int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time); | |||
| fbBuilder.finish(getExchangeKeysRoot); | |||
| byte[] msg = fbBuilder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer); | |||
| FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys); | |||
| return status; | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys success")); | |||
| clientPublicKeyList.clear(); | |||
| u1ClientList.clear(); | |||
| int length = bufData.remotePublickeysLength(); | |||
| for (int i = 0; i < length; i++) { | |||
| ClientPublicKey publicKey = new ClientPublicKey(); | |||
| String srcFlId = bufData.remotePublickeys(i).flId(); | |||
| publicKey.setFlID(srcFlId); | |||
| ByteBuffer bufCpk = bufData.remotePublickeys(i).cPkAsByteBuffer(); | |||
| int sizeCpk = bufData.remotePublickeys(i).cPkLength(); | |||
| ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer(); | |||
| int sizeSpk = bufData.remotePublickeys(i).sPkLength(); | |||
| publicKey.setCPK(byteToArray(bufCpk, sizeCpk)); | |||
| publicKey.setSPK(byteToArray(bufSpk, sizeSpk)); | |||
| clientPublicKeyList.put(srcFlId, publicKey); | |||
| u1ClientList.add(srcFlId); | |||
| } | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.SucNotReady): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!")); | |||
| return FLClientStatus.WAIT; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnExchangeKeys is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus requestShareSecrets() throws Exception { | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============requestShareSecrets url: " + url + "==============")); | |||
| genIndividualSecret(); | |||
| genEncryptExchangedKeys(); | |||
| encryptShares(); | |||
| FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); | |||
| int id = fbBuilder.createString(localFLParameter.getFlID()); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = fbBuilder.createString(dateTime); | |||
| int clientShareSize = clientShareList.size(); | |||
| if (clientShareSize <= 0) { | |||
| LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!")); | |||
| Common.sleep(SLEEP_TIME); | |||
| FLClientStatus status = requestShareSecrets(); | |||
| return status; | |||
| } else { | |||
| int[] add = new int[clientShareSize]; | |||
| for (int i = 0; i < clientShareSize; i++) { | |||
| int flID = fbBuilder.createString(clientShareList.get(i).getFlID()); | |||
| int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray()); | |||
| ClientShare.startClientShare(fbBuilder); | |||
| ClientShare.addFlId(fbBuilder, flID); | |||
| ClientShare.addShare(fbBuilder, shareSecretFbs); | |||
| int clientShareRoot = ClientShare.endClientShare(fbBuilder); | |||
| add[i] = clientShareRoot; | |||
| } | |||
| int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add); | |||
| int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time); | |||
| fbBuilder.finish(requestShareSecretsRoot); | |||
| byte[] msg = fbBuilder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer); | |||
| FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets); | |||
| return status; | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| } | |||
| public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success")); | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ResponseShareSecrets is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus getShareSecrets() { | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============getShareSecrets url: " + url + "==============")); | |||
| FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); | |||
| int id = fbBuilder.createString(localFLParameter.getFlID()); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = fbBuilder.createString(dateTime); | |||
| int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time); | |||
| fbBuilder.finish(getShareSecrets); | |||
| byte[] msg = fbBuilder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer); | |||
| FLClientStatus status = judgeGetShareSecrets(returnShareSecrets); | |||
| return status; | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] the size of encrypted shares: " + bufData.encryptedSharesLength())); | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets success")); | |||
| returnShareList.clear(); | |||
| u2UClientList.clear(); | |||
| int length = bufData.encryptedSharesLength(); | |||
| for (int i = 0; i < length; i++) { | |||
| EncryptShare shareSecret = new EncryptShare(); | |||
| shareSecret.setFlID(bufData.encryptedShares(i).flId()); | |||
| ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer(); | |||
| int sizeShare = bufData.encryptedShares(i).shareLength(); | |||
| shareSecret.setShare(byteToArray(bufShare, sizeShare)); | |||
| returnShareList.add(shareSecret); | |||
| u2UClientList.add(bufData.encryptedShares(i).flId()); | |||
| } | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.SucNotReady): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!")); | |||
| return FLClientStatus.WAIT; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnShareSecrets is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus exchangeKeys() { | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================")); | |||
| FLClientStatus curStatus; | |||
| // RequestExchangeKeys | |||
| curStatus = requestExchangeKeys(); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = requestExchangeKeys(); | |||
| } | |||
| if (curStatus != FLClientStatus.SUCCESS) { | |||
| return curStatus; | |||
| } | |||
| // GetExchangeKeys | |||
| curStatus = getExchangeKeys(); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = getExchangeKeys(); | |||
| } | |||
| return curStatus; | |||
| } | |||
| public FLClientStatus shareSecrets() throws Exception { | |||
| LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================"))); | |||
| FLClientStatus curStatus; | |||
| // RequestShareSecrets | |||
| curStatus = requestShareSecrets(); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = requestShareSecrets(); | |||
| } | |||
| if (curStatus != FLClientStatus.SUCCESS) { | |||
| return curStatus; | |||
| } | |||
| // GetShareSecrets | |||
| curStatus = getShareSecrets(); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = getShareSecrets(); | |||
| } | |||
| return curStatus; | |||
| } | |||
| public FLClientStatus reconstructSecrets() { | |||
| LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================")); | |||
| FLClientStatus curStatus; | |||
| // GetClientList | |||
| curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys); | |||
| } | |||
| if (curStatus == FLClientStatus.RESTART) { | |||
| nextRequestTime = clientListReq.getNextRequestTime(); | |||
| } | |||
| if (curStatus != FLClientStatus.SUCCESS) { | |||
| return curStatus; | |||
| } | |||
| // SendReconstructSecret | |||
| curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration); | |||
| while (curStatus == FLClientStatus.WAIT) { | |||
| Common.sleep(SLEEP_TIME); | |||
| curStatus = reconstructSecretReq.sendReconstructSecret(decryptShareSecretsList, u3ClientList, iteration); | |||
| } | |||
| if (curStatus == FLClientStatus.RESTART) { | |||
| nextRequestTime = reconstructSecretReq.getNextRequestTime(); | |||
| } | |||
| return curStatus; | |||
| } | |||
| } | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient; | |||
| import java.io.File; | |||
| import java.util.ArrayList; | |||
| import java.util.Arrays; | |||
| import java.util.Date; | |||
| import java.util.List; | |||
| import java.util.Random; | |||
| import java.util.logging.Logger; | |||
| public class Common { | |||
| public static final String LOG_TITLE = "<FLClient> "; | |||
| private static final Logger LOGGER = Logger.getLogger(Common.class.toString()); | |||
| private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "adbert")); | |||
| public static String generateUrl(boolean useElb, String ip, int port, int serverNum) { | |||
| String url; | |||
| if (useElb) { | |||
| Random rand = new Random(); | |||
| int randomNum = rand.nextInt(100000) % serverNum + port; | |||
| url = ip + String.valueOf(randomNum); | |||
| } else { | |||
| url = ip + String.valueOf(port); | |||
| } | |||
| return url; | |||
| } | |||
| public static void setClassifierWeightName(List<String> classifierWeightName) { | |||
| classifierWeightName.add("albert.pooler.weight"); | |||
| classifierWeightName.add("albert.pooler.bias"); | |||
| classifierWeightName.add("classifier.weight"); | |||
| classifierWeightName.add("classifier.bias"); | |||
| LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size())); | |||
| } | |||
| public static void setAlbertWeightName(List<String> albertWeightName) { | |||
| albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight"); | |||
| albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.query.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.key.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention.value.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.dense.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.gamma"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.attention.output.layernorm.beta"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.gamma"); | |||
| albertWeightName.add("albert.encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.beta"); | |||
| LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size())); | |||
| } | |||
| public static boolean checkFLName(String flName) { | |||
| return (flNameTrustList.contains(flName)); | |||
| } | |||
| public static void sleep(long millis) { | |||
| try { | |||
| Thread.sleep(millis); //1000 milliseconds is one second. | |||
| } catch (InterruptedException ex) { | |||
| LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage())); | |||
| Thread.currentThread().interrupt(); | |||
| } | |||
| } | |||
| public static long getWaitTime(String nextRequestTime) { | |||
| Date date = new Date(); | |||
| long currentTime = date.getTime(); | |||
| long waitTime = 0; | |||
| if (!nextRequestTime.equals("")) { | |||
| waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime); | |||
| } | |||
| LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime)); | |||
| LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime)); | |||
| return waitTime; | |||
| } | |||
| public static long startTime(String tag) { | |||
| Date startDate = new Date(); | |||
| long startTime = startDate.getTime(); | |||
| LOGGER.info(addTag("[start time] <" + tag + "> start time: " + startTime)); | |||
| return startTime; | |||
| } | |||
| public static void endTime(long start, String tag) { | |||
| Date endDate = new Date(); | |||
| long endTime = endDate.getTime(); | |||
| LOGGER.info(addTag("[end time] <" + tag + "> end time: " + endTime)); | |||
| LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start))); | |||
| } | |||
| public static String addTag(String message) { | |||
| return LOG_TITLE + message; | |||
| } | |||
| public static boolean isAutoscaling(byte[] message, String autoscalingTag) { | |||
| return (new String(message)).contains(autoscalingTag); | |||
| } | |||
| public static boolean checkPath(String path) { | |||
| File file = new File(path); | |||
| return file.exists(); | |||
| } | |||
| } | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -0,0 +1,163 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient; | |||
| import okhttp3.Call; | |||
| import okhttp3.Callback; | |||
| import okhttp3.MediaType; | |||
| import okhttp3.OkHttpClient; | |||
| import okhttp3.Request; | |||
| import okhttp3.RequestBody; | |||
| import okhttp3.Response; | |||
| import javax.net.ssl.HostnameVerifier; | |||
| import javax.net.ssl.SSLContext; | |||
| import javax.net.ssl.SSLSession; | |||
| import javax.net.ssl.TrustManager; | |||
| import javax.net.ssl.X509TrustManager; | |||
| import java.io.IOException; | |||
| import java.security.KeyManagementException; | |||
| import java.security.NoSuchAlgorithmException; | |||
| import java.security.cert.CertificateException; | |||
| import java.security.cert.X509Certificate; | |||
| import java.util.concurrent.TimeUnit; | |||
| import java.util.concurrent.TimeoutException; | |||
| import java.util.logging.Logger; | |||
| import static com.mindspore.flclient.FLParameter.TIME_OUT; | |||
| public class FLCommunication implements IFLCommunication { | |||
| private static int timeOut; | |||
| private static boolean ssl = false; | |||
| private FLParameter flParameter = FLParameter.getInstance(); | |||
| private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8"); | |||
| private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString()); | |||
| private OkHttpClient client; | |||
| private static FLCommunication communication; | |||
| private FLCommunication() { | |||
| if (flParameter.getTimeOut() != 0) { | |||
| timeOut = flParameter.getTimeOut(); | |||
| } else { | |||
| timeOut = TIME_OUT; | |||
| } | |||
| ssl = flParameter.isUseSSL(); | |||
| client = getUnsafeOkHttpClient(); | |||
| } | |||
| private static OkHttpClient getUnsafeOkHttpClient() { | |||
| X509TrustManager trustManager = new X509TrustManager() { | |||
| @Override | |||
| public X509Certificate[] getAcceptedIssuers() { | |||
| return new X509Certificate[]{}; | |||
| } | |||
| @Override | |||
| public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException { | |||
| } | |||
| @Override | |||
| public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException { | |||
| } | |||
| }; | |||
| final TrustManager[] trustAllCerts = new TrustManager[]{trustManager}; | |||
| try { | |||
| LOGGER.info(Common.addTag("the set timeOut in OkHttpClient: " + timeOut)); | |||
| OkHttpClient.Builder builder = new OkHttpClient.Builder(); | |||
| builder.connectTimeout(timeOut, TimeUnit.SECONDS); | |||
| builder.writeTimeout(timeOut, TimeUnit.SECONDS); | |||
| builder.readTimeout(3 * timeOut, TimeUnit.SECONDS); | |||
| if (ssl) { | |||
| builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager()); | |||
| builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier()); | |||
| } else { | |||
| final SSLContext sslContext = SSLContext.getInstance("TLS"); | |||
| sslContext.init(null, trustAllCerts, new java.security.SecureRandom()); | |||
| final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); | |||
| builder.sslSocketFactory(sslSocketFactory, trustManager); | |||
| builder.hostnameVerifier(new HostnameVerifier() { | |||
| @Override | |||
| public boolean verify(String arg0, SSLSession arg1) { | |||
| return true; | |||
| } | |||
| }); | |||
| } | |||
| return builder.build(); | |||
| } catch (NoSuchAlgorithmException | KeyManagementException e) { | |||
| LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage())); | |||
| throw new RuntimeException(e); | |||
| } | |||
| } | |||
| public static FLCommunication getInstance() { | |||
| if (communication == null) { | |||
| synchronized (FLCommunication.class) { | |||
| if (communication == null) { | |||
| communication = new FLCommunication(); | |||
| } | |||
| } | |||
| } | |||
| return communication; | |||
| } | |||
| @Override | |||
| public void setTimeOut(int timeout) throws TimeoutException { | |||
| } | |||
| @Override | |||
| public byte[] syncRequest(String url, byte[] msg) throws IOException { | |||
| Request request = new Request.Builder() | |||
| .url(url) | |||
| .post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build(); | |||
| Response response = this.client.newCall(request).execute(); | |||
| if (!response.isSuccessful()) { | |||
| throw new IOException("Unexpected code " + response); | |||
| } | |||
| return response.body().bytes(); | |||
| } | |||
| @Override | |||
| public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception { | |||
| Request request = new Request.Builder() | |||
| .url(url) | |||
| .header("Accept", "application/proto") | |||
| .header("Content-Type", "application/proto; charset=utf-8") | |||
| .post(RequestBody.create(MEDIA_TYPE_JSON, msg)).build(); | |||
| client.newCall(request).enqueue(new Callback() { | |||
| IAsyncCallBack asyncCallBack = callBack; | |||
| @Override | |||
| public void onResponse(Call call, Response response) throws IOException { | |||
| asyncCallBack.onResponse(response.body().bytes()); | |||
| call.cancel(); | |||
| } | |||
| @Override | |||
| public void onFailure(Call call, IOException e) { | |||
| asyncCallBack.onFailure(e); | |||
| call.cancel(); | |||
| } | |||
| }); | |||
| } | |||
| } | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -17,13 +17,15 @@ package com.mindspore.flclient; | |||
| import java.util.logging.Logger; | |||
| public class FLJobResultCallback implements IFLJobResultCallback{ | |||
| private static final Logger logger = Logger.getLogger(FLJobResultCallback.class.toString()); | |||
| public class FLJobResultCallback implements IFLJobResultCallback { | |||
| private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString()); | |||
| public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) { | |||
| logger.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode)); | |||
| LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode)); | |||
| } | |||
| public void onFlJobFinished(String modelName, int iterationCount, int resultCode) { | |||
| logger.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode)); | |||
| LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode)); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -15,10 +15,18 @@ | |||
| */ | |||
| package com.mindspore.flclient; | |||
| import javax.net.ssl.*; | |||
| import javax.net.ssl.HostnameVerifier; | |||
| import javax.net.ssl.SSLContext; | |||
| import javax.net.ssl.SSLSession; | |||
| import javax.net.ssl.SSLSocketFactory; | |||
| import javax.net.ssl.TrustManager; | |||
| import javax.net.ssl.X509TrustManager; | |||
| import java.io.FileInputStream; | |||
| import java.io.InputStream; | |||
| import java.security.*; | |||
| import java.security.InvalidKeyException; | |||
| import java.security.NoSuchAlgorithmException; | |||
| import java.security.NoSuchProviderException; | |||
| import java.security.SignatureException; | |||
| import java.security.cert.CertificateException; | |||
| import java.security.cert.CertificateFactory; | |||
| import java.security.cert.X509Certificate; | |||
| @@ -32,11 +40,12 @@ public class SSLSocketFactoryTools { | |||
| private SSLContext sslContext; | |||
| private MyTrustManager myTrustManager; | |||
| private static SSLSocketFactoryTools instance; | |||
| private SSLSocketFactoryTools() { | |||
| initSslSocketFactory(); | |||
| } | |||
| private void initSslSocketFactory(){ | |||
| private void initSslSocketFactory() { | |||
| try { | |||
| sslContext = SSLContext.getInstance("TLS"); | |||
| x509Certificate = readCert(flParameter.getCertPath()); | |||
| @@ -51,16 +60,14 @@ public class SSLSocketFactoryTools { | |||
| } | |||
| } | |||
| public static SSLSocketFactoryTools getInstance() { | |||
| if (instance == null) { | |||
| instance=new SSLSocketFactoryTools(); | |||
| instance = new SSLSocketFactoryTools(); | |||
| } | |||
| return instance; | |||
| } | |||
| public X509Certificate readCert(String assetName) { | |||
| public X509Certificate readCert(String assetName) { | |||
| InputStream inputStream = null; | |||
| try { | |||
| inputStream = new FileInputStream(assetName); | |||
| @@ -110,7 +117,6 @@ public class SSLSocketFactoryTools { | |||
| public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { | |||
| } | |||
| @Override | |||
| public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { | |||
| for (X509Certificate cert : chain) { | |||
| @@ -130,6 +136,7 @@ public class SSLSocketFactoryTools { | |||
| } catch (SignatureException e) { | |||
| logger.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage())); | |||
| } | |||
| logger.info(Common.addTag("checkServerTrusted success!")); | |||
| } | |||
| } | |||
| @@ -0,0 +1,323 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.model.AdTrainBert; | |||
| import com.mindspore.flclient.model.SessionUtil; | |||
| import com.mindspore.flclient.model.TrainLenet; | |||
| import mindspore.schema.FeatureMap; | |||
| import java.security.SecureRandom; | |||
| import java.util.ArrayList; | |||
| import java.util.HashMap; | |||
| import java.util.Map; | |||
| import java.util.Random; | |||
| import java.util.logging.Logger; | |||
| public class SecureProtocol { | |||
| private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString()); | |||
| private FLParameter flParameter = FLParameter.getInstance(); | |||
| private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); | |||
| private int iteration; | |||
| private CipherClient cipher; | |||
| private FLClientStatus status; | |||
| private float[] featureMask; | |||
| private double dpEps; | |||
| private double dpDelta; | |||
| private double dpNormClip; | |||
| private static double deltaError = 1e-6; | |||
| private static Map<String, float[]> modelMap; | |||
| private ArrayList<String> encryptFeatureName = new ArrayList<String>(); | |||
| public FLClientStatus getStatus() { | |||
| return status; | |||
| } | |||
| public float[] getFeatureMask() { | |||
| return featureMask; | |||
| } | |||
| public SecureProtocol() { | |||
| } | |||
| public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) { | |||
| this.iteration = iter; | |||
| this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize); | |||
| } | |||
| public FLClientStatus setDPParameter(int iter, double diffEps, | |||
| double diffDelta, double diffNorm, Map<String, float[]> map) { | |||
| try { | |||
| this.iteration = iter; | |||
| this.dpEps = diffEps; | |||
| this.dpDelta = diffDelta; | |||
| this.dpNormClip = diffNorm; | |||
| this.modelMap = map; | |||
| status = FLClientStatus.SUCCESS; | |||
| } catch (Exception e) { | |||
| LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage())); | |||
| status = FLClientStatus.FAILED; | |||
| } | |||
| return status; | |||
| } | |||
| public ArrayList<String> getEncryptFeatureName() { | |||
| return encryptFeatureName; | |||
| } | |||
| public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) { | |||
| this.encryptFeatureName = encryptFeatureName; | |||
| } | |||
| public String getNextRequestTime() { | |||
| return cipher.getNextRequestTime(); | |||
| } | |||
| public FLClientStatus pwCreateMask() { | |||
| LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="); | |||
| // round 0 | |||
| status = cipher.exchangeKeys(); | |||
| LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============"); | |||
| if (status != FLClientStatus.SUCCESS) { | |||
| return status; | |||
| } | |||
| // round 1 | |||
| try { | |||
| status = cipher.shareSecrets(); | |||
| LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "============="); | |||
| } catch (Exception e) { | |||
| LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask"); | |||
| status = FLClientStatus.FAILED; | |||
| } | |||
| if (status != FLClientStatus.SUCCESS) { | |||
| return status; | |||
| } | |||
| // round2 | |||
| try { | |||
| featureMask = cipher.doubleMaskingWeight(); | |||
| LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS============="); | |||
| } catch (Exception e) { | |||
| LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask"); | |||
| status = FLClientStatus.FAILED; | |||
| } | |||
| return status; | |||
| } | |||
| public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) { | |||
| LOGGER.info("[Encrypt] feature mask size: " + featureMask.length); | |||
| // get feature map | |||
| Map<String, float[]> map = new HashMap<String, float[]>(); | |||
| if (flParameter.getFlName().equals("adbert")) { | |||
| AdTrainBert adTrainBert = AdTrainBert.getInstance(); | |||
| map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); | |||
| } else if (flParameter.getFlName().equals("lenet")) { | |||
| TrainLenet trainLenet = TrainLenet.getInstance(); | |||
| map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); | |||
| } | |||
| int featureSize = encryptFeatureName.size(); | |||
| int[] featuresMap = new int[featureSize]; | |||
| int maskIndex = 0; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = encryptFeatureName.get(i); | |||
| float[] data = map.get(key); | |||
| LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length); | |||
| for (int j = 0; j < data.length; j++) { | |||
| float rawData = data[j]; | |||
| float maskData = rawData * trainDataSize + featureMask[maskIndex]; | |||
| maskIndex += 1; | |||
| data[j] = maskData; | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data); | |||
| int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| featuresMap[i] = featureMap; | |||
| } | |||
| return featuresMap; | |||
| } | |||
| public FLClientStatus pwUnmasking() { | |||
| status = cipher.reconstructSecrets(); // round3 | |||
| LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "============="); | |||
| return status; | |||
| } | |||
| private static float calculateErf(double x) { | |||
| double result = 0; | |||
| int segmentNum = 10000; | |||
| double deltaX = x / segmentNum; | |||
| result += 1; | |||
| for (int i = 1; i < segmentNum; i++) { | |||
| result += 2 * Math.exp(-Math.pow(deltaX * i, 2)); | |||
| } | |||
| result += Math.exp(-Math.pow(deltaX * segmentNum, 2)); | |||
| return (float) (result * deltaX / Math.pow(Math.PI, 0.5)); | |||
| } | |||
| private static double calculatePhi(double t) { | |||
| return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0)))); | |||
| } | |||
| private static double calculateBPositive(double eps, double s) { | |||
| return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0))); | |||
| } | |||
| private static double calculateBNegative(double eps, double s) { | |||
| return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0))); | |||
| } | |||
| private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) { | |||
| double deltaSup = calculateBPositive(eps, sSup); | |||
| while (deltaSup <= targetDelta) { | |||
| sInf = sSup; | |||
| sSup = 2 * sInf; | |||
| deltaSup = calculateBPositive(eps, sSup); | |||
| } | |||
| double sMid = sInf + (sSup - sInf) / 2.0; | |||
| int iterMax = 1000; | |||
| int iters = 0; | |||
| while (true) { | |||
| double b = calculateBPositive(eps, sMid); | |||
| if (b <= targetDelta) { | |||
| if (targetDelta - b <= deltaError) { | |||
| break; | |||
| } else { | |||
| sInf = sMid; | |||
| } | |||
| } else { | |||
| sSup = sMid; | |||
| } | |||
| sMid = sInf + (sSup - sInf) / 2.0; | |||
| iters += 1; | |||
| if (iters > iterMax) { | |||
| break; | |||
| } | |||
| } | |||
| return sMid; | |||
| } | |||
| private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) { | |||
| double deltaSup = calculateBNegative(eps, sSup); | |||
| while (deltaSup > targetDelta) { | |||
| sInf = sSup; | |||
| sSup = 2 * sInf; | |||
| deltaSup = calculateBNegative(eps, sSup); | |||
| } | |||
| double sMid = sInf + (sSup - sInf) / 2.0; | |||
| int iterMax = 1000; | |||
| int iters = 0; | |||
| while (true) { | |||
| double b = calculateBNegative(eps, sMid); | |||
| if (b <= targetDelta) { | |||
| if (targetDelta - b <= deltaError) { | |||
| break; | |||
| } else { | |||
| sSup = sMid; | |||
| } | |||
| } else { | |||
| sInf = sMid; | |||
| } | |||
| sMid = sInf + (sSup - sInf) / 2.0; | |||
| iters += 1; | |||
| if (iters > iterMax) { | |||
| break; | |||
| } | |||
| } | |||
| return sMid; | |||
| } | |||
| private static double calculateSigma(double clipNorm, double eps, double targetDelta) { | |||
| double deltaZero = calculateBPositive(eps, 0); | |||
| double alpha = 1; | |||
| if (targetDelta > deltaZero) { | |||
| double s = calculateSPositive(eps, targetDelta, 0, 1); | |||
| alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0); | |||
| } else if (targetDelta < deltaZero) { | |||
| double s = calculateSNegative(eps, targetDelta, 0, 1); | |||
| alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0); | |||
| } | |||
| return alpha * clipNorm / Math.sqrt(2.0 * eps); | |||
| } | |||
| public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) { | |||
| // get feature map | |||
| Map<String, float[]> map = new HashMap<String, float[]>(); | |||
| if (flParameter.getFlName().equals("adbert")) { | |||
| AdTrainBert adTrainBert = AdTrainBert.getInstance(); | |||
| map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(adTrainBert.getTrainSession())); | |||
| } else if (flParameter.getFlName().equals("lenet")) { | |||
| TrainLenet trainLenet = TrainLenet.getInstance(); | |||
| map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); | |||
| } | |||
| Map<String, float[]> mapBeforeTrain = modelMap; | |||
| int featureSize = encryptFeatureName.size(); | |||
| // calculate sigma | |||
| double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta); | |||
| LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "=============")); | |||
| // prepare gaussian noise | |||
| SecureRandom random = new SecureRandom(); | |||
| int randomInt = random.nextInt(); | |||
| Random r = new Random(randomInt); | |||
| // calculate l2-norm of all layers' update array | |||
| double updateL2Norm = 0; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = encryptFeatureName.get(i); | |||
| float[] data = map.get(key); | |||
| float[] dataBeforeTrain = mapBeforeTrain.get(key); | |||
| for (int j = 0; j < data.length; j++) { | |||
| float rawData = data[j]; | |||
| float rawDataBeforeTrain = dataBeforeTrain[j]; | |||
| float updateData = rawData - rawDataBeforeTrain; | |||
| updateL2Norm += updateData * updateData; | |||
| } | |||
| } | |||
| updateL2Norm = Math.sqrt(updateL2Norm); | |||
| double clipFactor = Math.min(1.0, dpNormClip / updateL2Norm); | |||
| // clip and add noise | |||
| int[] featuresMap = new int[featureSize]; | |||
| for (int i = 0; i < featureSize; i++) { | |||
| String key = encryptFeatureName.get(i); | |||
| float[] data = map.get(key); | |||
| float[] data2 = new float[data.length]; | |||
| float[] dataBeforeTrain = mapBeforeTrain.get(key); | |||
| for (int j = 0; j < data.length; j++) { | |||
| float rawData = data[j]; | |||
| float rawDataBeforeTrain = dataBeforeTrain[j]; | |||
| float updateData = rawData - rawDataBeforeTrain; | |||
| // clip | |||
| updateData *= clipFactor; | |||
| // add noise | |||
| double gaussianNoise = r.nextGaussian() * gaussianSigma; | |||
| updateData += gaussianNoise; | |||
| data2[j] = rawDataBeforeTrain + updateData; | |||
| data2[j] = data2[j] * trainDataSize; | |||
| } | |||
| int featureName = builder.createString(key); | |||
| int weight = FeatureMap.createDataVector(builder, data2); | |||
| int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight); | |||
| featuresMap[i] = featureMap; | |||
| } | |||
| return featuresMap; | |||
| } | |||
| } | |||
| @@ -1,12 +1,12 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| @@ -0,0 +1,104 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import com.mindspore.flclient.Common; | |||
| import javax.crypto.Cipher; | |||
| import javax.crypto.spec.IvParameterSpec; | |||
| import javax.crypto.spec.SecretKeySpec; | |||
| import java.io.UnsupportedEncodingException; | |||
| import java.util.logging.Logger; | |||
| public class AESEncrypt { | |||
| private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString()); | |||
| /** | |||
| * 128, 192 or 256 | |||
| */ | |||
| private static final int KEY_SIZE = 256; | |||
| private static final int I_VEC_LEN = 16; | |||
| /** | |||
| * encrypt/decrypt algorithm name | |||
| */ | |||
| private static final String ALGORITHM = "AES"; | |||
| /** | |||
| * algorithm/Mode/padding mode | |||
| */ | |||
| private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding"; | |||
| private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING"; | |||
| private String CIPHER_MODE; | |||
| private static final int RANDOM_LEN = KEY_SIZE / 8; | |||
| private String iVecS = "1111111111111111"; | |||
| private byte[] iVec = iVecS.getBytes("utf-8"); | |||
| public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException { | |||
| if (key == null) { | |||
| LOGGER.severe(Common.addTag("Key is null")); | |||
| return; | |||
| } | |||
| if (key.length != KEY_SIZE / 8) { | |||
| LOGGER.severe(Common.addTag("the length of key is not correct")); | |||
| return; | |||
| } | |||
| if (mode.contains("CBC")) { | |||
| CIPHER_MODE = CIPHER_MODE_CBC; | |||
| } else if (mode.contains("CTR")) { | |||
| CIPHER_MODE = CIPHER_MODE_CTR; | |||
| } else { | |||
| return; | |||
| } | |||
| if (iVecIn == null || iVecIn.length != I_VEC_LEN) { | |||
| return; | |||
| } | |||
| iVec = iVecIn; | |||
| } | |||
| public byte[] encrypt(byte[] key, byte[] data) throws Exception { | |||
| SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM); | |||
| Cipher cipher = Cipher.getInstance(CIPHER_MODE); | |||
| IvParameterSpec iv = new IvParameterSpec(iVec); | |||
| cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv); | |||
| byte[] encrypted = cipher.doFinal(data); | |||
| String encryptResultStr = BaseUtil.byte2HexString(encrypted); | |||
| return encrypted; | |||
| } | |||
| public byte[] encryptCTR(byte[] key, byte[] data) throws Exception { | |||
| SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM); | |||
| Cipher cipher = Cipher.getInstance(CIPHER_MODE); | |||
| IvParameterSpec iv = new IvParameterSpec(iVec); | |||
| cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv); | |||
| byte[] encrypted = cipher.doFinal(data); | |||
| return encrypted; | |||
| } | |||
| public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception { | |||
| SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM); | |||
| Cipher cipher = Cipher.getInstance(CIPHER_MODE); | |||
| IvParameterSpec iv = new IvParameterSpec(iVec); | |||
| cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv); | |||
| byte[] origin = cipher.doFinal(encryptData); | |||
| return origin; | |||
| } | |||
| } | |||
| @@ -0,0 +1,143 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * <p> | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * <p> | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * <p> | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import java.io.UnsupportedEncodingException; | |||
| import java.math.BigInteger; | |||
| import java.nio.charset.Charset; | |||
| import java.util.ArrayList; | |||
| import java.util.List; | |||
| public class BaseUtil { | |||
| private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; | |||
| public BaseUtil() { | |||
| } | |||
| public static String byte2HexString(byte[] bytes) { | |||
| if (null == bytes) { | |||
| return null; | |||
| } else if (bytes.length == 0) { | |||
| return ""; | |||
| } else { | |||
| char[] chars = new char[bytes.length * 2]; | |||
| for (int i = 0; i < bytes.length; ++i) { | |||
| int b = bytes[i]; | |||
| chars[i * 2] = HEX_DIGITS[(b & 240) >> 4]; | |||
| chars[i * 2 + 1] = HEX_DIGITS[b & 15]; | |||
| } | |||
| return new String(chars); | |||
| } | |||
| } | |||
| public static byte[] hexString2ByteArray(String str) { | |||
| int length = str.length() / 2; | |||
| byte[] bytes = new byte[length]; | |||
| byte[] source = str.getBytes(Charset.forName("UTF-8")); | |||
| for (int i = 0; i < bytes.length; ++i) { | |||
| byte bh = Byte.decode("0x" + new String(new byte[]{source[i * 2]}, Charset.forName("UTF-8"))); | |||
| bh = (byte) (bh << 4); | |||
| byte bl = Byte.decode("0x" + new String(new byte[]{source[i * 2 + 1]}, Charset.forName("UTF-8"))); | |||
| bytes[i] = (byte) (bh ^ bl); | |||
| } | |||
| return bytes; | |||
| } | |||
| public static BigInteger byteArray2BigInteger(byte[] bytes) { | |||
| BigInteger bigInteger = BigInteger.ZERO; | |||
| for (int i = 0; i < bytes.length; ++i) { | |||
| int intI = bytes[i]; | |||
| if (intI < 0) { | |||
| intI = intI + 256; | |||
| } | |||
| BigInteger bi = new BigInteger(String.valueOf(intI)); | |||
| bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi); | |||
| } | |||
| return bigInteger; | |||
| } | |||
| public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException { | |||
| StringBuilder res = new StringBuilder(); | |||
| byte[] bytes = String.valueOf(str).getBytes("UTF-8"); | |||
| BigInteger bigInteger = BigInteger.ZERO; | |||
| for (int i = 0; i < str.length(); ++i) { | |||
| BigInteger bi = new BigInteger(String.valueOf(bytes[i])); | |||
| bigInteger = bigInteger.multiply(BigInteger.valueOf(256)).add(bi); | |||
| } | |||
| return bigInteger; | |||
| } | |||
| public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException { | |||
| StringBuilder res = new StringBuilder(); | |||
| List<Integer> lists = new ArrayList<>(); | |||
| BigInteger bi = bigInteger; | |||
| BigInteger DIV = BigInteger.valueOf(256); | |||
| while (bi.compareTo(BigInteger.ZERO) > 0) { | |||
| lists.add(bi.mod(DIV).intValue()); | |||
| bi = bi.divide(DIV); | |||
| } | |||
| for (int i = lists.size() - 1; i >= 0; --i) { | |||
| res.append((char) (int) (lists.get(i))); | |||
| } | |||
| return res.toString(); | |||
| } | |||
| public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException { | |||
| List<Integer> lists = new ArrayList<>(); | |||
| BigInteger bi = bigInteger; | |||
| BigInteger DIV = BigInteger.valueOf(256); | |||
| while (bi.compareTo(BigInteger.ZERO) > 0) { | |||
| lists.add(bi.mod(DIV).intValue()); | |||
| bi = bi.divide(DIV); | |||
| } | |||
| byte[] res = new byte[lists.size()]; | |||
| for (int i = lists.size() - 1; i >= 0; --i) { | |||
| res[lists.size() - i - 1] = ((byte) (int) (lists.get(i))); | |||
| } | |||
| return res; | |||
| } | |||
| public static byte[] integer2byteArray(Integer num) { | |||
| List<Integer> lists = new ArrayList<>(); | |||
| Integer bi = num; | |||
| Integer DIV = 256; | |||
| while (bi > 0) { | |||
| lists.add(bi % DIV); | |||
| bi = bi / DIV; | |||
| } | |||
| byte[] res = new byte[lists.size()]; | |||
| for (int i = lists.size() - 1; i >= 0; --i) { | |||
| res[lists.size() - i - 1] = ((byte) (int) (lists.get(i))); | |||
| } | |||
| return res; | |||
| } | |||
| public static Integer byteArray2Integer(byte[] bytes) { | |||
| Integer num = 0; | |||
| for (int i = 0; i < bytes.length; ++i) { | |||
| int intI = bytes[i]; | |||
| if (intI < 0) { | |||
| intI = intI + 256; | |||
| } | |||
| num = num * 256 + intI; | |||
| } | |||
| return num; | |||
| } | |||
| } | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.Common; | |||
| import com.mindspore.flclient.FLClientStatus; | |||
| import com.mindspore.flclient.FLCommunication; | |||
| import com.mindspore.flclient.FLParameter; | |||
| import com.mindspore.flclient.LocalFLParameter; | |||
| import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; | |||
| import com.mindspore.flclient.cipher.struct.EncryptShare; | |||
| import com.mindspore.flclient.cipher.struct.NewArray; | |||
| import mindspore.schema.GetClientList; | |||
| import mindspore.schema.ResponseCode; | |||
| import mindspore.schema.ReturnClientList; | |||
| import java.nio.ByteBuffer; | |||
| import java.time.LocalDateTime; | |||
| import java.util.Arrays; | |||
| import java.util.List; | |||
| import java.util.Map; | |||
| import java.util.logging.Logger; | |||
| import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN; | |||
| public class ClientListReq { | |||
| private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString()); | |||
| private FLCommunication flCommunication; | |||
| private String nextRequestTime; | |||
| private FLParameter flParameter = FLParameter.getInstance(); | |||
| private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); | |||
| public ClientListReq() { | |||
| flCommunication = FLCommunication.getInstance(); | |||
| } | |||
| public String getNextRequestTime() { | |||
| return nextRequestTime; | |||
| } | |||
| public void setNextRequestTime(String nextRequestTime) { | |||
| this.nextRequestTime = nextRequestTime; | |||
| } | |||
| public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) { | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============getClientList url: " + url + "==============")); | |||
| FlatBufferBuilder builder = new FlatBufferBuilder(); | |||
| int id = builder.createString(localFLParameter.getFlID()); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = builder.createString(dateTime); | |||
| int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time); | |||
| builder.finish(clientListRoot); | |||
| byte[] msg = builder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer); | |||
| FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys); | |||
| return status; | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] the size of clients: " + bufData.clientsLength())); | |||
| FLClientStatus status; | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetClientList success")); | |||
| u3ClientList.clear(); | |||
| int clientSize = bufData.clientsLength(); | |||
| for (int i = 0; i < clientSize; i++) { | |||
| String curFlId = bufData.clients(i); | |||
| u3ClientList.add(curFlId); | |||
| } | |||
| try { | |||
| decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys); | |||
| } catch (Exception e) { | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.SucNotReady): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!")); | |||
| return FLClientStatus.WAIT; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReturnClientList is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception { | |||
| decryptSecretsList.clear(); | |||
| int size = returnShareList.size(); | |||
| for (int i = 0; i < size; i++) { | |||
| DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets(); | |||
| EncryptShare encryptShare = returnShareList.get(i); | |||
| String vFlID = encryptShare.getFlID(); | |||
| byte[] share = encryptShare.getShare().getArray(); | |||
| byte[] iVecIn = new byte[IVEC_LEN]; | |||
| AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC"); | |||
| byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share); | |||
| int sSize = (int) decryptShare[0]; | |||
| int bSize = (int) decryptShare[1]; | |||
| int sIndexLen = (int) decryptShare[2]; | |||
| int bIndexLen = (int) decryptShare[3]; | |||
| int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen)); | |||
| int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen)); | |||
| byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize); | |||
| byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize); | |||
| NewArray<byte[]> sSkVu = new NewArray<>(); | |||
| sSkVu.setSize(sSize); | |||
| sSkVu.setArray(sSkUv); | |||
| NewArray bVu = new NewArray(); | |||
| bVu.setSize(bSize); | |||
| bVu.setArray(bUv); | |||
| decryptShareSecrets.setFlID(vFlID); | |||
| decryptShareSecrets.setSSkVu(sSkVu); | |||
| decryptShareSecrets.setBVu(bVu); | |||
| decryptShareSecrets.setSIndex(sIndex); | |||
| decryptShareSecrets.setIndexB(bIndex); | |||
| decryptSecretsList.add(decryptShareSecrets); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import org.bouncycastle.crypto.digests.SHA256Digest; | |||
| import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator; | |||
| import org.bouncycastle.crypto.params.KeyParameter; | |||
| import org.bouncycastle.math.ec.rfc7748.X25519; | |||
| import java.security.SecureRandom; | |||
| import java.util.logging.Logger; | |||
| public class KEYAgreement { | |||
| private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString()); | |||
| private static final int PBKDF2_ITERATIONS = 10000; | |||
| private static final int SALT_SIZE = 32; | |||
| private static final int HASH_BIT_SIZE = 256; | |||
| private static final int KEY_LEN = X25519.SCALAR_SIZE; | |||
| private SecureRandom random = new SecureRandom(); | |||
| public byte[] generatePrivateKey() { | |||
| byte[] privateKey = new byte[KEY_LEN]; | |||
| X25519.generatePrivateKey(random, privateKey); | |||
| return privateKey; | |||
| } | |||
| public byte[] generatePublicKey(byte[] privatekey) { | |||
| byte[] publicKey = new byte[KEY_LEN]; | |||
| X25519.generatePublicKey(privatekey, 0, publicKey, 0); | |||
| return publicKey; | |||
| } | |||
| public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) { | |||
| byte[] secret = new byte[KEY_LEN]; | |||
| X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0); | |||
| return secret; | |||
| } | |||
| public byte[] getEncryptedPassword(byte[] password, byte[] salt) { | |||
| byte[] saltB = new byte[SALT_SIZE]; | |||
| PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest()); | |||
| gen.init(password, saltB, PBKDF2_ITERATIONS); | |||
| byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey(); | |||
| return dk; | |||
| } | |||
| } | |||
| @@ -0,0 +1,82 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import java.security.NoSuchAlgorithmException; | |||
| import java.security.SecureRandom; | |||
| import java.util.List; | |||
| import java.util.logging.Logger; | |||
| public class Random { | |||
| /** | |||
| * random generate RNG algorithm name | |||
| */ | |||
| private static final Logger LOGGER = Logger.getLogger(Random.class.toString()); | |||
| private static final String RNG_ALGORITHM = "SHA1PRNG"; | |||
| private static final int RANDOM_LEN = 128 / 8; | |||
| public void getRandomBytes(byte[] secret) { | |||
| try { | |||
| SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG"); | |||
| secureRandom.nextBytes(secret); | |||
| } catch (NoSuchAlgorithmException e) { | |||
| e.printStackTrace(); | |||
| } | |||
| } | |||
| public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception { | |||
| int intV = Integer.SIZE / 8; | |||
| int size = length * intV; | |||
| byte[] data = new byte[size]; | |||
| for (int i = 0; i < size; i++) { | |||
| data[i] = 0; | |||
| } | |||
| byte[] ivec = new byte[RANDOM_LEN]; | |||
| AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR"); | |||
| byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data); | |||
| for (int i = 0; i < length; i++) { | |||
| int[] sub = new int[intV]; | |||
| for (int j = 0; j < 4; j++) { | |||
| sub[j] = (int) encryptCtr[i * intV + j] & 0xff; | |||
| } | |||
| int subI = byte2int(sub, 4); | |||
| Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE); | |||
| noise.add(f); | |||
| } | |||
| } | |||
| public static int byte2int(int[] data, int n) { | |||
| switch (n) { | |||
| case 1: | |||
| return (int) data[0]; | |||
| case 2: | |||
| return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00); | |||
| case 3: | |||
| return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000); | |||
| case 4: | |||
| return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000) | |||
| | (data[3] << 24 & 0xff000000); | |||
| default: | |||
| return 0; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import com.google.flatbuffers.FlatBufferBuilder; | |||
| import com.mindspore.flclient.Common; | |||
| import com.mindspore.flclient.FLClientStatus; | |||
| import com.mindspore.flclient.FLCommunication; | |||
| import com.mindspore.flclient.FLParameter; | |||
| import com.mindspore.flclient.LocalFLParameter; | |||
| import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; | |||
| import mindspore.schema.ClientShare; | |||
| import mindspore.schema.ResponseCode; | |||
| import java.nio.ByteBuffer; | |||
| import java.time.LocalDateTime; | |||
| import java.util.List; | |||
| import java.util.logging.Logger; | |||
| public class ReconstructSecretReq { | |||
| private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString()); | |||
| private FLCommunication flCommunication; | |||
| private String nextRequestTime; | |||
| private FLParameter flParameter = FLParameter.getInstance(); | |||
| private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); | |||
| public String getNextRequestTime() { | |||
| return nextRequestTime; | |||
| } | |||
| public void setNextRequestTime(String nextRequestTime) { | |||
| this.nextRequestTime = nextRequestTime; | |||
| } | |||
| public ReconstructSecretReq() { | |||
| flCommunication = FLCommunication.getInstance(); | |||
| } | |||
| public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) { | |||
| String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ==============sendReconstructSecret url: " + url + "==============")); | |||
| FlatBufferBuilder builder = new FlatBufferBuilder(); | |||
| int desFlId = builder.createString(localFLParameter.getFlID()); | |||
| String dateTime = LocalDateTime.now().toString(); | |||
| int time = builder.createString(dateTime); | |||
| int shareSecretsSize = decryptShareSecretsList.size(); | |||
| if (shareSecretsSize <= 0) { | |||
| LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite.")); | |||
| return FLClientStatus.FAILED; | |||
| } else { | |||
| int[] decryptShareList = new int[shareSecretsSize]; | |||
| for (int i = 0; i < shareSecretsSize; i++) { | |||
| DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i); | |||
| String srcFlId = decryptShareSecrets.getFlID(); | |||
| byte[] share; | |||
| int index; | |||
| if (u3ClientList.contains(srcFlId)) { | |||
| share = decryptShareSecrets.getBVu().getArray(); | |||
| index = decryptShareSecrets.getIndexB(); | |||
| } else { | |||
| share = decryptShareSecrets.getSSkVu().getArray(); | |||
| index = decryptShareSecrets.getSIndex(); | |||
| } | |||
| int fbsSrcFlId = builder.createString(srcFlId); | |||
| int fbsShare = ClientShare.createShareVector(builder, share); | |||
| int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index); | |||
| decryptShareList[i] = clientShare; | |||
| } | |||
| int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList); | |||
| int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time); | |||
| builder.finish(reconstructSecretRoot); | |||
| byte[] msg = builder.sizedByteArray(); | |||
| try { | |||
| byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg); | |||
| ByteBuffer buffer = ByteBuffer.wrap(responseData); | |||
| mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer); | |||
| FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp); | |||
| return status; | |||
| } catch (Exception e) { | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct")); | |||
| e.printStackTrace(); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| } | |||
| public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) { | |||
| int retcode = bufData.retcode(); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************")); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retcode)); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] reason: " + bufData.reason())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] current iteration in server: " + bufData.iteration())); | |||
| LOGGER.info(Common.addTag("[PairWiseMask] next request time: " + bufData.nextReqTime())); | |||
| switch (retcode) { | |||
| case (ResponseCode.SUCCEED): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success")); | |||
| return FLClientStatus.SUCCESS; | |||
| case (ResponseCode.OutOfTime): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again")); | |||
| setNextRequestTime(bufData.nextReqTime()); | |||
| return FLClientStatus.RESTART; | |||
| case (ResponseCode.RequestError): | |||
| case (ResponseCode.SystemError): | |||
| LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets")); | |||
| return FLClientStatus.FAILED; | |||
| default: | |||
| LOGGER.severe(Common.addTag("[PairWiseMask] the return <retcode> from server in ReconstructSecret is invalid: " + retcode)); | |||
| return FLClientStatus.FAILED; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,136 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher; | |||
| import com.mindspore.flclient.Common; | |||
| import java.math.BigInteger; | |||
| import java.util.Random; | |||
| import java.util.logging.Logger; | |||
| public class ShareSecrets { | |||
| private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString()); | |||
| public final class SecretShare { | |||
| public SecretShare(final int num, final BigInteger share) { | |||
| this.num = num; | |||
| this.share = share; | |||
| } | |||
| public int getNum() { | |||
| return num; | |||
| } | |||
| public BigInteger getShare() { | |||
| return share; | |||
| } | |||
| @Override | |||
| public String toString() { | |||
| return "SecretShare [num=" + num + ", share=" + share + "]"; | |||
| } | |||
| private final int num; | |||
| private final BigInteger share; | |||
| } | |||
| public ShareSecrets(final int k, final int n) { | |||
| this.k = k; | |||
| this.n = n; | |||
| random = new Random(); | |||
| } | |||
| public SecretShare[] split(final byte[] bytes, byte[] primeByte) { | |||
| BigInteger secret = BaseUtil.byteArray2BigInteger(bytes); | |||
| final int modLength = secret.bitLength() + 1; | |||
| prime = BaseUtil.byteArray2BigInteger(primeByte); | |||
| final BigInteger[] coeff = new BigInteger[k - 1]; | |||
| LOGGER.info(Common.addTag("Prime Number: " + prime)); | |||
| for (int i = 0; i < k - 1; i++) { | |||
| coeff[i] = randomZp(prime); | |||
| LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i])); | |||
| } | |||
| final SecretShare[] shares = new SecretShare[n]; | |||
| for (int i = 1; i <= n; i++) { | |||
| BigInteger accum = secret; | |||
| for (int j = 1; j < k; j++) { | |||
| final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime); | |||
| final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime); | |||
| accum = accum.add(t2).mod(prime); | |||
| } | |||
| shares[i - 1] = new SecretShare(i, accum); | |||
| LOGGER.info(Common.addTag("Share " + shares[i - 1])); | |||
| } | |||
| return shares; | |||
| } | |||
| public BigInteger getPrime() { | |||
| return prime; | |||
| } | |||
| public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) { | |||
| BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte); | |||
| BigInteger accum = BigInteger.ZERO; | |||
| for (int j = 0; j < k; j++) { | |||
| BigInteger num = BigInteger.ONE; | |||
| BigInteger den = BigInteger.ONE; | |||
| BigInteger tmp; | |||
| for (int m = 0; m < k; m++) { | |||
| if (j != m) { | |||
| num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum); | |||
| tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1)); | |||
| tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum); | |||
| den = den.multiply(tmp).mod(primeNum); | |||
| } | |||
| } | |||
| final BigInteger value = shares[j].getShare(); | |||
| tmp = den.modInverse(primeNum); | |||
| tmp = tmp.multiply(num).mod(primeNum); | |||
| tmp = tmp.multiply(value).mod(primeNum); | |||
| accum = accum.add(tmp).mod(primeNum); | |||
| LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum)); | |||
| } | |||
| LOGGER.info(Common.addTag("The secret is: " + accum)); | |||
| return accum; | |||
| } | |||
| private BigInteger randomZp(final BigInteger p) { | |||
| while (true) { | |||
| final BigInteger r = new BigInteger(p.bitLength(), random); | |||
| if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) { | |||
| return r; | |||
| } | |||
| } | |||
| } | |||
| private BigInteger prime; | |||
| private final int k; | |||
| private final int n; | |||
| private final Random random; | |||
| private final int SECRET_MAX_LEN = 32; | |||
| } | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher.struct; | |||
| public class ClientPublicKey { | |||
| private String flID; | |||
| private NewArray<byte[]> cPK; | |||
| private NewArray<byte[]> sPk; | |||
| public String getFlID() { | |||
| return flID; | |||
| } | |||
| public void setFlID(String flID) { | |||
| this.flID = flID; | |||
| } | |||
| public NewArray<byte[]> getCPK() { | |||
| return cPK; | |||
| } | |||
| public void setCPK(NewArray<byte[]> cPK) { | |||
| this.cPK = cPK; | |||
| } | |||
| public NewArray<byte[]> getSPK() { | |||
| return sPk; | |||
| } | |||
| public void setSPK(NewArray<byte[]> sPk) { | |||
| this.sPk = sPk; | |||
| } | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher.struct; | |||
| public class DecryptShareSecrets { | |||
| private String flID; | |||
| private NewArray<byte[]> sSkVu; | |||
| private NewArray<byte[]> bVu; | |||
| private int sIndex; | |||
| private int indexB; | |||
| public String getFlID() { | |||
| return flID; | |||
| } | |||
| public void setFlID(String flID) { | |||
| this.flID = flID; | |||
| } | |||
| public NewArray<byte[]> getSSkVu() { | |||
| return sSkVu; | |||
| } | |||
| public void setSSkVu(NewArray<byte[]> sSkVu) { | |||
| this.sSkVu = sSkVu; | |||
| } | |||
| public NewArray<byte[]> getBVu() { | |||
| return bVu; | |||
| } | |||
| public void setBVu(NewArray<byte[]> bVu) { | |||
| this.bVu = bVu; | |||
| } | |||
| public int getSIndex() { | |||
| return sIndex; | |||
| } | |||
| public void setSIndex(int sIndex) { | |||
| this.sIndex = sIndex; | |||
| } | |||
| public int getIndexB() { | |||
| return indexB; | |||
| } | |||
| public void setIndexB(int indexB) { | |||
| this.indexB = indexB; | |||
| } | |||
| } | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher.struct; | |||
| public class EncryptShare { | |||
| private String flID; | |||
| private NewArray<byte[]> share; | |||
| public String getFlID() { | |||
| return flID; | |||
| } | |||
| public void setFlID(String flID) { | |||
| this.flID = flID; | |||
| } | |||
| public NewArray<byte[]> getShare() { | |||
| return share; | |||
| } | |||
| public void setShare(NewArray<byte[]> share) { | |||
| this.share = share; | |||
| } | |||
| } | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher.struct; | |||
| public class NewArray<T> { | |||
| private int size; | |||
| private T array; | |||
| public int getSize() { | |||
| return size; | |||
| } | |||
| public void setSize(int size) { | |||
| this.size = size; | |||
| } | |||
| public T getArray() { | |||
| return array; | |||
| } | |||
| public void setArray(T array) { | |||
| this.array = array; | |||
| } | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| package com.mindspore.flclient.cipher.struct; | |||
| public class ShareSecret { | |||
| private String flID; | |||
| private NewArray<byte[]> share; | |||
| private int index; | |||
| public String getFlID() { | |||
| return flID; | |||
| } | |||
| public void setFlID(String flID) { | |||
| this.flID = flID; | |||
| } | |||
| public NewArray<byte[]> getShare() { | |||
| return share; | |||
| } | |||
| public void setShare(NewArray<byte[]> share) { | |||
| this.share = share; | |||
| } | |||
| public int getIndex() { | |||
| return index; | |||
| } | |||
| public void setIndex(int index) { | |||
| this.index = index; | |||
| } | |||
| } | |||