| @@ -5,10 +5,14 @@ else() | |||||
| set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz") | set(REQ_URL "https://github.com/openssl/openssl/archive/refs/tags/OpenSSL_1_1_1k.tar.gz") | ||||
| set(MD5 "bdd51a68ad74618dd2519da8e0bcc759") | set(MD5 "bdd51a68ad74618dd2519da8e0bcc759") | ||||
| endif() | endif() | ||||
| mindspore_add_pkg(openssl | |||||
| VER 1.1.0 | |||||
| LIBS ssl crypto | |||||
| URL ${REQ_URL} | |||||
| MD5 ${MD5} | |||||
| CONFIGURE_COMMAND ./config no-zlib no-shared) | |||||
| include_directories(${openssl_INC}) | |||||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||||
| mindspore_add_pkg(openssl | |||||
| VER 1.1.0 | |||||
| LIBS ssl crypto | |||||
| URL ${REQ_URL} | |||||
| MD5 ${MD5} | |||||
| CONFIGURE_COMMAND ./config no-zlib no-shared) | |||||
| include_directories(${openssl_INC}) | |||||
| add_library(mindspore::ssl ALIAS openssl::ssl) | |||||
| add_library(mindspore::crypto ALIAS openssl::crypto) | |||||
| endif() | |||||
| @@ -226,6 +226,7 @@ set(SUB_COMP | |||||
| pipeline/jit | pipeline/jit | ||||
| pipeline/pynative | pipeline/pynative | ||||
| common debug pybind_api utils vm profiler ps | common debug pybind_api utils vm profiler ps | ||||
| crypto | |||||
| ) | ) | ||||
| foreach(_comp ${SUB_COMP}) | foreach(_comp ${SUB_COMP}) | ||||
| @@ -0,0 +1,6 @@ | |||||
| file(GLOB_RECURSE _CRYPTO_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | |||||
| add_library(_mindspore_crypto_obj OBJECT ${_CRYPTO_SRC_FILES}) | |||||
| if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") | |||||
| target_link_libraries(_mindspore_crypto_obj mindspore::crypto) | |||||
| endif() | |||||
| @@ -0,0 +1,347 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "crypto/crypto.h" | |||||
| namespace mindspore { | |||||
| namespace crypto { | |||||
| int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; } | |||||
| Byte *intToByte(const int32_t &n) { | |||||
| Byte *byte = new Byte[4]; | |||||
| memset(byte, 0, sizeof(Byte) * 4); | |||||
| byte[0] = (Byte)(0xFF & n); | |||||
| byte[1] = (Byte)((0xFF00 & n) >> 8); | |||||
| byte[2] = (Byte)((0xFF0000 & n) >> 16); | |||||
| byte[3] = (Byte)((0xFF000000 & n) >> 24); | |||||
| return byte; | |||||
| } | |||||
| int32_t ByteToint(const Byte *byteArray) { | |||||
| int32_t res = byteArray[0] & 0xFF; | |||||
| res |= ((byteArray[1] << 8) & 0xFF00); | |||||
| res |= ((byteArray[2] << 16) & 0xFF0000); | |||||
| res += ((byteArray[3] << 24) & 0xFF000000); | |||||
| return res; | |||||
| } | |||||
| bool IsCipherFile(std::string file_path) { | |||||
| char *int_buf = new char[4]; | |||||
| int flag = 0; | |||||
| std::ifstream fid(file_path, std::ios::in | std::ios::binary); | |||||
| if (!fid) { | |||||
| MS_LOG(ERROR) << "Open file failed"; | |||||
| exit(-1); | |||||
| } | |||||
| fid.read(int_buf, sizeof(int32_t)); | |||||
| fid.close(); | |||||
| flag = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||||
| delete[] int_buf; | |||||
| return flag == MAGIC_NUM; | |||||
| } | |||||
| #if defined(_WIN32) | |||||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||||
| const std::string &enc_mode) { | |||||
| MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; | |||||
| } | |||||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||||
| const std::string &dec_mode) { | |||||
| MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform."; | |||||
| } | |||||
| #else | |||||
| bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len, | |||||
| Byte **cipher_data, int32_t *cipher_len) { | |||||
| // Encrypt data is organized in order to iv_len, iv, cipher_len, cipher_data | |||||
| Byte buf[4]; | |||||
| memcpy(buf, encrypt_data, 4); | |||||
| *iv_len = ByteToint(buf); | |||||
| memcpy(buf, encrypt_data + *iv_len + 4, 4); | |||||
| *cipher_len = ByteToint(buf); | |||||
| if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) { | |||||
| MS_LOG(ERROR) << "Failed to parse encrypt data."; | |||||
| return false; | |||||
| } | |||||
| *iv = new Byte[*iv_len]; | |||||
| memcpy(*iv, encrypt_data + 4, *iv_len); | |||||
| *cipher_data = new Byte[*cipher_len]; | |||||
| memcpy(*cipher_data, encrypt_data + *iv_len + 8, *cipher_len); | |||||
| return true; | |||||
| } | |||||
| bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) { | |||||
| std::smatch results; | |||||
| std::regex re("([A-Z]{3})-([A-Z]{3})"); | |||||
| if (!std::regex_match(mode.c_str(), re)) { | |||||
| MS_LOG(ERROR) << "Mode " << mode << " is invalid."; | |||||
| return false; | |||||
| } | |||||
| std::regex_search(mode, results, re); | |||||
| *alg_mode = results[1]; | |||||
| *work_mode = results[2]; | |||||
| return true; | |||||
| } | |||||
| EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv, | |||||
| int flag) { | |||||
| int ret = 0; | |||||
| EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); | |||||
| if (work_mode != "GCM" && work_mode != "CBC") { | |||||
| MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid."; | |||||
| return nullptr; | |||||
| } | |||||
| const EVP_CIPHER *(*funcPtr)() = nullptr; | |||||
| if (work_mode == "GCM") { | |||||
| switch (key_len) { | |||||
| case 16: | |||||
| funcPtr = EVP_aes_128_gcm; | |||||
| break; | |||||
| case 24: | |||||
| funcPtr = EVP_aes_192_gcm; | |||||
| break; | |||||
| case 32: | |||||
| funcPtr = EVP_aes_256_gcm; | |||||
| break; | |||||
| default: | |||||
| MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; | |||||
| } | |||||
| } else if (work_mode == "CBC") { | |||||
| switch (key_len) { | |||||
| case 16: | |||||
| funcPtr = EVP_aes_128_cbc; | |||||
| break; | |||||
| case 24: | |||||
| funcPtr = EVP_aes_192_cbc; | |||||
| break; | |||||
| case 32: | |||||
| funcPtr = EVP_aes_256_cbc; | |||||
| break; | |||||
| default: | |||||
| MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << "."; | |||||
| } | |||||
| } | |||||
| if (flag == 0) { | |||||
| ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv); | |||||
| } else if (flag == 1) { | |||||
| ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv); | |||||
| } | |||||
| if (ret != 1) { | |||||
| MS_LOG(ERROR) << "EVP_EncryptInit_ex failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1); | |||||
| return ctx; | |||||
| } | |||||
| bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key, | |||||
| const int32_t key_len, const std::string &enc_mode) { | |||||
| // Encrypted according to enc_key and enc_mode, the format of the returned encrypted data block is "total length + | |||||
| // iv length + iv + plain text length + cipher text length + cipher text" | |||||
| int32_t cipher_len = 0; // cipher length | |||||
| int32_t iv_len = AES_BLOCK_SIZE; | |||||
| Byte *iv = new Byte[iv_len]; | |||||
| RAND_bytes(iv, sizeof(Byte) * iv_len); | |||||
| Byte *iv_cpy = new Byte[16]; | |||||
| memcpy(iv_cpy, iv, 16); | |||||
| // set the encryption length | |||||
| int32_t ret = 0; | |||||
| int32_t flen = 0; | |||||
| std::string alg_mode; | |||||
| std::string work_mode; | |||||
| if (!ParseMode(enc_mode, &alg_mode, &work_mode)) { | |||||
| return false; | |||||
| } | |||||
| auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0); | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; | |||||
| return false; | |||||
| } | |||||
| Byte *cipher_data; | |||||
| cipher_data = new Byte[plain_len + 16]; | |||||
| ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len); | |||||
| if (ret != 1) { | |||||
| MS_LOG(ERROR) << "EVP_EncryptUpdate failed"; | |||||
| delete[] cipher_data; | |||||
| return false; | |||||
| } | |||||
| if (work_mode == "CBC") { | |||||
| EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen); | |||||
| cipher_len += flen; | |||||
| } | |||||
| EVP_CIPHER_CTX_free(ctx); | |||||
| int64_t cur = 0; | |||||
| *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len; // 按iv长度、iv、明文长度、密文长度、密文进行拼接 | |||||
| memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4); | |||||
| cur += 4; | |||||
| memcpy(encrypt_data + cur, intToByte(iv_len), 4); | |||||
| cur += 4; | |||||
| memcpy(encrypt_data + cur, iv_cpy, iv_len); | |||||
| cur += iv_len; | |||||
| memcpy(encrypt_data + cur, intToByte(cipher_len), 4); | |||||
| cur += 4; | |||||
| memcpy(encrypt_data + cur, cipher_data, cipher_len); | |||||
| *encrypt_data_len += 4; | |||||
| delete[] cipher_data; | |||||
| return true; | |||||
| } | |||||
| bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key, | |||||
| const int32_t key_len, const std::string &dec_mode) { | |||||
| // Decrypt according to dec_key and dec_mode, the format of the encrypted data block is "iv length + iv + | |||||
| // plain text data length + cipher text data length + cipher text data" | |||||
| std::string alg_mode; | |||||
| std::string work_mode; | |||||
| if (!ParseMode(dec_mode, &alg_mode, &work_mode)) { | |||||
| return false; | |||||
| } | |||||
| // 解析加密数据 | |||||
| int32_t iv_len = 0; | |||||
| int32_t cipher_len = 0; | |||||
| Byte *iv = NULL; | |||||
| Byte *cipher_data = NULL; | |||||
| if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) { | |||||
| return false; | |||||
| } | |||||
| *plain_data = new Byte[cipher_len + 16]; | |||||
| if (*plain_data == NULL) { | |||||
| MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string."; | |||||
| return false; | |||||
| } | |||||
| // 解密密文 | |||||
| int ret = 0; | |||||
| int mlen = 0; | |||||
| auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1); | |||||
| if (ctx == nullptr) { | |||||
| MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX."; | |||||
| return false; | |||||
| } | |||||
| ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len); | |||||
| if (ret != 1) { | |||||
| MS_LOG(ERROR) << "EVP_DecryptUpdate failed"; | |||||
| return false; | |||||
| } | |||||
| if (work_mode == "CBC") { | |||||
| ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen); | |||||
| if (ret != 1) { | |||||
| MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed"; | |||||
| return false; | |||||
| } | |||||
| *plain_len += mlen; | |||||
| } | |||||
| delete[] iv; | |||||
| delete[] cipher_data; | |||||
| EVP_CIPHER_CTX_free(ctx); | |||||
| return true; | |||||
| } | |||||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||||
| const std::string &enc_mode) { | |||||
| int64_t cur_pos = 0; | |||||
| int64_t block_enc_len = 0; | |||||
| int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100; | |||||
| Byte *encrypt_data = new Byte[encrypt_buf_len]; | |||||
| Byte *block_buf = new Byte[MAX_BLOCK_SIZE]; | |||||
| Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100]; | |||||
| *encrypt_len = 0; | |||||
| while (cur_pos < plain_len) { | |||||
| int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos); | |||||
| memcpy(block_buf, plain_data + cur_pos, cur_block_size); | |||||
| if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) { | |||||
| delete[] block_buf; | |||||
| delete[] block_enc_buf; | |||||
| delete[] encrypt_data; | |||||
| MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid."; | |||||
| } | |||||
| memcpy(encrypt_data + *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t)); | |||||
| *encrypt_len += sizeof(int32_t); | |||||
| memcpy(encrypt_data + *encrypt_len, block_enc_buf, block_enc_len); | |||||
| *encrypt_len += block_enc_len; | |||||
| cur_pos += cur_block_size; | |||||
| } | |||||
| delete[] block_buf; | |||||
| delete[] block_enc_buf; | |||||
| return encrypt_data; | |||||
| } | |||||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||||
| const std::string &dec_mode) { | |||||
| Byte *decrypt_data = nullptr; | |||||
| char *block_buf = new char[MAX_BLOCK_SIZE * 2]; | |||||
| char *int_buf = new char[4]; | |||||
| // Byte *decrypt_block_buf = new Byte[100]; | |||||
| Byte *decrypt_block_buf = nullptr; | |||||
| int32_t decrypt_block_len; | |||||
| std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary); | |||||
| if (!fid) { | |||||
| MS_LOG(ERROR) << "Open file failed"; | |||||
| exit(-1); | |||||
| } | |||||
| fid.seekg(0, std::ios_base::end); | |||||
| int64_t file_size = fid.tellg(); | |||||
| fid.clear(); | |||||
| fid.seekg(0); | |||||
| decrypt_data = new Byte[file_size]; | |||||
| *decrypt_len = 0; | |||||
| while (fid.tellg() < file_size) { | |||||
| fid.read(int_buf, sizeof(int32_t)); | |||||
| int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||||
| if (cipher_flag != MAGIC_NUM) { | |||||
| MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path | |||||
| << "\"is not an encrypted file and cannot be decrypted"; | |||||
| } | |||||
| fid.read(int_buf, sizeof(int32_t)); | |||||
| int64_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf)); | |||||
| fid.read(block_buf, sizeof(char) * block_size); | |||||
| if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key, | |||||
| key_len, dec_mode))) { | |||||
| delete[] block_buf; | |||||
| delete[] int_buf; | |||||
| delete[] decrypt_data; | |||||
| MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid"; | |||||
| } | |||||
| memcpy(decrypt_data, decrypt_block_buf, decrypt_block_len); | |||||
| *decrypt_len += decrypt_block_len; | |||||
| } | |||||
| fid.close(); | |||||
| delete[] block_buf; | |||||
| delete[] int_buf; | |||||
| return decrypt_data; | |||||
| } | |||||
| #endif | |||||
| } // namespace crypto | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_H | |||||
| #define MINDSPORE_CCSRC_CRYPTO_CRYPTO_H | |||||
| #if not defined(_WIN32) | |||||
| #include <openssl/aes.h> | |||||
| #include <openssl/evp.h> | |||||
| #include <openssl/rand.h> | |||||
| #endif | |||||
| #include <stdio.h> | |||||
| #include <fstream> | |||||
| #include <string> | |||||
| #include <regex> | |||||
| #include "utils/log_adapter.h" | |||||
| typedef unsigned char Byte; | |||||
| namespace mindspore { | |||||
| namespace crypto { | |||||
| const int MAX_BLOCK_SIZE = 512 * 1024 * 1024; // Maximum ciphertext segment 512MB | |||||
| const unsigned int MAGIC_NUM = 0x7F3A5ED8; // Magic number | |||||
| Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len, | |||||
| const std::string &enc_mode); | |||||
| Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len, | |||||
| const std::string &dec_mode); | |||||
| bool IsCipherFile(const std::string file_path); | |||||
| } // namespace crypto | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "crypto/crypto_pybind.h" | |||||
| namespace mindspore { | |||||
| namespace crypto { | |||||
| py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode) { | |||||
| int64_t encrypt_len; | |||||
| char *encrypt_data; | |||||
| encrypt_data = reinterpret_cast<char *>(Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len, | |||||
| reinterpret_cast<Byte *>(key), key_len, enc_mode)); | |||||
| return py::bytes(encrypt_data, encrypt_len); | |||||
| } | |||||
| py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode) { | |||||
| int64_t decrypt_len; | |||||
| char *decrypt_data; | |||||
| decrypt_data = reinterpret_cast<char *>( | |||||
| Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode)); | |||||
| return py::bytes(decrypt_data, decrypt_len); | |||||
| } | |||||
| bool PyIsCipherFile(std::string file_path) { return IsCipherFile(file_path); } | |||||
| } // namespace crypto | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H | |||||
| #define MINDSPORE_CCSRC_CRYPTO_CRYPTO_PYBIND_H | |||||
| #include "crypto/crypto.h" | |||||
| #include <pybind11/pybind11.h> | |||||
| #include <string> | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| namespace crypto { | |||||
| py::bytes PyEncrypt(char *plain_data, const int64_t plain_len, char *key, const int32_t key_len, std::string enc_mode); | |||||
| py::bytes PyDecrypt(std::string encrypt_data_path, char *key, const int32_t key_len, std::string dec_mode); | |||||
| bool PyIsCipherFile(std::string file_path); | |||||
| } // namespace crypto | |||||
| } // namespace mindspore | |||||
| #endif | |||||
| @@ -28,6 +28,7 @@ | |||||
| #include "utils/mpi/mpi_config.h" | #include "utils/mpi/mpi_config.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/costmodel_context.h" | #include "frontend/parallel/costmodel_context.h" | ||||
| #include "crypto/crypto_pybind.h" | |||||
| #ifdef ENABLE_GPU_COLLECTIVE | #ifdef ENABLE_GPU_COLLECTIVE | ||||
| #include "runtime/device/gpu/distribution/collective_init.h" | #include "runtime/device/gpu/distribution/collective_init.h" | ||||
| #else | #else | ||||
| @@ -330,4 +331,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); | .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); | ||||
| (void)m.def("_encrypt", &mindspore::crypto::PyEncrypt, "Encrypt the data."); | |||||
| (void)m.def("_decrypt", &mindspore::crypto::PyDecrypt, "Decrypt the data."); | |||||
| (void)m.def("_is_cipher_file", &mindspore::crypto::PyIsCipherFile, "Determine whether the file is encrypted"); | |||||
| } | } | ||||
| @@ -82,6 +82,10 @@ class CheckpointConfig: | |||||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. | async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. | ||||
| saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation | saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation | ||||
| with the network in training, the initial value of saved_network will be saved. Default: None. | with the network in training, the initial value of saved_network will be saved. Default: None. | ||||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||||
| is not required. Default: None. | |||||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||||
| Raises: | Raises: | ||||
| ValueError: If the input_param is None or 0. | ValueError: If the input_param is None or 0. | ||||
| @@ -126,7 +130,9 @@ class CheckpointConfig: | |||||
| keep_checkpoint_per_n_minutes=0, | keep_checkpoint_per_n_minutes=0, | ||||
| integrated_save=True, | integrated_save=True, | ||||
| async_save=False, | async_save=False, | ||||
| saved_network=None): | |||||
| saved_network=None, | |||||
| enc_key=None, | |||||
| enc_mode='AES-GCM'): | |||||
| if save_checkpoint_steps is not None: | if save_checkpoint_steps is not None: | ||||
| save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) | save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) | ||||
| @@ -160,6 +166,8 @@ class CheckpointConfig: | |||||
| self._integrated_save = Validator.check_bool(integrated_save) | self._integrated_save = Validator.check_bool(integrated_save) | ||||
| self._async_save = Validator.check_bool(async_save) | self._async_save = Validator.check_bool(async_save) | ||||
| self._saved_network = saved_network | self._saved_network = saved_network | ||||
| self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||||
| self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||||
| @property | @property | ||||
| def save_checkpoint_steps(self): | def save_checkpoint_steps(self): | ||||
| @@ -196,6 +204,16 @@ class CheckpointConfig: | |||||
| """Get the value of _saved_network""" | """Get the value of _saved_network""" | ||||
| return self._saved_network | return self._saved_network | ||||
| @property | |||||
| def enc_key(self): | |||||
| """Get the value of _enc_key""" | |||||
| return self._enc_key | |||||
| @property | |||||
| def enc_mode(self): | |||||
| """Get the value of _enc_mode""" | |||||
| return self._enc_mode | |||||
| def get_checkpoint_policy(self): | def get_checkpoint_policy(self): | ||||
| """Get the policy of checkpoint.""" | """Get the policy of checkpoint.""" | ||||
| checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, | checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, | ||||
| @@ -355,7 +373,7 @@ class ModelCheckpoint(Callback): | |||||
| network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network | network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network | ||||
| save_checkpoint(network, cur_file, self._config.integrated_save, | save_checkpoint(network, cur_file, self._config.integrated_save, | ||||
| self._config.async_save) | |||||
| self._config.async_save, self._config.enc_key, self._config.enc_mode) | |||||
| self._latest_ckpt_file_name = cur_file | self._latest_ckpt_file_name = cur_file | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Model and parameters serialization.""" | """Model and parameters serialization.""" | ||||
| import os | import os | ||||
| import sys | import sys | ||||
| import stat | import stat | ||||
| import math | import math | ||||
| @@ -40,7 +41,7 @@ from mindspore._checkparam import check_input_data, Validator | |||||
| from mindspore.compression.export import quant_export | from mindspore.compression.export import quant_export | ||||
| from mindspore.parallel._tensor import _load_tensor | from mindspore.parallel._tensor import _load_tensor | ||||
| from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices | from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices | ||||
| from .._c_expression import load_mindir | |||||
| from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file | |||||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | ||||
| @@ -120,14 +121,19 @@ def _update_param(param, new_param): | |||||
| param.set_data(type(param.data)(new_param.data)) | param.set_data(type(param.data)(new_param.data)) | ||||
| def _exec_save(ckpt_file_name, data_list): | |||||
| def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): | |||||
| """Execute the process of saving checkpoint into file.""" | """Execute the process of saving checkpoint into file.""" | ||||
| try: | try: | ||||
| MAX_BLOCK_SIZE = 1024*1024*512 | |||||
| with _ckpt_mutex: | with _ckpt_mutex: | ||||
| if os.path.exists(ckpt_file_name): | if os.path.exists(ckpt_file_name): | ||||
| os.remove(ckpt_file_name) | os.remove(ckpt_file_name) | ||||
| with open(ckpt_file_name, "ab") as f: | with open(ckpt_file_name, "ab") as f: | ||||
| if enc_key is not None: | |||||
| plain_data = bytes(0) | |||||
| cipher_data = bytes(0) | |||||
| for name, value in data_list.items(): | for name, value in data_list.items(): | ||||
| data_size = value[2].nbytes / 1024 | data_size = value[2].nbytes / 1024 | ||||
| if data_size > SLICE_SIZE: | if data_size > SLICE_SIZE: | ||||
| @@ -145,7 +151,19 @@ def _exec_save(ckpt_file_name, data_list): | |||||
| param_tensor.tensor_type = value[1] | param_tensor.tensor_type = value[1] | ||||
| param_tensor.tensor_content = param_slice.tobytes() | param_tensor.tensor_content = param_slice.tobytes() | ||||
| f.write(checkpoint_list.SerializeToString()) | |||||
| if enc_key is None: | |||||
| f.write(checkpoint_list.SerializeToString()) | |||||
| else: | |||||
| plain_data += checkpoint_list.SerializeToString() | |||||
| while len(plain_data) >= MAX_BLOCK_SIZE: | |||||
| cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key, | |||||
| len(enc_key), enc_mode) | |||||
| plain_data = plain_data[MAX_BLOCK_SIZE:] | |||||
| if enc_key is not None: | |||||
| if plain_data: | |||||
| cipher_data += _encrypt(plain_data, len(plain_data), enc_key, len(enc_key), enc_mode) | |||||
| f.write(cipher_data) | |||||
| os.chmod(ckpt_file_name, stat.S_IRUSR) | os.chmod(ckpt_file_name, stat.S_IRUSR) | ||||
| @@ -154,7 +172,7 @@ def _exec_save(ckpt_file_name, data_list): | |||||
| raise e | raise e | ||||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): | |||||
| def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"): | |||||
| """ | """ | ||||
| Saves checkpoint info to a specified file. | Saves checkpoint info to a specified file. | ||||
| @@ -166,6 +184,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||||
| ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ||||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | ||||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | ||||
| enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption | |||||
| is not required. Default: None. | |||||
| enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption | |||||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||||
| Raises: | Raises: | ||||
| TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter | TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter | ||||
| @@ -176,6 +198,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | ||||
| integrated_save = Validator.check_bool(integrated_save) | integrated_save = Validator.check_bool(integrated_save) | ||||
| async_save = Validator.check_bool(async_save) | async_save = Validator.check_bool(async_save) | ||||
| enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) | |||||
| enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) | |||||
| logger.info("Execute the process of saving checkpoint files.") | logger.info("Execute the process of saving checkpoint files.") | ||||
| @@ -218,10 +242,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||||
| data_list[key].append(data) | data_list[key].append(data) | ||||
| if async_save: | if async_save: | ||||
| thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt") | |||||
| thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list, enc_key, enc_mode), name="asyn_save_ckpt") | |||||
| thr.start() | thr.start() | ||||
| else: | else: | ||||
| _exec_save(ckpt_file_name, data_list) | |||||
| _exec_save(ckpt_file_name, data_list, enc_key, enc_mode) | |||||
| logger.info("Saving checkpoint process is finished.") | logger.info("Saving checkpoint process is finished.") | ||||
| @@ -278,7 +302,7 @@ def load(file_name): | |||||
| return graph | return graph | ||||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): | |||||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): | |||||
| """ | """ | ||||
| Loads checkpoint info from a specified file. | Loads checkpoint info from a specified file. | ||||
| @@ -289,6 +313,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||||
| in the param_dict into net with the same suffix. Default: False | in the param_dict into net with the same suffix. Default: False | ||||
| filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix | filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix | ||||
| will not be loaded. Default: None. | will not be loaded. Default: None. | ||||
| dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption | |||||
| is not required. Default: None. | |||||
| dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption | |||||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||||
| Returns: | Returns: | ||||
| Dict, key is parameter name, value is a Parameter. | Dict, key is parameter name, value is a Parameter. | ||||
| @@ -303,15 +331,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | ||||
| """ | """ | ||||
| ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) | ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) | ||||
| dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) | |||||
| dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) | |||||
| logger.info("Execute the process of loading checkpoint files.") | logger.info("Execute the process of loading checkpoint files.") | ||||
| checkpoint_list = Checkpoint() | checkpoint_list = Checkpoint() | ||||
| try: | try: | ||||
| with open(ckpt_file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| if dec_key is None: | |||||
| with open(ckpt_file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| else: | |||||
| pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) | |||||
| checkpoint_list.ParseFromString(pb_content) | checkpoint_list.ParseFromString(pb_content) | ||||
| except BaseException as e: | except BaseException as e: | ||||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) | |||||
| if _is_cipher_file(ckpt_file_name): | |||||
| logger.error("Failed to read the checkpoint file `%s`. The file may be encrypted, please pass in the " | |||||
| "dec_key.", ckpt_file_name) | |||||
| else: | |||||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", \ | |||||
| ckpt_file_name) | |||||
| raise ValueError(e.__str__()) | raise ValueError(e.__str__()) | ||||
| parameter_dict = {} | parameter_dict = {} | ||||
| @@ -1075,7 +1113,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| return merged_parameter | return merged_parameter | ||||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): | |||||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, dec_key=None, dec_mode='AES-GCM'): | |||||
| """ | """ | ||||
| Load checkpoint into net for distributed predication. | Load checkpoint into net for distributed predication. | ||||
| @@ -1088,6 +1126,10 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | ||||
| it means that the predication process just uses single device. | it means that the predication process just uses single device. | ||||
| Default: None. | Default: None. | ||||
| dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption | |||||
| is not required. Default: None. | |||||
| dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption | |||||
| mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. | |||||
| Raises: | Raises: | ||||
| TypeError: The type of inputs do not match the requirements. | TypeError: The type of inputs do not match the requirements. | ||||
| @@ -1106,6 +1148,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | f"dev_matrix (list[int]), tensor_map (list[int]), " | ||||
| f"param_split_shape (list[int]) and field_size (zero).") | f"param_split_shape (list[int]) and field_size (zero).") | ||||
| dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) | |||||
| dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) | |||||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | ||||
| _train_strategy = build_searched_strategy(train_strategy_filename) | _train_strategy = build_searched_strategy(train_strategy_filename) | ||||
| train_strategy = _convert_to_list(_train_strategy) | train_strategy = _convert_to_list(_train_strategy) | ||||
| @@ -1128,7 +1173,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| param_rank = rank_list[param.name][0] | param_rank = rank_list[param.name][0] | ||||
| skip_merge_split = rank_list[param.name][1] | skip_merge_split = rank_list[param.name][1] | ||||
| for rank in param_rank: | for rank in param_rank: | ||||
| sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name] | |||||
| sliced_param = load_checkpoint(checkpoint_filenames[rank], dec_key=dec_key, dec_mode=dec_mode)[param.name] | |||||
| sliced_params.append(sliced_param) | sliced_params.append(sliced_param) | ||||
| if skip_merge_split: | if skip_merge_split: | ||||
| split_param = sliced_params[0] | split_param = sliced_params[0] | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """test callback function.""" | """test callback function.""" | ||||
| import os | import os | ||||
| import platform | |||||
| import stat | import stat | ||||
| from unittest import mock | from unittest import mock | ||||
| @@ -246,6 +247,43 @@ def test_checkpoint_save_ckpt_seconds(): | |||||
| ckpt_cb2.step_end(run_context) | ckpt_cb2.step_end(run_context) | ||||
| def test_checkpoint_save_ckpt_with_encryption(): | |||||
| """Test checkpoint save ckpt with encryption.""" | |||||
| train_config = CheckpointConfig( | |||||
| save_checkpoint_steps=16, | |||||
| save_checkpoint_seconds=0, | |||||
| keep_checkpoint_max=5, | |||||
| keep_checkpoint_per_n_minutes=0, | |||||
| enc_key=os.urandom(16), | |||||
| enc_mode="AES-GCM") | |||||
| ckpt_cb = ModelCheckpoint(config=train_config) | |||||
| cb_params = _InternalCallbackParam() | |||||
| net = Net() | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| network_ = WithLossCell(net, loss) | |||||
| _train_network = TrainOneStepCell(network_, optim) | |||||
| cb_params.train_network = _train_network | |||||
| cb_params.epoch_num = 10 | |||||
| cb_params.cur_epoch_num = 5 | |||||
| cb_params.cur_step_num = 160 | |||||
| cb_params.batch_num = 32 | |||||
| run_context = RunContext(cb_params) | |||||
| ckpt_cb.begin(run_context) | |||||
| ckpt_cb.step_end(run_context) | |||||
| ckpt_cb2 = ModelCheckpoint(config=train_config) | |||||
| cb_params.cur_epoch_num = 1 | |||||
| cb_params.cur_step_num = 15 | |||||
| if platform.system().lower() == "windows": | |||||
| with pytest.raises(NotImplementedError): | |||||
| ckpt_cb2.begin(run_context) | |||||
| ckpt_cb2.step_end(run_context) | |||||
| else: | |||||
| ckpt_cb2.begin(run_context) | |||||
| ckpt_cb2.step_end(run_context) | |||||
| def test_CallbackManager(): | def test_CallbackManager(): | ||||
| """TestCallbackManager.""" | """TestCallbackManager.""" | ||||
| ck_obj = ModelCheckpoint() | ck_obj = ModelCheckpoint() | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ut for model serialize(save/load)""" | """ut for model serialize(save/load)""" | ||||
| import os | import os | ||||
| import platform | |||||
| import stat | import stat | ||||
| import time | import time | ||||
| @@ -299,6 +300,30 @@ def test_load_checkpoint_empty_file(): | |||||
| load_checkpoint("empty.ckpt") | load_checkpoint("empty.ckpt") | ||||
| def test_save_and_load_checkpoint_for_network_with_encryption(): | |||||
| """ test save and checkpoint for network with encryption""" | |||||
| net = Net() | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True) | |||||
| opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) | |||||
| loss_net = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(loss_net, opt) | |||||
| key = os.urandom(16) | |||||
| mode = "AES-GCM" | |||||
| ckpt_path = "./encrypt_ckpt.ckpt" | |||||
| if platform.system().lower() == "windows": | |||||
| with pytest.raises(NotImplementedError): | |||||
| save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) | |||||
| param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") | |||||
| load_param_into_net(net, param_dict) | |||||
| else: | |||||
| save_checkpoint(train_network, ckpt_file_name=ckpt_path, enc_key=key, enc_mode=mode) | |||||
| param_dict = load_checkpoint(ckpt_path, dec_key=key, dec_mode="AES-GCM") | |||||
| load_param_into_net(net, param_dict) | |||||
| if os.path.exists(ckpt_path): | |||||
| os.remove(ckpt_path) | |||||
| class MYNET(nn.Cell): | class MYNET(nn.Cell): | ||||
| """ NET definition """ | """ NET definition """ | ||||