Browse Source

Fix load failed while loading oversize checkpoint file with encyption

pull/15996/head
liuluobin 4 years ago
parent
commit
0c93fa6bd7
3 changed files with 19 additions and 25 deletions
  1. +13
    -21
      mindspore/ccsrc/crypto/crypto.cc
  2. +1
    -1
      mindspore/ccsrc/crypto/crypto.h
  3. +5
    -3
      mindspore/train/serialization.py

+ 13
- 21
mindspore/ccsrc/crypto/crypto.cc View File

@@ -66,20 +66,20 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *


bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
Byte **cipher_data, int32_t *cipher_len) { Byte **cipher_data, int32_t *cipher_len) {
// Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data
// encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
Byte buf[4]; Byte buf[4];
memcpy(buf, encrypt_data, 4);
memcpy_s(buf, 4, encrypt_data, 4);
*iv_len = ByteToint(buf); *iv_len = ByteToint(buf);
memcpy(buf, encrypt_data + *iv_len + 4, 4);
memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
*cipher_len = ByteToint(buf); *cipher_len = ByteToint(buf);
if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
MS_LOG(ERROR) << "Failed to parse encrypt data."; MS_LOG(ERROR) << "Failed to parse encrypt data.";
return false; return false;
} }
*iv = new Byte[*iv_len]; *iv = new Byte[*iv_len];
memcpy(*iv, encrypt_data + 4, *iv_len);
memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
*cipher_data = new Byte[*cipher_len]; *cipher_data = new Byte[*cipher_len];
memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len);
memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
return true; return true;
} }


@@ -152,18 +152,15 @@ EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key,


bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
const int32_t key_len, const std::string &enc_mode) { const int32_t key_len, const std::string &enc_mode) {
// Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length +
// iv length + iv + plain text length + cipher text length + cipher text"
int32_t cipher_len = 0; // cipher length
int32_t cipher_len = 0;


int32_t iv_len = AES_BLOCK_SIZE; int32_t iv_len = AES_BLOCK_SIZE;
Byte *iv = new Byte[iv_len]; Byte *iv = new Byte[iv_len];
RAND_bytes(iv, sizeof(Byte) * iv_len); RAND_bytes(iv, sizeof(Byte) * iv_len);


Byte *iv_cpy = new Byte[16]; Byte *iv_cpy = new Byte[16];
memcpy(iv_cpy, iv, 16);
memcpy_s(iv_cpy, 16, iv, 16);


// set the encryption length
int32_t ret = 0; int32_t ret = 0;
int32_t flen = 0; int32_t flen = 0;
std::string alg_mode; std::string alg_mode;
@@ -193,7 +190,7 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da
EVP_CIPHER_CTX_free(ctx); EVP_CIPHER_CTX_free(ctx);


int64_t cur = 0; int64_t cur = 0;
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接
*encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;


memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
cur += 4; cur += 4;
@@ -212,8 +209,6 @@ bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_da


bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
const int32_t key_len, const std::string &dec_mode) { const int32_t key_len, const std::string &dec_mode) {
// Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv +
// plain text data length + cipher text data length + cipher text data"
std::string alg_mode; std::string alg_mode;
std::string work_mode; std::string work_mode;


@@ -221,7 +216,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
return false; return false;
} }


// 解析加密数据
int32_t iv_len = 0; int32_t iv_len = 0;
int32_t cipher_len = 0; int32_t cipher_len = 0;
Byte *iv = NULL; Byte *iv = NULL;
@@ -236,7 +230,6 @@ bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, co
return false; return false;
} }


// 解密密文
int ret = 0; int ret = 0;
int mlen = 0; int mlen = 0;


@@ -276,7 +269,7 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
*encrypt_len = 0; *encrypt_len = 0;
while (cur_pos < plain_len) { while (cur_pos < plain_len) {
int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
memcpy(block_buf, plain_data + cur_pos, cur_block_size);
memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);


if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
delete[] block_buf; delete[] block_buf;
@@ -284,9 +277,9 @@ Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, B
delete[] encrypt_data; delete[] encrypt_data;
MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
} }
memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
*encrypt_len += sizeof(int32_t); *encrypt_len += sizeof(int32_t);
memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len);
memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
*encrypt_len += block_enc_len; *encrypt_len += block_enc_len;
cur_pos += cur_block_size; cur_pos += cur_block_size;
} }
@@ -300,7 +293,6 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
Byte *decrypt_data = nullptr; Byte *decrypt_data = nullptr;
char *block_buf = new char[MAX_BLOCK_SIZE * 2]; char *block_buf = new char[MAX_BLOCK_SIZE * 2];
char *int_buf = new char[4]; char *int_buf = new char[4];
// Byte *decrypt_block_buf = new Byte[100];
Byte *decrypt_block_buf = nullptr; Byte *decrypt_block_buf = nullptr;
int32_t decrypt_block_len; int32_t decrypt_block_len;


@@ -325,7 +317,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
} }
fid.read(int_buf, sizeof(int32_t)); fid.read(int_buf, sizeof(int32_t));


int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
fid.read(block_buf, sizeof(char) * block_size); fid.read(block_buf, sizeof(char) * block_size);
if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key, if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
key_len, dec_mode))) { key_len, dec_mode))) {
@@ -334,7 +326,7 @@ Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *
delete[] decrypt_data; delete[] decrypt_data;
MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
} }
memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len);
memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len);
*decrypt_len += decrypt_block_len; *decrypt_len += decrypt_block_len;
} }
fid.close(); fid.close();


+ 1
- 1
mindspore/ccsrc/crypto/crypto.h View File

@@ -33,7 +33,7 @@ typedef unsigned char Byte;


namespace mindspore { namespace mindspore {
namespace crypto { namespace crypto {
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB
const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment, units is Byte
const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number


Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,


+ 5
- 3
mindspore/train/serialization.py View File

@@ -168,10 +168,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
f.write(checkpoint_list.SerializeToString()) f.write(checkpoint_list.SerializeToString())
else: else:
plain_data += checkpoint_list.SerializeToString() plain_data += checkpoint_list.SerializeToString()
while len(plain_data) >= SLICE_SIZE * 1024:
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key,

max_block_size = SLICE_SIZE*1024
while len(plain_data) >= max_block_size:
cipher_data += _encrypt(plain_data[0: max_block_size], max_block_size, enc_key,
len(enc_key), enc_mode) len(enc_key), enc_mode)
plain_data = plain_data[SLICE_SIZE*1024:]
plain_data = plain_data[max_block_size:]


if enc_key is not None: if enc_key is not None:
if plain_data: if plain_data:


Loading…
Cancel
Save