/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "crypto/crypto.h" namespace mindspore { namespace crypto { int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; } Byte *intToByte(const int32_t &n) { Byte *byte = new Byte[4]; memset(byte, 0, sizeof(Byte) * 4); byte[0] = (Byte)(0xFF & n); byte[1] = (Byte)((0xFF00 & n) >> 8); byte[2] = (Byte)((0xFF0000 & n) >> 16); byte[3] = (Byte)((0xFF000000 & n) >> 24); return byte; } int32_t ByteToint(const Byte *byteArray) { int32_t res = byteArray[0] & 0xFF; res |= ((byteArray[1] << 8) & 0xFF00); res |= ((byteArray[2] << 16) & 0xFF0000); res += ((byteArray[3] << 24) & 0xFF000000); return res; } bool IsCipherFile(std::string file_path) { char *int_buf = new char[4]; int flag = 0; std::ifstream fid(file_path, std::ios::in | std::ios::binary); if (!fid) { MS_LOG(ERROR) << "Open file failed"; exit(-1); } fid.read(int_buf, sizeof(int32_t)); fid.close(); flag = ByteToint(reinterpret_cast(int_buf)); delete[] int_buf; return flag == MAGIC_NUM; } #if defined(_WIN32) Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) { MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; } Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, const std::string &dec_mode) { MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; } #else bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, Byte **cipher_data, int32_t *cipher_len) { // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data Byte buf[4]; memcpy_s(buf, 4, encrypt_data, 4); *iv_len = ByteToint(buf); memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4); *cipher_len = ByteToint(buf); if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { MS_LOG(ERROR) << "Failed to parse encrypt data."; return false; } *iv = new Byte[*iv_len]; memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len); *cipher_data = new Byte[*cipher_len]; memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len); return true; } bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) { std::smatch results; std::regex re("([A-Z]{3})-([A-Z]{3})"); if (!std::regex_match(mode.c_str(), re)) { MS_LOG(ERROR) << "Mode " << mode << " is invalid."; return false; } std::regex_search(mode, results, re); *alg_mode = results[1]; *work_mode = results[2]; return true; } EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv, int flag) { int ret = 0; EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); if (work_mode != "GCM" && work_mode != "CBC") { MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid."; return nullptr; } const EVP_CIPHER *(*funcPtr)() = nullptr; if (work_mode == "GCM") { switch (key_len) { case 16: funcPtr = EVP_aes_128_gcm; break; case 24: funcPtr = EVP_aes_192_gcm; break; case 32: funcPtr = EVP_aes_256_gcm; break; default: MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; } } else if (work_mode == "CBC") { switch (key_len) { case 16: funcPtr = EVP_aes_128_cbc; break; case 24: funcPtr = EVP_aes_192_cbc; break; case 32: funcPtr = EVP_aes_256_cbc; break; default: MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; } } if (flag == 0) { ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); } else if (flag == 1) { ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); } if (ret != 1) { MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; return nullptr; } if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1); return ctx; } bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) { int32_t cipher_len = 0; int32_t iv_len = AES_BLOCK_SIZE; Byte *iv = new Byte[iv_len]; RAND_bytes(iv, sizeof(Byte) * iv_len); Byte *iv_cpy = new Byte[16]; memcpy_s(iv_cpy, 16, iv, 16); int32_t ret = 0; int32_t flen = 0; std::string alg_mode; std::string work_mode; if (!ParseMode(enc_mode, &alg_mode, &work_mode)) { return false; } auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0); if (ctx == nullptr) { MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; return false; } Byte *cipher_data; cipher_data = new Byte[plain_len + 16]; ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len); if (ret != 1) { MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; delete[] cipher_data; return false; } if (work_mode == "CBC") { EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen); cipher_len += flen; } EVP_CIPHER_CTX_free(ctx); int64_t cur = 0; *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); cur += 4; memcpy(encrypt_data + cur, intToByte(iv_len), 4); cur += 4; memcpy(encrypt_data + cur, iv_cpy, iv_len); cur += iv_len; memcpy(encrypt_data + cur, intToByte(cipher_len), 4); cur += 4; memcpy(encrypt_data + cur, cipher_data, cipher_len); *encrypt_data_len += 4; delete[] cipher_data; return true; } bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, const int32_t key_len, const std::string &dec_mode) { std::string alg_mode; std::string work_mode; if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { return false; } int32_t iv_len = 0; int32_t cipher_len = 0; Byte *iv = NULL; Byte *cipher_data = NULL; if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) { return false; } *plain_data = new Byte[cipher_len + 16]; if (*plain_data == NULL) { MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string."; return false; } int ret = 0; int mlen = 0; auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1); if (ctx == nullptr) { MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; return false; } ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len); if (ret != 1) { MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; return false; } if (work_mode == "CBC") { ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen); if (ret != 1) { MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; return false; } *plain_len += mlen; } delete[] iv; delete[] cipher_data; EVP_CIPHER_CTX_free(ctx); return true; } Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, const std::string &enc_mode) { int64_t cur_pos = 0; int64_t block_enc_len = 0; int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100; Byte *encrypt_data = new Byte[encrypt_buf_len]; Byte *block_buf = new Byte[MAX_BLOCK_SIZE]; Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100]; *encrypt_len = 0; while (cur_pos < plain_len) { int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size); if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { delete[] block_buf; delete[] block_enc_buf; delete[] encrypt_data; MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; } memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); *encrypt_len += sizeof(int32_t); memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len); *encrypt_len += block_enc_len; cur_pos += cur_block_size; } delete[] block_buf; delete[] block_enc_buf; return encrypt_data; } Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, const std::string &dec_mode) { Byte *decrypt_data = nullptr; char *block_buf = new char[MAX_BLOCK_SIZE * 2]; char *int_buf = new char[4]; Byte *decrypt_block_buf = nullptr; int32_t decrypt_block_len; std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary); if (!fid) { MS_LOG(ERROR) << "Open file failed"; exit(-1); } fid.seekg(0, std::ios_base::end); int64_t file_size = fid.tellg(); fid.clear(); fid.seekg(0); decrypt_data = new Byte[file_size]; *decrypt_len = 0; while (fid.tellg() < file_size) { fid.read(int_buf, sizeof(int32_t)); int cipher_flag = ByteToint(reinterpret_cast(int_buf)); if (cipher_flag != MAGIC_NUM) { MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path << "\"is not an encrypted file and cannot be decrypted"; } fid.read(int_buf, sizeof(int32_t)); int32_t block_size = ByteToint(reinterpret_cast(int_buf)); fid.read(block_buf, sizeof(char) * block_size); if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast(block_buf), block_size, key, key_len, dec_mode))) { delete[] block_buf; delete[] int_buf; delete[] decrypt_data; MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; } memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len); *decrypt_len += decrypt_block_len; } fid.close(); delete[] block_buf; delete[] int_buf; return decrypt_data; } #endif } // namespace crypto } // namespace mindspore