| @@ -5,10 +5,14 @@ else() | |||
| set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz") | |||
| set(MD5 "bdd51a68ad74618dd2519da8e0bcc759") | |||
| endif() | |||
| mindspore_add_pkg(openssl | |||
| VER 1.1.0 | |||
| LIBS ssl crypto | |||
| URL ${REQ_URL} | |||
| MD5 ${MD5} | |||
| CONFIGURE_COMMAND ./config no-zlib no-shared) | |||
| include_directories(${openssl_INC}) | |||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||
| mindspore_add_pkg(openssl | |||
| VER 1.1.0 | |||
| LIBS ssl crypto | |||
| URL ${REQ_URL} | |||
| MD5 ${MD5} | |||
| CONFIGURE_COMMAND ./config no-zlib no-shared) | |||
| include_directories(${openssl_INC}) | |||
| add_library(mindspore::ssl ALIAS openssl::ssl) | |||
| add_library(mindspore::crypto ALIAS openssl::crypto) | |||
| endif() | |||
| @@ -226,6 +226,7 @@ set(SUB_COMP | |||
| pipeline/jit | |||
| pipeline/pynative | |||
| common debug pybind_api utils vm profiler ps | |||
| crypto | |||
| ) | |||
| foreach(_comp ${SUB_COMP}) | |||
| @@ -0,0 +1,6 @@ | |||
| file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||
| add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES}) | |||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||
| target_link_libraries(_mindspore_crypto_obj mindspore::crypto) | |||
| endif() | |||
| @@ -0,0 +1,347 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "crypto/crypto.h" | |||
| namespace mindspore { | |||
| namespace crypto { | |||
| int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; } | |||
| Byte *intToByte(const int32_t &n) { | |||
| Byte *byte = new Byte[4]; | |||
| memset(byte, 0, sizeof(Byte) * 4); | |||
| byte[0] = (Byte)(0xFF & n); | |||
| byte[1] = (Byte)((0xFF00 & n) >> 8); | |||
| byte[2] = (Byte)((0xFF0000 & n) >> 16); | |||
| byte[3] = (Byte)((0xFF000000 & n) >> 24); | |||
| return byte; | |||
| } | |||
| int32_t ByteToint(const Byte *byteArray) { | |||
| int32_t res = byteArray[0] & 0xFF; | |||
| res |= ((byteArray[1] << 8) & 0xFF00); | |||
| res |= ((byteArray[2] << 16) & 0xFF0000); | |||
| res += ((byteArray[3] << 24) & 0xFF000000); | |||
| return res; | |||
| } | |||
| bool IsCipherFile(std::string file_path) { | |||
| char *int_buf = new char[4]; | |||
| int flag = 0; | |||
| std::ifstream fid(file_path, std::ios::in | std::ios::binary); | |||
| if (!fid) { | |||
| MS_LOG(ERROR) << "Open file failed"; | |||
| exit(-1); | |||
| } | |||
| fid.read(int_buf, sizeof(int32_t)); | |||
| fid.close(); | |||
| flag = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||
| delete[] int_buf; | |||
| return flag == MAGIC_NUM; | |||
| } | |||
| #if defined(_WIN32) | |||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||
| const std::string &enc_mode) { | |||
| MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; | |||
| } | |||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||
| const std::string &dec_mode) { | |||
| MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; | |||
| } | |||
| #else | |||
| bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, | |||
| Byte **cipher_data, int32_t *cipher_len) { | |||
| // Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data | |||
| Byte buf[4]; | |||
| memcpy(buf, encrypt_data, 4); | |||
| *iv_len = ByteToint(buf); | |||
| memcpy(buf, encrypt_data + *iv_len + 4, 4); | |||
| *cipher_len = ByteToint(buf); | |||
| if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { | |||
| MS_LOG(ERROR) << "Failed to parse encrypt data."; | |||
| return false; | |||
| } | |||
| *iv = new Byte[*iv_len]; | |||
| memcpy(*iv, encrypt_data + 4, *iv_len); | |||
| *cipher_data = new Byte[*cipher_len]; | |||
| memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len); | |||
| return true; | |||
| } | |||
| bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) { | |||
| std::smatch results; | |||
| std::regex re("([A-Z]{3})-([A-Z]{3})"); | |||
| if (!std::regex_match(mode.c_str(), re)) { | |||
| MS_LOG(ERROR) << "Mode " << mode << " is invalid."; | |||
| return false; | |||
| } | |||
| std::regex_search(mode, results, re); | |||
| *alg_mode = results[1]; | |||
| *work_mode = results[2]; | |||
| return true; | |||
| } | |||
| EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv, | |||
| int flag) { | |||
| int ret = 0; | |||
| EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); | |||
| if (work_mode != "GCM" && work_mode != "CBC") { | |||
| MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid."; | |||
| return nullptr; | |||
| } | |||
| const EVP_CIPHER *(*funcPtr)() = nullptr; | |||
| if (work_mode == "GCM") { | |||
| switch (key_len) { | |||
| case 16: | |||
| funcPtr = EVP_aes_128_gcm; | |||
| break; | |||
| case 24: | |||
| funcPtr = EVP_aes_192_gcm; | |||
| break; | |||
| case 32: | |||
| funcPtr = EVP_aes_256_gcm; | |||
| break; | |||
| default: | |||
| MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; | |||
| } | |||
| } else if (work_mode == "CBC") { | |||
| switch (key_len) { | |||
| case 16: | |||
| funcPtr = EVP_aes_128_cbc; | |||
| break; | |||
| case 24: | |||
| funcPtr = EVP_aes_192_cbc; | |||
| break; | |||
| case 32: | |||
| funcPtr = EVP_aes_256_cbc; | |||
| break; | |||
| default: | |||
| MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; | |||
| } | |||
| } | |||
| if (flag == 0) { | |||
| ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); | |||
| } else if (flag == 1) { | |||
| ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); | |||
| } | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; | |||
| return nullptr; | |||
| } | |||
| if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1); | |||
| return ctx; | |||
| } | |||
| bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, | |||
| const int32_t key_len, const std::string &enc_mode) { | |||
| // Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length + | |||
| // iv length + iv + plain text length + cipher text length + cipher text" | |||
| int32_t cipher_len = 0; // cipher length | |||
| int32_t iv_len = AES_BLOCK_SIZE; | |||
| Byte *iv = new Byte[iv_len]; | |||
| RAND_bytes(iv, sizeof(Byte) * iv_len); | |||
| Byte *iv_cpy = new Byte[16]; | |||
| memcpy(iv_cpy, iv, 16); | |||
| // set the encryption length | |||
| int32_t ret = 0; | |||
| int32_t flen = 0; | |||
| std::string alg_mode; | |||
| std::string work_mode; | |||
| if (!ParseMode(enc_mode, &alg_mode, &work_mode)) { | |||
| return false; | |||
| } | |||
| auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0); | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; | |||
| return false; | |||
| } | |||
| Byte *cipher_data; | |||
| cipher_data = new Byte[plain_len + 16]; | |||
| ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; | |||
| delete[] cipher_data; | |||
| return false; | |||
| } | |||
| if (work_mode == "CBC") { | |||
| EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen); | |||
| cipher_len += flen; | |||
| } | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| int64_t cur = 0; | |||
| *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接 | |||
| memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); | |||
| cur += 4; | |||
| memcpy(encrypt_data + cur, intToByte(iv_len), 4); | |||
| cur += 4; | |||
| memcpy(encrypt_data + cur, iv_cpy, iv_len); | |||
| cur += iv_len; | |||
| memcpy(encrypt_data + cur, intToByte(cipher_len), 4); | |||
| cur += 4; | |||
| memcpy(encrypt_data + cur, cipher_data, cipher_len); | |||
| *encrypt_data_len += 4; | |||
| delete[] cipher_data; | |||
| return true; | |||
| } | |||
| bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, | |||
| const int32_t key_len, const std::string &dec_mode) { | |||
| // Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv + | |||
| // plain text data length + cipher text data length + cipher text data" | |||
| std::string alg_mode; | |||
| std::string work_mode; | |||
| if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { | |||
| return false; | |||
| } | |||
| // 解析加密数据 | |||
| int32_t iv_len = 0; | |||
| int32_t cipher_len = 0; | |||
| Byte *iv = NULL; | |||
| Byte *cipher_data = NULL; | |||
| if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) { | |||
| return false; | |||
| } | |||
| *plain_data = new Byte[cipher_len + 16]; | |||
| if (*plain_data == NULL) { | |||
| MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string."; | |||
| return false; | |||
| } | |||
| // 解密密文 | |||
| int ret = 0; | |||
| int mlen = 0; | |||
| auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1); | |||
| if (ctx == nullptr) { | |||
| MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; | |||
| return false; | |||
| } | |||
| ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; | |||
| return false; | |||
| } | |||
| if (work_mode == "CBC") { | |||
| ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen); | |||
| if (ret != 1) { | |||
| MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; | |||
| return false; | |||
| } | |||
| *plain_len += mlen; | |||
| } | |||
| delete[] iv; | |||
| delete[] cipher_data; | |||
| EVP_CIPHER_CTX_free(ctx); | |||
| return true; | |||
| } | |||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||
| const std::string &enc_mode) { | |||
| int64_t cur_pos = 0; | |||
| int64_t block_enc_len = 0; | |||
| int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100; | |||
| Byte *encrypt_data = new Byte[encrypt_buf_len]; | |||
| Byte *block_buf = new Byte[MAX_BLOCK_SIZE]; | |||
| Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100]; | |||
| *encrypt_len = 0; | |||
| while (cur_pos < plain_len) { | |||
| int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); | |||
| memcpy(block_buf, plain_data + cur_pos, cur_block_size); | |||
| if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { | |||
| delete[] block_buf; | |||
| delete[] block_enc_buf; | |||
| delete[] encrypt_data; | |||
| MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; | |||
| } | |||
| memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); | |||
| *encrypt_len += sizeof(int32_t); | |||
| memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len); | |||
| *encrypt_len += block_enc_len; | |||
| cur_pos += cur_block_size; | |||
| } | |||
| delete[] block_buf; | |||
| delete[] block_enc_buf; | |||
| return encrypt_data; | |||
| } | |||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||
| const std::string &dec_mode) { | |||
| Byte *decrypt_data = nullptr; | |||
| char *block_buf = new char[MAX_BLOCK_SIZE * 2]; | |||
| char *int_buf = new char[4]; | |||
| // Byte *decrypt_block_buf = new Byte[100]; | |||
| Byte *decrypt_block_buf = nullptr; | |||
| int32_t decrypt_block_len; | |||
| std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary); | |||
| if (!fid) { | |||
| MS_LOG(ERROR) << "Open file failed"; | |||
| exit(-1); | |||
| } | |||
| fid.seekg(0, std::ios_base::end); | |||
| int64_t file_size = fid.tellg(); | |||
| fid.clear(); | |||
| fid.seekg(0); | |||
| decrypt_data = new Byte[file_size]; | |||
| *decrypt_len = 0; | |||
| while (fid.tellg() < file_size) { | |||
| fid.read(int_buf, sizeof(int32_t)); | |||
| int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||
| if (cipher_flag != MAGIC_NUM) { | |||
| MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path | |||
| << "\"is not an encrypted file and cannot be decrypted"; | |||
| } | |||
| fid.read(int_buf, sizeof(int32_t)); | |||
| int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||
| fid.read(block_buf, sizeof(char) * block_size); | |||
| if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key, | |||
| key_len, dec_mode))) { | |||
| delete[] block_buf; | |||
| delete[] int_buf; | |||
| delete[] decrypt_data; | |||
| MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; | |||
| } | |||
| memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len); | |||
| *decrypt_len += decrypt_block_len; | |||
| } | |||
| fid.close(); | |||
| delete[] block_buf; | |||
| delete[] int_buf; | |||
| return decrypt_data; | |||
| } | |||
| #endif | |||
| } // namespace crypto | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H | |||
| #define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H | |||
| #if not defined(_WIN32) | |||
| #include <openssl/aes.h> | |||
| #include <openssl/evp.h> | |||
| #include <openssl/rand.h> | |||
| #endif | |||
| #include <stdio.h> | |||
| #include <fstream> | |||
| #include <string> | |||
| #include <regex> | |||
| #include "utils/log_adapter.h" | |||
| typedef unsigned char Byte; | |||
| namespace mindspore { | |||
| namespace crypto { | |||
| const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB | |||
| const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number | |||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||
| const std::string &enc_mode); | |||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||
| const std::string &dec_mode); | |||
| bool IsCipherFile(const std::string file_path); | |||
| } // namespace crypto | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "crypto/crypto_pybind.h" | |||
| namespace mindspore { | |||
| namespace crypto { | |||
| py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) { | |||
| int64_t encrypt_len; | |||
| char *encrypt_data; | |||
| encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len, | |||
| reinterpret_cast<Byte *>(key), key_len, enc_mode)); | |||
| return py::bytes(encrypt_data, encrypt_len); | |||
| } | |||
| py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) { | |||
| int64_t decrypt_len; | |||
| char *decrypt_data; | |||
| decrypt_data = reinterpret_cast<char *>( | |||
| Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode)); | |||
| return py::bytes(decrypt_data, decrypt_len); | |||
| } | |||
| bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); } | |||
| } // namespace crypto | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H | |||
| #define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H | |||
| #include "crypto/crypto.h" | |||
| #include <pybind11/pybind11.h> | |||
| #include <string> | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace crypto { | |||
| py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode); | |||
| py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode); | |||
| bool PyIsCipherFile(std::string file_path); | |||
| } // namespace crypto | |||
| } // namespace mindspore | |||
| #endif | |||
| @@ -28,6 +28,7 @@ | |||
| #include "utils/mpi/mpi_config.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/costmodel_context.h" | |||
| #include "crypto/crypto_pybind.h" | |||
| #ifdef ENABLE_GPU_COLLECTIVE | |||
| #include "runtime/device/gpu/distribution/collective_init.h" | |||
| #else | |||
| @@ -330,4 +331,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | |||
| .def(py::init()) | |||
| .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); | |||
| (void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data."); | |||
| (void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data."); | |||
| (void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted"); | |||
| } | |||
| @@ -82,6 +82,10 @@ class CheckpointConfig: | |||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. | |||
| saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation | |||
| with the network in training, the initial value of saved_network will be saved. Default: None. | |||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||
| is not required. Default: None. | |||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||
| Raises: | |||
| ValueError: If the input_param is None or 0. | |||
| @@ -126,7 +130,9 @@ class CheckpointConfig: | |||
| keep_checkpoint_per_n_minutes=0, | |||
| integrated_save=True, | |||
| async_save=False, | |||
| saved_network=None): | |||
| saved_network=None, | |||
| enc_key=None, | |||
| enc_mode='AES-GCM'): | |||
| if save_checkpoint_steps is not None: | |||
| save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) | |||
| @@ -160,6 +166,8 @@ class CheckpointConfig: | |||
| self._integrated_save = Validator.check_bool(integrated_save) | |||
| self._async_save = Validator.check_bool(async_save) | |||
| self._saved_network = saved_network | |||
| self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||
| self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||
| @property | |||
| def save_checkpoint_steps(self): | |||
| @@ -196,6 +204,16 @@ class CheckpointConfig: | |||
| """Get the value of _saved_network""" | |||
| return self._saved_network | |||
| @property | |||
| def enc_key(self): | |||
| """Get the value of _enc_key""" | |||
| return self._enc_key | |||
| @property | |||
| def enc_mode(self): | |||
| """Get the value of _enc_mode""" | |||
| return self._enc_mode | |||
| def get_checkpoint_policy(self): | |||
| """Get the policy of checkpoint.""" | |||
| checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, | |||
| @@ -355,7 +373,7 @@ class ModelCheckpoint(Callback): | |||
| network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network | |||
| save_checkpoint(network, cur_file, self._config.integrated_save, | |||
| self._config.async_save) | |||
| self._config.async_save, self._config.enc_key, self._config.enc_mode) | |||
| self._latest_ckpt_file_name = cur_file | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Model and parameters serialization.""" | |||
| import os | |||
| import sys | |||
| import stat | |||
| import math | |||
| @@ -40,7 +41,7 @@ from mindspore._checkparam import check_input_data, Validator | |||
| from mindspore.compression.export import quant_export | |||
| from mindspore.parallel._tensor import _load_tensor | |||
| from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices | |||
| from .._c_expression import load_mindir | |||
| from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file | |||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | |||
| @@ -120,14 +121,19 @@ def _update_param(param, new_param): | |||
| param.set_data(type(param.data)(new_param.data)) | |||
| def _exec_save(ckpt_file_name, data_list): | |||
| def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): | |||
| """Execute the process of saving checkpoint into file.""" | |||
| try: | |||
| MAX_BLOCK_SIZE = 1024*1024*512 | |||
| with _ckpt_mutex: | |||
| if os.path.exists(ckpt_file_name): | |||
| os.remove(ckpt_file_name) | |||
| with open(ckpt_file_name, "ab") as f: | |||
| if enc_key is not None: | |||
| plain_data = bytes(0) | |||
| cipher_data = bytes(0) | |||
| for name, value in data_list.items(): | |||
| data_size = value[2].nbytes / 1024 | |||
| if data_size > SLICE_SIZE: | |||
| @@ -145,7 +151,19 @@ def _exec_save(ckpt_file_name, data_list): | |||
| param_tensor.tensor_type = value[1] | |||
| param_tensor.tensor_content = param_slice.tobytes() | |||
| f.write(checkpoint_list.SerializeToString()) | |||
| if enc_key is None: | |||
| f.write(checkpoint_list.SerializeToString()) | |||
| else: | |||
| plain_data += checkpoint_list.SerializeToString() | |||
| while len(plain_data) >= MAX_BLOCK_SIZE: | |||
| cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key, | |||
| len(enc_key), enc_mode) | |||
| plain_data = plain_data[MAX_BLOCK_SIZE:] | |||
| if enc_key is not None: | |||
| if plain_data: | |||
| cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode) | |||
| f.write(cipher_data) | |||
| os.chmod(ckpt_file_name, stat.S_IRUSR) | |||
| @@ -154,7 +172,7 @@ def _exec_save(ckpt_file_name, data_list): | |||
| raise e | |||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): | |||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"): | |||
| """ | |||
| Saves checkpoint info to a specified file. | |||
| @@ -166,6 +184,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | |||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | |||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | |||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||
| is not required. Default: None. | |||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||
| Raises: | |||
| TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter | |||
| @@ -176,6 +198,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | |||
| integrated_save = Validator.check_bool(integrated_save) | |||
| async_save = Validator.check_bool(async_save) | |||
| enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||
| enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||
| logger.info("Execute the process of saving checkpoint files.") | |||
| @@ -218,10 +242,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| data_list[key].append(data) | |||
| if async_save: | |||
| thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") | |||
| thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt") | |||
| thr.start() | |||
| else: | |||
| _exec_save(ckpt_file_name, data_list) | |||
| _exec_save(ckpt_file_name, data_list, enc_key, enc_mode) | |||
| logger.info("Saving checkpoint process is finished.") | |||
| @@ -278,7 +302,7 @@ def load(file_name): | |||
| return graph | |||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): | |||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): | |||
| """ | |||
| Loads checkpoint info from a specified file. | |||
| @@ -289,6 +313,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||
| in the param_dict into net with the same suffix. Default: False | |||
| filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix | |||
| will not be loaded. Default: None. | |||
| dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption | |||
| is not required. Default: None. | |||
| dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption | |||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||
| Returns: | |||
| Dict, key is parameter name, value is a Parameter. | |||
| @@ -303,15 +331,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | |||
| """ | |||
| ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) | |||
| dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) | |||
| dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) | |||
| logger.info("Execute the process of loading checkpoint files.") | |||
| checkpoint_list = Checkpoint() | |||
| try: | |||
| with open(ckpt_file_name, "rb") as f: | |||
| pb_content = f.read() | |||
| if dec_key is None: | |||
| with open(ckpt_file_name, "rb") as f: | |||
| pb_content = f.read() | |||
| else: | |||
| pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) | |||
| checkpoint_list.ParseFromString(pb_content) | |||
| except BaseException as e: | |||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) | |||
| if _is_cipher_file(ckpt_file_name): | |||
| logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the " | |||
| "dec_key.", ckpt_file_name) | |||
| else: | |||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \ | |||
| ckpt_file_name) | |||
| raise ValueError(e.__str__()) | |||
| parameter_dict = {} | |||
| @@ -1075,7 +1113,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||
| return merged_parameter | |||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): | |||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'): | |||
| """ | |||
| Load checkpoint into net for distributed predication. | |||
| @@ -1088,6 +1126,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||
| it means that the predication process just uses single device. | |||
| Default: None. | |||
| dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption | |||
| is not required. Default: None. | |||
| dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption | |||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||
| Raises: | |||
| TypeError: The type of inputs do not match the requirements. | |||
| @@ -1106,6 +1148,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||
| f"param_split_shape (list[int]) and field_size (zero).") | |||
| dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) | |||
| dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) | |||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||
| _train_strategy = build_searched_strategy(train_strategy_filename) | |||
| train_strategy = _convert_to_list(_train_strategy) | |||
| @@ -1128,7 +1173,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| param_rank = rank_list[param.name][0] | |||
| skip_merge_split = rank_list[param.name][1] | |||
| for rank in param_rank: | |||
| sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name] | |||
| sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name] | |||
| sliced_params.append(sliced_param) | |||
| if skip_merge_split: | |||
| split_param = sliced_params[0] | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """test callback function.""" | |||
| import os | |||
| import platform | |||
| import stat | |||
| from unittest import mock | |||
| @@ -246,6 +247,43 @@ def test_checkpoint_save_ckpt_seconds(): | |||
| ckpt_cb2.step_end(run_context) | |||
| def test_checkpoint_save_ckpt_with_encryption(): | |||
| """Test checkpoint save ckpt with encryption.""" | |||
| train_config = CheckpointConfig( | |||
| save_checkpoint_steps=16, | |||
| save_checkpoint_seconds=0, | |||
| keep_checkpoint_max=5, | |||
| keep_checkpoint_per_n_minutes=0, | |||
| enc_key=os.urandom(16), | |||
| enc_mode="AES-GCM") | |||
| ckpt_cb = ModelCheckpoint(config=train_config) | |||
| cb_params = _InternalCallbackParam() | |||
| net = Net() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| network_ = WithLossCell(net, loss) | |||
| _train_network = TrainOneStepCell(network_, optim) | |||
| cb_params.train_network = _train_network | |||
| cb_params.epoch_num = 10 | |||
| cb_params.cur_epoch_num = 5 | |||
| cb_params.cur_step_num = 160 | |||
| cb_params.batch_num = 32 | |||
| run_context = RunContext(cb_params) | |||
| ckpt_cb.begin(run_context) | |||
| ckpt_cb.step_end(run_context) | |||
| ckpt_cb2 = ModelCheckpoint(config=train_config) | |||
| cb_params.cur_epoch_num = 1 | |||
| cb_params.cur_step_num = 15 | |||
| if platform.system().lower() == "windows": | |||
| with pytest.raises(NotImplementedError): | |||
| ckpt_cb2.begin(run_context) | |||
| ckpt_cb2.step_end(run_context) | |||
| else: | |||
| ckpt_cb2.begin(run_context) | |||
| ckpt_cb2.step_end(run_context) | |||
| def test_CallbackManager(): | |||
| """TestCallbackManager.""" | |||
| ck_obj = ModelCheckpoint() | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """ut for model serialize(save/load)""" | |||
| import os | |||
| import platform | |||
| import stat | |||
| import time | |||
| @@ -299,6 +300,30 @@ def test_load_checkpoint_empty_file(): | |||
| load_checkpoint("empty.ckpt") | |||
| def test_save_and_load_checkpoint_for_network_with_encryption(): | |||
| """ test save and checkpoint for network with encryption""" | |||
| net = Net() | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||
| opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) | |||
| loss_net = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(loss_net, opt) | |||
| key = os.urandom(16) | |||
| mode = "AES-GCM" | |||
| ckpt_path = "./encrypt_ckpt.ckpt" | |||
| if platform.system().lower() == "windows": | |||
| with pytest.raises(NotImplementedError): | |||
| save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) | |||
| param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") | |||
| load_param_into_net(net, param_dict) | |||
| else: | |||
| save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) | |||
| param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") | |||
| load_param_into_net(net, param_dict) | |||
| if os.path.exists(ckpt_path): | |||
| os.remove(ckpt_path) | |||
| class MYNET(nn.Cell): | |||
| """ NET definition """ | |||