You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

crypto.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/crypto.h"
  17. #include <regex>
  18. #include <vector>
  19. #include <fstream>
  20. #include <algorithm>
  21. #include "utils/log_adapter.h"
  22. #ifdef ENABLE_OPENSSL
  23. #include <openssl/aes.h>
  24. #include <openssl/evp.h>
  25. #include <openssl/rand.h>
  26. #endif
  27. namespace mindspore {
  28. void IntToByte(std::vector<Byte> *byteArray, int32_t n) {
  29. if (byteArray == nullptr) {
  30. MS_LOG(ERROR) << "byteArray is nullptr";
  31. return;
  32. }
  33. auto ptr = reinterpret_cast<const Byte *>(&n);
  34. (*byteArray).assign(ptr, ptr + sizeof(int32_t));
  35. }
  36. int32_t ByteToInt(const Byte *byteArray, size_t length) {
  37. if (length < sizeof(int32_t)) {
  38. MS_LOG(ERROR) << "Length of byteArray is " << length << ", less than sizeof(int32_t): 4.";
  39. return -1;
  40. }
  41. return *(reinterpret_cast<const int32_t *>(byteArray));
  42. }
  43. bool IsCipherFile(const std::string &file_path) {
  44. std::ifstream fid(file_path, std::ios::in | std::ios::binary);
  45. if (!fid) {
  46. MS_LOG(ERROR) << "Failed to open file " << file_path;
  47. return false;
  48. }
  49. std::vector<char> int_buf(sizeof(int32_t));
  50. fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
  51. fid.close();
  52. auto flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
  53. return static_cast<unsigned int>(flag) == MAGIC_NUM;
  54. }
  55. bool IsCipherFile(const Byte *model_data) {
  56. MS_EXCEPTION_IF_NULL(model_data);
  57. std::vector<Byte> int_buf;
  58. int_buf.assign(model_data, model_data + sizeof(int32_t));
  59. auto flag = ByteToInt(int_buf.data(), int_buf.size());
  60. return static_cast<unsigned int>(flag) == MAGIC_NUM;
  61. }
  62. #ifndef ENABLE_OPENSSL
  63. std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, const Byte *key,
  64. size_t key_len, const std::string &enc_mode) {
  65. MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
  66. "when the OPENSSL compilation option is enabled.";
  67. return nullptr;
  68. }
  69. std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
  70. size_t key_len, const std::string &dec_mode) {
  71. MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
  72. "when the OPENSSL compilation option is enabled.";
  73. return nullptr;
  74. }
  75. std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, size_t data_size, const Byte *key,
  76. size_t key_len, const std::string &dec_mode) {
  77. MS_LOG(ERROR) << "The feature is only supported on the Linux platform "
  78. "when the OPENSSL compilation option is enabled.";
  79. return nullptr;
  80. }
  81. #else
  82. bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector<Byte> *iv,
  83. std::vector<Byte> *cipher_data) {
  84. // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
  85. std::vector<Byte> int_buf(sizeof(int32_t));
  86. int_buf.assign(encrypt_data, encrypt_data + sizeof(int32_t));
  87. auto iv_len = ByteToInt(int_buf.data(), int_buf.size());
  88. int_buf.assign(encrypt_data + iv_len + sizeof(int32_t), encrypt_data + iv_len + sizeof(int32_t) + sizeof(int32_t));
  89. auto cipher_len = ByteToInt(int_buf.data(), int_buf.size());
  90. if (iv_len <= 0 || cipher_len <= 0 ||
  91. ((static_cast<size_t>(iv_len) + sizeof(int32_t) + static_cast<size_t>(cipher_len) + sizeof(int32_t)) !=
  92. encrypt_len)) {
  93. MS_LOG(ERROR) << "Failed to parse encrypt data.";
  94. return false;
  95. }
  96. (*iv).assign(encrypt_data + sizeof(int32_t), encrypt_data + sizeof(int32_t) + iv_len);
  97. (*cipher_data)
  98. .assign(encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t),
  99. encrypt_data + sizeof(int32_t) + iv_len + sizeof(int32_t) + cipher_len);
  100. return true;
  101. }
  102. bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work_mode) {
  103. std::smatch results;
  104. std::regex re("([A-Z]{3})-([A-Z]{3})");
  105. if (!(std::regex_match(mode.c_str(), re) && std::regex_search(mode, results, re))) {
  106. MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
  107. return false;
  108. }
  109. *alg_mode = results[1];
  110. *work_mode = results[2];
  111. return true;
  112. }
  113. EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &work_mode, const Byte *key, int32_t key_len, const Byte *iv,
  114. bool is_encrypt) {
  115. constexpr int32_t key_length_16 = 16;
  116. constexpr int32_t key_length_24 = 24;
  117. constexpr int32_t key_length_32 = 32;
  118. const EVP_CIPHER *(*funcPtr)() = nullptr;
  119. if (work_mode == "GCM") {
  120. switch (key_len) {
  121. case key_length_16:
  122. funcPtr = EVP_aes_128_gcm;
  123. break;
  124. case key_length_24:
  125. funcPtr = EVP_aes_192_gcm;
  126. break;
  127. case key_length_32:
  128. funcPtr = EVP_aes_256_gcm;
  129. break;
  130. default:
  131. MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
  132. return nullptr;
  133. }
  134. } else if (work_mode == "CBC") {
  135. switch (key_len) {
  136. case key_length_16:
  137. funcPtr = EVP_aes_128_cbc;
  138. break;
  139. case key_length_24:
  140. funcPtr = EVP_aes_192_cbc;
  141. break;
  142. case key_length_32:
  143. funcPtr = EVP_aes_256_cbc;
  144. break;
  145. default:
  146. MS_LOG(ERROR) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
  147. return nullptr;
  148. }
  149. } else {
  150. MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
  151. return nullptr;
  152. }
  153. int32_t ret = 0;
  154. auto ctx = EVP_CIPHER_CTX_new();
  155. if (is_encrypt) {
  156. ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
  157. } else {
  158. ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
  159. }
  160. if (ret != 1) {
  161. MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
  162. return nullptr;
  163. }
  164. if (work_mode == "CBC") {
  165. ret = EVP_CIPHER_CTX_set_padding(ctx, 1);
  166. if (ret != 1) {
  167. MS_LOG(ERROR) << "EVP_CIPHER_CTX_set_padding failed";
  168. return nullptr;
  169. }
  170. }
  171. return ctx;
  172. }
  173. bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector<Byte> &plain_data, const Byte *key,
  174. int32_t key_len, const std::string &enc_mode) {
  175. size_t encrypt_data_buf_len = *encrypt_data_len;
  176. int32_t cipher_len = 0;
  177. int32_t iv_len = AES_BLOCK_SIZE;
  178. std::vector<Byte> iv(iv_len);
  179. auto ret = RAND_bytes(iv.data(), iv_len);
  180. if (ret != 1) {
  181. MS_LOG(ERROR) << "RAND_bytes error, failed to init iv.";
  182. return false;
  183. }
  184. std::vector<Byte> iv_cpy(iv);
  185. std::string alg_mode;
  186. std::string work_mode;
  187. if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
  188. return false;
  189. }
  190. auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), true);
  191. if (ctx == nullptr) {
  192. MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
  193. return false;
  194. }
  195. std::vector<Byte> cipher_data_buf(plain_data.size() + AES_BLOCK_SIZE);
  196. auto ret_evp = EVP_EncryptUpdate(ctx, cipher_data_buf.data(), &cipher_len, plain_data.data(),
  197. static_cast<int32_t>(plain_data.size()));
  198. if (ret_evp != 1) {
  199. MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
  200. return false;
  201. }
  202. if (work_mode == "CBC") {
  203. int32_t flen = 0;
  204. ret_evp = EVP_EncryptFinal_ex(ctx, cipher_data_buf.data() + cipher_len, &flen);
  205. if (ret_evp != 1) {
  206. MS_LOG(ERROR) << "EVP_EncryptFinal_ex failed";
  207. return false;
  208. }
  209. cipher_len += flen;
  210. }
  211. EVP_CIPHER_CTX_free(ctx);
  212. size_t offset = 0;
  213. std::vector<Byte> int_buf(sizeof(int32_t));
  214. *encrypt_data_len = sizeof(int32_t) + static_cast<size_t>(iv_len) + sizeof(int32_t) + static_cast<size_t>(cipher_len);
  215. IntToByte(&int_buf, static_cast<int32_t>(*encrypt_data_len));
  216. ret = memcpy_s(encrypt_data, encrypt_data_buf_len, int_buf.data(), int_buf.size());
  217. if (ret != 0) {
  218. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  219. }
  220. offset += int_buf.size();
  221. IntToByte(&int_buf, iv_len);
  222. ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, int_buf.data(), int_buf.size());
  223. if (ret != 0) {
  224. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  225. }
  226. offset += int_buf.size();
  227. ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, iv_cpy.data(), iv_cpy.size());
  228. if (ret != 0) {
  229. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  230. }
  231. offset += iv_cpy.size();
  232. IntToByte(&int_buf, cipher_len);
  233. ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, int_buf.data(), int_buf.size());
  234. if (ret != 0) {
  235. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  236. }
  237. offset += int_buf.size();
  238. ret = memcpy_s(encrypt_data + offset, encrypt_data_buf_len - offset, cipher_data_buf.data(),
  239. static_cast<size_t>(cipher_len));
  240. if (ret != 0) {
  241. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  242. }
  243. *encrypt_data_len += sizeof(int32_t);
  244. return true;
  245. }
  246. bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key,
  247. int32_t key_len, const std::string &dec_mode) {
  248. std::string alg_mode;
  249. std::string work_mode;
  250. if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
  251. return false;
  252. }
  253. std::vector<Byte> iv;
  254. std::vector<Byte> cipher_data;
  255. if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &cipher_data)) {
  256. return false;
  257. }
  258. auto ctx = GetEvpCipherCtx(work_mode, key, key_len, iv.data(), false);
  259. if (ctx == nullptr) {
  260. MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
  261. return false;
  262. }
  263. auto ret =
  264. EVP_DecryptUpdate(ctx, plain_data, plain_len, cipher_data.data(), static_cast<int32_t>(cipher_data.size()));
  265. if (ret != 1) {
  266. MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
  267. return false;
  268. }
  269. if (work_mode == "CBC") {
  270. int32_t mlen = 0;
  271. ret = EVP_DecryptFinal_ex(ctx, plain_data + *plain_len, &mlen);
  272. if (ret != 1) {
  273. MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
  274. return false;
  275. }
  276. *plain_len += mlen;
  277. }
  278. EVP_CIPHER_CTX_free(ctx);
  279. return true;
  280. }
  281. std::unique_ptr<Byte[]> Encrypt(size_t *encrypt_len, const Byte *plain_data, size_t plain_len, const Byte *key,
  282. size_t key_len, const std::string &enc_mode) {
  283. MS_EXCEPTION_IF_NULL(plain_data);
  284. MS_EXCEPTION_IF_NULL(key);
  285. size_t block_enc_buf_len = MAX_BLOCK_SIZE + RESERVED_BYTE_PER_BLOCK;
  286. size_t encrypt_buf_len = plain_len + (plain_len + MAX_BLOCK_SIZE) / MAX_BLOCK_SIZE * RESERVED_BYTE_PER_BLOCK;
  287. std::vector<Byte> int_buf(sizeof(int32_t));
  288. std::vector<Byte> block_buf;
  289. std::vector<Byte> block_enc_buf(block_enc_buf_len);
  290. auto encrypt_data = std::make_unique<Byte[]>(encrypt_buf_len);
  291. size_t offset = 0;
  292. *encrypt_len = 0;
  293. while (offset < plain_len) {
  294. size_t block_enc_len = block_enc_buf.size();
  295. size_t cur_block_size = std::min(MAX_BLOCK_SIZE, plain_len - offset);
  296. block_buf.assign(plain_data + offset, plain_data + offset + cur_block_size);
  297. if (!BlockEncrypt(block_enc_buf.data(), &block_enc_len, block_buf, key, static_cast<int32_t>(key_len), enc_mode)) {
  298. MS_LOG(ERROR) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
  299. return nullptr;
  300. }
  301. IntToByte(&int_buf, static_cast<int32_t>(MAGIC_NUM));
  302. size_t capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN); // avoid dest size over 2gb
  303. auto ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, int_buf.data(), sizeof(int32_t));
  304. if (ret != 0) {
  305. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  306. }
  307. *encrypt_len += sizeof(int32_t);
  308. capacity = std::min(encrypt_buf_len - *encrypt_len, SECUREC_MEM_MAX_LEN);
  309. ret = memcpy_s(encrypt_data.get() + *encrypt_len, capacity, block_enc_buf.data(), block_enc_len);
  310. if (ret != 0) {
  311. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  312. }
  313. *encrypt_len += block_enc_len;
  314. offset += cur_block_size;
  315. }
  316. return encrypt_data;
  317. }
  318. std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const std::string &encrypt_data_path, const Byte *key,
  319. size_t key_len, const std::string &dec_mode) {
  320. MS_EXCEPTION_IF_NULL(key);
  321. std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
  322. if (!fid) {
  323. MS_LOG(ERROR) << "Open file '" << encrypt_data_path << "' failed, please check the correct of the file.";
  324. return nullptr;
  325. }
  326. fid.seekg(0, std::ios_base::end);
  327. size_t file_size = static_cast<size_t>(fid.tellg());
  328. fid.clear();
  329. fid.seekg(0);
  330. std::vector<char> block_buf(MAX_BLOCK_SIZE + RESERVED_BYTE_PER_BLOCK);
  331. std::vector<char> int_buf(sizeof(int32_t));
  332. std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE);
  333. auto decrypt_data = std::make_unique<Byte[]>(file_size);
  334. int32_t decrypt_block_len;
  335. *decrypt_len = 0;
  336. while (static_cast<size_t>(fid.tellg()) < file_size) {
  337. fid.read(int_buf.data(), static_cast<int32_t>(sizeof(int32_t)));
  338. auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
  339. if (static_cast<unsigned int>(cipher_flag) != MAGIC_NUM) {
  340. MS_LOG(ERROR) << "File \"" << encrypt_data_path << "\" is not an encrypted file and cannot be decrypted";
  341. return nullptr;
  342. }
  343. fid.read(int_buf.data(), static_cast<int64_t>(sizeof(int32_t)));
  344. auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
  345. if (block_size < 0) {
  346. MS_LOG(ERROR) << "The block_size read from the cipher file must be not negative, but got " << block_size;
  347. return nullptr;
  348. }
  349. fid.read(block_buf.data(), static_cast<int64_t>(block_size));
  350. if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
  351. static_cast<size_t>(block_size), key, static_cast<int32_t>(key_len), dec_mode))) {
  352. MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
  353. return nullptr;
  354. }
  355. size_t capacity = std::min(file_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
  356. auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(),
  357. static_cast<int32_t>(decrypt_block_len));
  358. if (ret != 0) {
  359. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  360. }
  361. *decrypt_len += static_cast<size_t>(decrypt_block_len);
  362. }
  363. fid.close();
  364. return decrypt_data;
  365. }
  366. std::unique_ptr<Byte[]> Decrypt(size_t *decrypt_len, const Byte *model_data, size_t data_size, const Byte *key,
  367. size_t key_len, const std::string &dec_mode) {
  368. MS_EXCEPTION_IF_NULL(model_data);
  369. MS_EXCEPTION_IF_NULL(key);
  370. std::vector<char> block_buf;
  371. std::vector<char> int_buf(sizeof(int32_t));
  372. std::vector<Byte> decrypt_block_buf(MAX_BLOCK_SIZE);
  373. auto decrypt_data = std::make_unique<Byte[]>(data_size);
  374. int32_t decrypt_block_len;
  375. size_t offset = 0;
  376. *decrypt_len = 0;
  377. while (offset < data_size) {
  378. int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
  379. offset += int_buf.size();
  380. auto cipher_flag = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
  381. if (static_cast<unsigned int>(cipher_flag) != MAGIC_NUM) {
  382. MS_LOG(ERROR) << "model_data is not encrypted and therefore cannot be decrypted.";
  383. return nullptr;
  384. }
  385. int_buf.assign(model_data + offset, model_data + offset + sizeof(int32_t));
  386. offset += int_buf.size();
  387. auto block_size = ByteToInt(reinterpret_cast<Byte *>(int_buf.data()), int_buf.size());
  388. if (block_size < 0) {
  389. MS_LOG(ERROR) << "The block_size read from the cipher data must be not negative, but got " << block_size;
  390. return nullptr;
  391. }
  392. block_buf.assign(model_data + offset, model_data + offset + block_size);
  393. offset += block_buf.size();
  394. if (!(BlockDecrypt(decrypt_block_buf.data(), &decrypt_block_len, reinterpret_cast<Byte *>(block_buf.data()),
  395. block_buf.size(), key, static_cast<int32_t>(key_len), dec_mode))) {
  396. MS_LOG(ERROR) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
  397. return nullptr;
  398. }
  399. size_t capacity = std::min(data_size - *decrypt_len, SECUREC_MEM_MAX_LEN);
  400. auto ret = memcpy_s(decrypt_data.get() + *decrypt_len, capacity, decrypt_block_buf.data(),
  401. static_cast<size_t>(decrypt_block_len));
  402. if (ret != 0) {
  403. MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret;
  404. }
  405. *decrypt_len += static_cast<size_t>(decrypt_block_len);
  406. }
  407. return decrypt_data;
  408. }
  409. #endif
  410. } // namespace mindspore