You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

crypto.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "crypto/crypto.h"
  17. namespace mindspore {
  18. namespace crypto {
  19. int64_t Min(int64_t a, int64_t b) { return a < b ? a : b; }
  20. Byte *intToByte(const int32_t &n) {
  21. Byte *byte = new Byte[4];
  22. memset(byte, 0, sizeof(Byte) * 4);
  23. byte[0] = (Byte)(0xFF & n);
  24. byte[1] = (Byte)((0xFF00 & n) >> 8);
  25. byte[2] = (Byte)((0xFF0000 & n) >> 16);
  26. byte[3] = (Byte)((0xFF000000 & n) >> 24);
  27. return byte;
  28. }
  29. int32_t ByteToint(const Byte *byteArray) {
  30. int32_t res = byteArray[0] & 0xFF;
  31. res |= ((byteArray[1] << 8) & 0xFF00);
  32. res |= ((byteArray[2] << 16) & 0xFF0000);
  33. res += ((byteArray[3] << 24) & 0xFF000000);
  34. return res;
  35. }
  36. bool IsCipherFile(std::string file_path) {
  37. char *int_buf = new char[4];
  38. int flag = 0;
  39. std::ifstream fid(file_path, std::ios::in | std::ios::binary);
  40. if (!fid) {
  41. MS_LOG(ERROR) << "Open file failed";
  42. exit(-1);
  43. }
  44. fid.read(int_buf, sizeof(int32_t));
  45. fid.close();
  46. flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
  47. delete[] int_buf;
  48. return flag == MAGIC_NUM;
  49. }
  50. #if defined(_WIN32)
  51. Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
  52. const std::string &enc_mode) {
  53. MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
  54. }
  55. Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
  56. const std::string &dec_mode) {
  57. MS_EXCEPTION(NotSupportError) << "Unsupported feature in Windows platform.";
  58. }
  59. #else
  60. bool ParseEncryptData(const Byte *encrypt_data, const int32_t encrypt_len, Byte **iv, int32_t *iv_len,
  61. Byte **cipher_data, int32_t *cipher_len) {
  62. // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data
  63. Byte buf[4];
  64. memcpy_s(buf, 4, encrypt_data, 4);
  65. *iv_len = ByteToint(buf);
  66. memcpy_s(buf, 4, encrypt_data + *iv_len + 4, 4);
  67. *cipher_len = ByteToint(buf);
  68. if (*iv_len <= 0 || *cipher_len <= 0 || *iv_len + *cipher_len + 8 != encrypt_len) {
  69. MS_LOG(ERROR) << "Failed to parse encrypt data.";
  70. return false;
  71. }
  72. *iv = new Byte[*iv_len];
  73. memcpy_s(*iv, *iv_len, encrypt_data + 4, *iv_len);
  74. *cipher_data = new Byte[*cipher_len];
  75. memcpy_s(*cipher_data, *cipher_len, encrypt_data + *iv_len + 8, *cipher_len);
  76. return true;
  77. }
  78. bool ParseMode(std::string mode, std::string *alg_mode, std::string *work_mode) {
  79. std::smatch results;
  80. std::regex re("([A-Z]{3})-([A-Z]{3})");
  81. if (!std::regex_match(mode.c_str(), re)) {
  82. MS_LOG(ERROR) << "Mode " << mode << " is invalid.";
  83. return false;
  84. }
  85. std::regex_search(mode, results, re);
  86. *alg_mode = results[1];
  87. *work_mode = results[2];
  88. return true;
  89. }
  90. EVP_CIPHER_CTX *GetEVP_CIPHER_CTX(const std::string &work_mode, const Byte *key, const int32_t key_len, const Byte *iv,
  91. int flag) {
  92. int ret = 0;
  93. EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
  94. if (work_mode != "GCM" && work_mode != "CBC") {
  95. MS_LOG(ERROR) << "Work mode " << work_mode << " is invalid.";
  96. return nullptr;
  97. }
  98. const EVP_CIPHER *(*funcPtr)() = nullptr;
  99. if (work_mode == "GCM") {
  100. switch (key_len) {
  101. case 16:
  102. funcPtr = EVP_aes_128_gcm;
  103. break;
  104. case 24:
  105. funcPtr = EVP_aes_192_gcm;
  106. break;
  107. case 32:
  108. funcPtr = EVP_aes_256_gcm;
  109. break;
  110. default:
  111. MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
  112. }
  113. } else if (work_mode == "CBC") {
  114. switch (key_len) {
  115. case 16:
  116. funcPtr = EVP_aes_128_cbc;
  117. break;
  118. case 24:
  119. funcPtr = EVP_aes_192_cbc;
  120. break;
  121. case 32:
  122. funcPtr = EVP_aes_256_cbc;
  123. break;
  124. default:
  125. MS_EXCEPTION(ValueError) << "The key length must be 16, 24 or 32, but got key length is " << key_len << ".";
  126. }
  127. }
  128. if (flag == 0) {
  129. ret = EVP_EncryptInit_ex(ctx, funcPtr(), NULL, key, iv);
  130. } else if (flag == 1) {
  131. ret = EVP_DecryptInit_ex(ctx, funcPtr(), NULL, key, iv);
  132. }
  133. if (ret != 1) {
  134. MS_LOG(ERROR) << "EVP_EncryptInit_ex failed";
  135. return nullptr;
  136. }
  137. if (work_mode == "CBC") EVP_CIPHER_CTX_set_padding(ctx, 1);
  138. return ctx;
  139. }
  140. bool _BlockEncrypt(Byte *encrypt_data, int64_t *encrypt_data_len, Byte *plain_data, const int64_t plain_len, Byte *key,
  141. const int32_t key_len, const std::string &enc_mode) {
  142. int32_t cipher_len = 0;
  143. int32_t iv_len = AES_BLOCK_SIZE;
  144. Byte *iv = new Byte[iv_len];
  145. RAND_bytes(iv, sizeof(Byte) * iv_len);
  146. Byte *iv_cpy = new Byte[16];
  147. memcpy_s(iv_cpy, 16, iv, 16);
  148. int32_t ret = 0;
  149. int32_t flen = 0;
  150. std::string alg_mode;
  151. std::string work_mode;
  152. if (!ParseMode(enc_mode, &alg_mode, &work_mode)) {
  153. return false;
  154. }
  155. auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 0);
  156. if (ctx == nullptr) {
  157. MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
  158. return false;
  159. }
  160. Byte *cipher_data;
  161. cipher_data = new Byte[plain_len + 16];
  162. ret = EVP_EncryptUpdate(ctx, cipher_data, &cipher_len, plain_data, plain_len);
  163. if (ret != 1) {
  164. MS_LOG(ERROR) << "EVP_EncryptUpdate failed";
  165. delete[] cipher_data;
  166. return false;
  167. }
  168. if (work_mode == "CBC") {
  169. EVP_EncryptFinal_ex(ctx, cipher_data + cipher_len, &flen);
  170. cipher_len += flen;
  171. }
  172. EVP_CIPHER_CTX_free(ctx);
  173. int64_t cur = 0;
  174. *encrypt_data_len = sizeof(int32_t) * 2 + iv_len + cipher_len;
  175. memcpy(encrypt_data + cur, intToByte(*encrypt_data_len), 4);
  176. cur += 4;
  177. memcpy(encrypt_data + cur, intToByte(iv_len), 4);
  178. cur += 4;
  179. memcpy(encrypt_data + cur, iv_cpy, iv_len);
  180. cur += iv_len;
  181. memcpy(encrypt_data + cur, intToByte(cipher_len), 4);
  182. cur += 4;
  183. memcpy(encrypt_data + cur, cipher_data, cipher_len);
  184. *encrypt_data_len += 4;
  185. delete[] cipher_data;
  186. return true;
  187. }
  188. bool _BlockDecrypt(Byte **plain_data, int32_t *plain_len, Byte *encrypt_data, const int64_t encrypt_len, Byte *key,
  189. const int32_t key_len, const std::string &dec_mode) {
  190. std::string alg_mode;
  191. std::string work_mode;
  192. if (!ParseMode(dec_mode, &alg_mode, &work_mode)) {
  193. return false;
  194. }
  195. int32_t iv_len = 0;
  196. int32_t cipher_len = 0;
  197. Byte *iv = NULL;
  198. Byte *cipher_data = NULL;
  199. if (!ParseEncryptData(encrypt_data, encrypt_len, &iv, &iv_len, &cipher_data, &cipher_len)) {
  200. return false;
  201. }
  202. *plain_data = new Byte[cipher_len + 16];
  203. if (*plain_data == NULL) {
  204. MS_LOG(ERROR) << "Unable to allocate memory for decrypt_string.";
  205. return false;
  206. }
  207. int ret = 0;
  208. int mlen = 0;
  209. auto ctx = GetEVP_CIPHER_CTX(work_mode, key, key_len, iv, 1);
  210. if (ctx == nullptr) {
  211. MS_LOG(ERROR) << "Failed to get EVP_CIPHER_CTX.";
  212. return false;
  213. }
  214. ret = EVP_DecryptUpdate(ctx, *plain_data, plain_len, cipher_data, cipher_len);
  215. if (ret != 1) {
  216. MS_LOG(ERROR) << "EVP_DecryptUpdate failed";
  217. return false;
  218. }
  219. if (work_mode == "CBC") {
  220. ret = EVP_DecryptFinal_ex(ctx, *plain_data + *plain_len, &mlen);
  221. if (ret != 1) {
  222. MS_LOG(ERROR) << "EVP_DecryptFinal_ex failed";
  223. return false;
  224. }
  225. *plain_len += mlen;
  226. }
  227. delete[] iv;
  228. delete[] cipher_data;
  229. EVP_CIPHER_CTX_free(ctx);
  230. return true;
  231. }
  232. Byte *Encrypt(int64_t *encrypt_len, Byte *plain_data, const int64_t plain_len, Byte *key, const int32_t key_len,
  233. const std::string &enc_mode) {
  234. int64_t cur_pos = 0;
  235. int64_t block_enc_len = 0;
  236. int64_t encrypt_buf_len = plain_len + (plain_len / MAX_BLOCK_SIZE + 1) * 100;
  237. Byte *encrypt_data = new Byte[encrypt_buf_len];
  238. Byte *block_buf = new Byte[MAX_BLOCK_SIZE];
  239. Byte *block_enc_buf = new Byte[MAX_BLOCK_SIZE + 100];
  240. *encrypt_len = 0;
  241. while (cur_pos < plain_len) {
  242. int64_t cur_block_size = Min(MAX_BLOCK_SIZE, plain_len - cur_pos);
  243. memcpy_s(block_buf, MAX_BLOCK_SIZE, plain_data + cur_pos, cur_block_size);
  244. if (!_BlockEncrypt(block_enc_buf, &block_enc_len, block_buf, cur_block_size, key, key_len, enc_mode)) {
  245. delete[] block_buf;
  246. delete[] block_enc_buf;
  247. delete[] encrypt_data;
  248. MS_EXCEPTION(ValueError) << "Failed to encrypt data, please check if enc_key or enc_mode is valid.";
  249. }
  250. memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, intToByte(MAGIC_NUM), sizeof(int32_t));
  251. *encrypt_len += sizeof(int32_t);
  252. memcpy_s(encrypt_data + *encrypt_len, encrypt_buf_len - *encrypt_len, block_enc_buf, block_enc_len);
  253. *encrypt_len += block_enc_len;
  254. cur_pos += cur_block_size;
  255. }
  256. delete[] block_buf;
  257. delete[] block_enc_buf;
  258. return encrypt_data;
  259. }
  260. Byte *Decrypt(int64_t *decrypt_len, const std::string &encrypt_data_path, Byte *key, const int32_t key_len,
  261. const std::string &dec_mode) {
  262. Byte *decrypt_data = nullptr;
  263. char *block_buf = new char[MAX_BLOCK_SIZE * 2];
  264. char *int_buf = new char[4];
  265. Byte *decrypt_block_buf = nullptr;
  266. int32_t decrypt_block_len;
  267. std::ifstream fid(encrypt_data_path, std::ios::in | std::ios::binary);
  268. if (!fid) {
  269. MS_LOG(ERROR) << "Open file failed";
  270. exit(-1);
  271. }
  272. fid.seekg(0, std::ios_base::end);
  273. int64_t file_size = fid.tellg();
  274. fid.clear();
  275. fid.seekg(0);
  276. decrypt_data = new Byte[file_size];
  277. *decrypt_len = 0;
  278. while (fid.tellg() < file_size) {
  279. fid.read(int_buf, sizeof(int32_t));
  280. int cipher_flag = ByteToint(reinterpret_cast<Byte *>(int_buf));
  281. if (cipher_flag != MAGIC_NUM) {
  282. MS_EXCEPTION(ValueError) << "File \"" << encrypt_data_path
  283. << "\"is not an encrypted file and cannot be decrypted";
  284. }
  285. fid.read(int_buf, sizeof(int32_t));
  286. int32_t block_size = ByteToint(reinterpret_cast<Byte *>(int_buf));
  287. fid.read(block_buf, sizeof(char) * block_size);
  288. if (!(_BlockDecrypt(&decrypt_block_buf, &decrypt_block_len, reinterpret_cast<Byte *>(block_buf), block_size, key,
  289. key_len, dec_mode))) {
  290. delete[] block_buf;
  291. delete[] int_buf;
  292. delete[] decrypt_data;
  293. MS_EXCEPTION(ValueError) << "Failed to decrypt data, please check if dec_key or dec_mode is valid";
  294. }
  295. memcpy_s(decrypt_data + *decrypt_len, file_size - *decrypt_len, decrypt_block_buf, decrypt_block_len);
  296. *decrypt_len += decrypt_block_len;
  297. }
  298. fid.close();
  299. delete[] block_buf;
  300. delete[] int_buf;
  301. return decrypt_data;
  302. }
  303. #endif
  304. } // namespace crypto
  305. } // namespace mindspore