/** * \file src/decryption/rc4/rc4_cryption_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "rc4_cryption_impl.h" #include "../../misc.h" #include using namespace lite; /*! * \brief Read the input stream once in order to initialize the decryption * state. */ void RC4Impl::init_rc4_state() { rc4::RC4RandStream enc_stream(m_enc_key); rc4::FastHash64 dechash(m_hash_key); size_t offset = 0; std::vector buffer(128); size_t remaining = m_model_length - sizeof(uint64_t); while (remaining > 0) { size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t)); memcpy(buffer.data(), static_cast(m_model_mem) + offset, toread); offset += toread; remaining -= toread; for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) { uint64_t value = buffer[i]; value ^= enc_stream.next64(); dechash.feed(value); } } uint64_t hashvalue; memcpy(&hashvalue, static_cast(m_model_mem) + offset, sizeof(hashvalue)); offset += sizeof(hashvalue); hashvalue ^= dechash.get() ^ enc_stream.next64(); m_state.hash_stream.reset(hashvalue); m_state.enc_stream.reset(m_enc_key); } std::vector RC4Impl::decrypt_model() { std::vector result(m_model_length, 0); uint8_t* ptr = result.data(); for (size_t i = 0; i < m_model_length; ++i) { ptr[i] = static_cast(m_model_mem)[i]; ptr[i] ^= m_state.hash_stream.next8() ^ m_state.enc_stream.next8(); } return result; } /*! \brief Encrypt the data in m_buffer. * * The basic idea is to calculate a 64-bit hash from the buffer and append * it to the end of the buffer. The basic requirement is that the change of * every byte including the hash value will destroy the whole model in every * byte. * * Encryption: * * 1. First calculate a 64-bit hash, called plain hash value, from the * buffer. * 2. Initialize a RC4 stream with the plain hash value. * 3. Obfuscate the model body with the RC4 stream defined in step 2. * 4. Calculate the hash value of the obfuscated model, called hash value * after hashing. * 5. Encrypt the model body with a RC4 stream made from the encryption key. * 6. Bit-xor the hash value after hashing with the plain hash value, called * mixed hash. * 7. Encrypt the mixed hash with the RC4 stream defined in step 5, called * the protected hash. * 8. Append the protected hash to the buffer. * * Decryption: * 1. Decrypt the model body with a RC4 stream made from the encryption key, * which is the reverse of step 5 and 7 of encryption and get the mixed * hash. * 2. Calculate the hash value of the decrypted model, which equals to the * hash value after hashing in step 4 of encryption. * 3. Bit-xor the hash value after hashing and the mixed hash to get the * plain hash value, which is the reverse of step 6 of encryption. * 4. Un-obfuscate the model body with the plain hash value, which is the * reverse of step 3 of encryption. * * Think: * 1. If any byte in the model body is broken, the hash value after hashing * will be broken in step 2, and hence the plain hash value in step 3 * will be also broken, and finally, the model body will be broken in * step 4. * 2. If the protected hash is broken, the plain hash value in step 3 will * be broken, and finally the model body will be broken. */ std::vector RC4Impl::encrypt_model() { size_t total_length = (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t); std::vector pad_model(total_length, 0); memcpy(pad_model.data(), m_model_mem, m_model_length); // Calculate the hash of the model. rc4::FastHash64 plainhash(m_hash_key); uint64_t* ptr = reinterpret_cast(pad_model.data()); size_t len = pad_model.size() / sizeof(uint64_t); for (size_t i = 0; i < len; ++i) plainhash.feed(ptr[i]); uint64_t plainhash_value = plainhash.get(); // Encrypt the model. rc4::RC4RandStream hash_enc(plainhash_value); rc4::RC4RandStream outmost_enc(m_enc_key); rc4::FastHash64 afterhashenc_hash(m_hash_key); for (size_t i = 0; i < len; ++i) { uint64_t value = ptr[i] ^ hash_enc.next64(); afterhashenc_hash.feed(value); ptr[i] = value ^ outmost_enc.next64(); } uint64_t protected_hash = plainhash_value ^ afterhashenc_hash.get() ^ outmost_enc.next64(); size_t end = pad_model.size(); pad_model.resize(pad_model.size() + sizeof(uint64_t)); ptr = reinterpret_cast(&pad_model[end]); *ptr = protected_hash; return pad_model; } /*! * \brief Read the input stream once in order to initialize the decryption * state. */ void SimpleFastRC4Impl::init_sfrc4_state() { rc4::RC4RandStream enc_stream(m_enc_key); rc4::FastHash64 dechash(m_hash_key); size_t offset = 0; std::vector buffer(128); size_t remaining = m_model_length - sizeof(uint64_t); while (remaining > 0) { size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t)); memcpy(buffer.data(), static_cast(m_model_mem) + offset, toread); offset += toread; remaining -= toread; for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) { uint64_t value = buffer[i]; dechash.feed(value); } } uint64_t hashvalue; memcpy(&hashvalue, static_cast(m_model_mem) + offset, sizeof(hashvalue)); offset += sizeof(hashvalue); /*! \brief test the hash_val. */ if (hashvalue != dechash.get()) LITE_THROW( "The checksum of the file cannot be verified. The file may " "be encrypted in the wrong algorithm or different keys."); m_state.hash_stream.reset(m_hash_key); m_state.enc_stream.reset(m_enc_key); } std::vector SimpleFastRC4Impl::decrypt_model() { std::vector result(m_model_length, 0); uint8_t* ptr = result.data(); for (size_t i = 0; i < m_model_length; ++i) { ptr[i] = static_cast(m_model_mem)[i]; ptr[i] ^= m_state.enc_stream.next8(); } return result; } std::vector SimpleFastRC4Impl::encrypt_model() { size_t total_length = (m_model_length + (sizeof(size_t) - 1)) / sizeof(size_t) * sizeof(size_t); std::vector pad_model(total_length, 0); memcpy(pad_model.data(), m_model_mem, m_model_length); // Calculate the hash of the model. rc4::FastHash64 enchash(m_hash_key); uint64_t* ptr = reinterpret_cast(pad_model.data()); size_t len = pad_model.size() / sizeof(uint64_t); // Encrypt the model. rc4::RC4RandStream out_enc(m_enc_key); for (size_t i = 0; i < len; ++i) { ptr[i] = ptr[i] ^ out_enc.next64(); enchash.feed(ptr[i]); } uint64_t hash_value = enchash.get(); size_t end = pad_model.size(); pad_model.resize(pad_model.size() + sizeof(uint64_t)); ptr = reinterpret_cast(&pad_model[end]); *ptr = hash_value; return pad_model; } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}