|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- /**
- * \file src/decryption/rc4/rc4_cryption_impl.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #include "rc4_cryption_impl.h"
- #include "../../misc.h"
-
- #include <cstring>
-
- using namespace lite;
-
- /*!
- * \brief Read the input stream once in order to initialize the decryption
- * state.
- */
- void RC4Impl::init_rc4_state() {
- rc4::RC4RandStream enc_stream(m_enc_key);
- rc4::FastHash64 dechash(m_hash_key);
-
- size_t offset = 0;
-
- std::vector<uint64_t> buffer(128);
- size_t remaining = m_model_length - sizeof(uint64_t);
- while (remaining > 0) {
- size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
- memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
- toread);
- offset += toread;
- remaining -= toread;
-
- for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
- uint64_t value = buffer[i];
- value ^= enc_stream.next64();
- dechash.feed(value);
- }
- }
-
- uint64_t hashvalue;
- memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
- sizeof(hashvalue));
- offset += sizeof(hashvalue);
-
- hashvalue ^= dechash.get() ^ enc_stream.next64();
- m_state.hash_stream.reset(hashvalue);
- m_state.enc_stream.reset(m_enc_key);
- }
-
- std::vector<uint8_t> RC4Impl::decrypt_model() {
- std::vector<uint8_t> result(m_model_length, 0);
-
- uint8_t* ptr = result.data();
- for (size_t i = 0; i < m_model_length; ++i) {
- ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
- ptr[i] ^= m_state.hash_stream.next8() ^ m_state.enc_stream.next8();
- }
- return result;
- }
-
- /*! \brief Encrypt the data in m_buffer.
- *
- * The basic idea is to calculate a 64-bit hash from the buffer and append
- * it to the end of the buffer. The basic requirement is that the change of
- * every byte including the hash value will destroy the whole model in every
- * byte.
- *
- * Encryption:
- *
- * 1. First calculate a 64-bit hash, called plain hash value, from the
- * buffer.
- * 2. Initialize a RC4 stream with the plain hash value.
- * 3. Obfuscate the model body with the RC4 stream defined in step 2.
- * 4. Calculate the hash value of the obfuscated model, called hash value
- * after hashing.
- * 5. Encrypt the model body with a RC4 stream made from the encryption key.
- * 6. Bit-xor the hash value after hashing with the plain hash value, called
- * mixed hash.
- * 7. Encrypt the mixed hash with the RC4 stream defined in step 5, called
- * the protected hash.
- * 8. Append the protected hash to the buffer.
- *
- * Decryption:
- * 1. Decrypt the model body with a RC4 stream made from the encryption key,
- * which is the reverse of step 5 and 7 of encryption and get the mixed
- * hash.
- * 2. Calculate the hash value of the decrypted model, which equals to the
- * hash value after hashing in step 4 of encryption.
- * 3. Bit-xor the hash value after hashing and the mixed hash to get the
- * plain hash value, which is the reverse of step 6 of encryption.
- * 4. Un-obfuscate the model body with the plain hash value, which is the
- * reverse of step 3 of encryption.
- *
- * Think:
- * 1. If any byte in the model body is broken, the hash value after hashing
- * will be broken in step 2, and hence the plain hash value in step 3
- * will be also broken, and finally, the model body will be broken in
- * step 4.
- * 2. If the protected hash is broken, the plain hash value in step 3 will
- * be broken, and finally the model body will be broken.
- */
- std::vector<uint8_t> RC4Impl::encrypt_model() {
- size_t total_length =
- (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
- std::vector<uint8_t> pad_model(total_length, 0);
- memcpy(pad_model.data(), m_model_mem, m_model_length);
-
- // Calculate the hash of the model.
- rc4::FastHash64 plainhash(m_hash_key);
- uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
- size_t len = pad_model.size() / sizeof(uint64_t);
-
- for (size_t i = 0; i < len; ++i)
- plainhash.feed(ptr[i]);
- uint64_t plainhash_value = plainhash.get();
-
- // Encrypt the model.
- rc4::RC4RandStream hash_enc(plainhash_value);
- rc4::RC4RandStream outmost_enc(m_enc_key);
- rc4::FastHash64 afterhashenc_hash(m_hash_key);
-
- for (size_t i = 0; i < len; ++i) {
- uint64_t value = ptr[i] ^ hash_enc.next64();
- afterhashenc_hash.feed(value);
- ptr[i] = value ^ outmost_enc.next64();
- }
-
- uint64_t protected_hash =
- plainhash_value ^ afterhashenc_hash.get() ^ outmost_enc.next64();
-
- size_t end = pad_model.size();
- pad_model.resize(pad_model.size() + sizeof(uint64_t));
- ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
- *ptr = protected_hash;
- return pad_model;
- }
-
- /*!
- * \brief Read the input stream once in order to initialize the decryption
- * state.
- */
- void SimpleFastRC4Impl::init_sfrc4_state() {
- rc4::RC4RandStream enc_stream(m_enc_key);
- rc4::FastHash64 dechash(m_hash_key);
-
- size_t offset = 0;
- std::vector<uint64_t> buffer(128);
- size_t remaining = m_model_length - sizeof(uint64_t);
- while (remaining > 0) {
- size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
- memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
- toread);
- offset += toread;
- remaining -= toread;
-
- for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
- uint64_t value = buffer[i];
- dechash.feed(value);
- }
- }
-
- uint64_t hashvalue;
- memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
- sizeof(hashvalue));
-
- offset += sizeof(hashvalue);
-
- /*! \brief test the hash_val. */
- if (hashvalue != dechash.get())
- LITE_THROW(
- "The checksum of the file cannot be verified. The file may "
- "be encrypted in the wrong algorithm or different keys.");
-
- m_state.hash_stream.reset(m_hash_key);
- m_state.enc_stream.reset(m_enc_key);
- }
-
- std::vector<uint8_t> SimpleFastRC4Impl::decrypt_model() {
- std::vector<uint8_t> result(m_model_length, 0);
- uint8_t* ptr = result.data();
- for (size_t i = 0; i < m_model_length; ++i) {
- ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
- ptr[i] ^= m_state.enc_stream.next8();
- }
- return result;
- }
-
- std::vector<uint8_t> SimpleFastRC4Impl::encrypt_model() {
- size_t total_length =
- (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t);
- std::vector<uint8_t> pad_model(total_length, 0);
- memcpy(pad_model.data(), m_model_mem, m_model_length);
-
- // Calculate the hash of the model.
- rc4::FastHash64 enchash(m_hash_key);
- uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
- size_t len = pad_model.size() / sizeof(uint64_t);
-
- // Encrypt the model.
- rc4::RC4RandStream out_enc(m_enc_key);
- for (size_t i = 0; i < len; ++i) {
- ptr[i] = ptr[i] ^ out_enc.next64();
- enchash.feed(ptr[i]);
- }
-
- uint64_t hash_value = enchash.get();
-
- size_t end = pad_model.size();
- pad_model.resize(pad_model.size() + sizeof(uint64_t));
- ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
- *ptr = hash_value;
-
- return pad_model;
- }
-
- // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
|