You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rc4_cryption_impl.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /**
  2. * \file src/decryption/rc4/rc4_cryption_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "rc4_cryption_impl.h"
  12. #include "../../misc.h"
  13. #include <cstring>
  14. using namespace lite;
  15. /*!
  16. * \brief Read the input stream once in order to initialize the decryption
  17. * state.
  18. */
  19. void RC4Impl::init_rc4_state() {
  20. rc4::RC4RandStream enc_stream(m_enc_key);
  21. rc4::FastHash64 dechash(m_hash_key);
  22. size_t offset = 0;
  23. std::vector<uint64_t> buffer(128);
  24. size_t remaining = m_model_length - sizeof(uint64_t);
  25. while (remaining > 0) {
  26. size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
  27. memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
  28. toread);
  29. offset += toread;
  30. remaining -= toread;
  31. for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
  32. uint64_t value = buffer[i];
  33. value ^= enc_stream.next64();
  34. dechash.feed(value);
  35. }
  36. }
  37. uint64_t hashvalue;
  38. memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
  39. sizeof(hashvalue));
  40. offset += sizeof(hashvalue);
  41. hashvalue ^= dechash.get() ^ enc_stream.next64();
  42. m_state.hash_stream.reset(hashvalue);
  43. m_state.enc_stream.reset(m_enc_key);
  44. }
  45. std::vector<uint8_t> RC4Impl::decrypt_model() {
  46. std::vector<uint8_t> result(m_model_length, 0);
  47. uint8_t* ptr = result.data();
  48. for (size_t i = 0; i < m_model_length; ++i) {
  49. ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
  50. ptr[i] ^= m_state.hash_stream.next8() ^ m_state.enc_stream.next8();
  51. }
  52. return result;
  53. }
  54. /*! \brief Encrypt the data in m_buffer.
  55. *
  56. * The basic idea is to calculate a 64-bit hash from the buffer and append
  57. * it to the end of the buffer. The basic requirement is that the change of
  58. * every byte including the hash value will destroy the whole model in every
  59. * byte.
  60. *
  61. * Encryption:
  62. *
  63. * 1. First calculate a 64-bit hash, called plain hash value, from the
  64. * buffer.
  65. * 2. Initialize a RC4 stream with the plain hash value.
  66. * 3. Obfuscate the model body with the RC4 stream defined in step 2.
  67. * 4. Calculate the hash value of the obfuscated model, called hash value
  68. * after hashing.
  69. * 5. Encrypt the model body with a RC4 stream made from the encryption key.
  70. * 6. Bit-xor the hash value after hashing with the plain hash value, called
  71. * mixed hash.
  72. * 7. Encrypt the mixed hash with the RC4 stream defined in step 5, called
  73. * the protected hash.
  74. * 8. Append the protected hash to the buffer.
  75. *
  76. * Decryption:
  77. * 1. Decrypt the model body with a RC4 stream made from the encryption key,
  78. * which is the reverse of step 5 and 7 of encryption and get the mixed
  79. * hash.
  80. * 2. Calculate the hash value of the decrypted model, which equals to the
  81. * hash value after hashing in step 4 of encryption.
  82. * 3. Bit-xor the hash value after hashing and the mixed hash to get the
  83. * plain hash value, which is the reverse of step 6 of encryption.
  84. * 4. Un-obfuscate the model body with the plain hash value, which is the
  85. * reverse of step 3 of encryption.
  86. *
  87. * Think:
  88. * 1. If any byte in the model body is broken, the hash value after hashing
  89. * will be broken in step 2, and hence the plain hash value in step 3
  90. * will be also broken, and finally, the model body will be broken in
  91. * step 4.
  92. * 2. If the protected hash is broken, the plain hash value in step 3 will
  93. * be broken, and finally the model body will be broken.
  94. */
  95. std::vector<uint8_t> RC4Impl::encrypt_model() {
  96. size_t total_length =
  97. (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
  98. std::vector<uint8_t> pad_model(total_length, 0);
  99. memcpy(pad_model.data(), m_model_mem, m_model_length);
  100. // Calculate the hash of the model.
  101. rc4::FastHash64 plainhash(m_hash_key);
  102. uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
  103. size_t len = pad_model.size() / sizeof(uint64_t);
  104. for (size_t i = 0; i < len; ++i)
  105. plainhash.feed(ptr[i]);
  106. uint64_t plainhash_value = plainhash.get();
  107. // Encrypt the model.
  108. rc4::RC4RandStream hash_enc(plainhash_value);
  109. rc4::RC4RandStream outmost_enc(m_enc_key);
  110. rc4::FastHash64 afterhashenc_hash(m_hash_key);
  111. for (size_t i = 0; i < len; ++i) {
  112. uint64_t value = ptr[i] ^ hash_enc.next64();
  113. afterhashenc_hash.feed(value);
  114. ptr[i] = value ^ outmost_enc.next64();
  115. }
  116. uint64_t protected_hash =
  117. plainhash_value ^ afterhashenc_hash.get() ^ outmost_enc.next64();
  118. size_t end = pad_model.size();
  119. pad_model.resize(pad_model.size() + sizeof(uint64_t));
  120. ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
  121. *ptr = protected_hash;
  122. return pad_model;
  123. }
  124. /*!
  125. * \brief Read the input stream once in order to initialize the decryption
  126. * state.
  127. */
  128. void SimpleFastRC4Impl::init_sfrc4_state() {
  129. rc4::RC4RandStream enc_stream(m_enc_key);
  130. rc4::FastHash64 dechash(m_hash_key);
  131. size_t offset = 0;
  132. std::vector<uint64_t> buffer(128);
  133. size_t remaining = m_model_length - sizeof(uint64_t);
  134. while (remaining > 0) {
  135. size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
  136. memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
  137. toread);
  138. offset += toread;
  139. remaining -= toread;
  140. for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
  141. uint64_t value = buffer[i];
  142. dechash.feed(value);
  143. }
  144. }
  145. uint64_t hashvalue;
  146. memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
  147. sizeof(hashvalue));
  148. offset += sizeof(hashvalue);
  149. /*! \brief test the hash_val. */
  150. if (hashvalue != dechash.get())
  151. LITE_THROW(
  152. "The checksum of the file cannot be verified. The file may "
  153. "be encrypted in the wrong algorithm or different keys.");
  154. m_state.hash_stream.reset(m_hash_key);
  155. m_state.enc_stream.reset(m_enc_key);
  156. }
  157. std::vector<uint8_t> SimpleFastRC4Impl::decrypt_model() {
  158. std::vector<uint8_t> result(m_model_length, 0);
  159. uint8_t* ptr = result.data();
  160. for (size_t i = 0; i < m_model_length; ++i) {
  161. ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
  162. ptr[i] ^= m_state.enc_stream.next8();
  163. }
  164. return result;
  165. }
  166. std::vector<uint8_t> SimpleFastRC4Impl::encrypt_model() {
  167. size_t total_length =
  168. (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
  169. std::vector<uint8_t> pad_model(total_length, 0);
  170. memcpy(pad_model.data(), m_model_mem, m_model_length);
  171. // Calculate the hash of the model.
  172. rc4::FastHash64 enchash(m_hash_key);
  173. uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
  174. size_t len = pad_model.size() / sizeof(uint64_t);
  175. // Encrypt the model.
  176. rc4::RC4RandStream out_enc(m_enc_key);
  177. for (size_t i = 0; i < len; ++i) {
  178. ptr[i] = ptr[i] ^ out_enc.next64();
  179. enchash.feed(ptr[i]);
  180. }
  181. uint64_t hash_value = enchash.get();
  182. size_t end = pad_model.size();
  183. pad_model.resize(pad_model.size() + sizeof(uint64_t));
  184. ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
  185. *ptr = hash_value;
  186. return pad_model;
  187. }
  188. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}