From 22342487e8fb851a9837db22408db56240aa6931 Mon Sep 17 00:00:00 2001
From: Zach Hilman <zachhilman@gmail.com>
Date: Sat, 28 Jul 2018 16:23:00 -0400
Subject: [PATCH] Extract mbedtls to cpp file

---
 src/common/file_util.cpp            |   2 +-
 src/core/CMakeLists.txt             |   1 +
 src/core/crypto/aes_util.cpp        | 102 ++++++++++++++++++++++++++-
 src/core/crypto/aes_util.h          | 104 ++++++----------------------
 src/core/file_sys/content_archive.h |   3 +-
 5 files changed, 126 insertions(+), 86 deletions(-)

diff --git a/src/common/file_util.cpp b/src/common/file_util.cpp
index 89004c3c0..a26702f54 100644
--- a/src/common/file_util.cpp
+++ b/src/common/file_util.cpp
@@ -739,7 +739,7 @@ const std::string& GetUserPath(UserPath path, const std::string& new_path) {
 std::string GetHactoolConfigurationPath() {
 #ifdef _WIN32
     char path[MAX_PATH];
-    if (SHGetFolderPathA(NULL, CSIDL_PROFILE, NULL, 0, path) != S_OK)
+    if (SHGetFolderPathA(nullptr, CSIDL_PROFILE, nullptr, 0, path) != S_OK)
         return "";
     std::string local_path = Common::StringFromFixedZeroTerminatedBuffer(path, MAX_PATH);
     return local_path + "\\.switch";
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index a69ee9146..528d96a58 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -12,6 +12,7 @@ add_library(core STATIC
     core_timing.h
     core_timing_util.cpp
     core_timing_util.h
+    crypto/aes_util.cpp
     crypto/aes_util.h
     crypto/encryption_layer.cpp
     crypto/encryption_layer.h
diff --git a/src/core/crypto/aes_util.cpp b/src/core/crypto/aes_util.cpp
index 46326cdec..a9646e52f 100644
--- a/src/core/crypto/aes_util.cpp
+++ b/src/core/crypto/aes_util.cpp
@@ -2,5 +2,103 @@
 // Licensed under GPLv2 or any later version
 // Refer to the license.txt file included.
 
-namespace Crypto {
-} // namespace Crypto
+#include "core/crypto/aes_util.h"
+#include "mbedtls/cipher.h"
+
+namespace Core::Crypto {
+static_assert(static_cast<size_t>(Mode::CTR) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_CTR), "CTR mode is incorrect.");
+static_assert(static_cast<size_t>(Mode::ECB) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_ECB), "ECB mode is incorrect.");
+static_assert(static_cast<size_t>(Mode::XTS) == static_cast<size_t>(MBEDTLS_CIPHER_AES_128_XTS), "XTS mode is incorrect.");
+
+template<typename Key, size_t KeySize>
+Crypto::AESCipher<Key, KeySize>::AESCipher(Key key, Mode mode) {
+    mbedtls_cipher_init(encryption_context.get());
+    mbedtls_cipher_init(decryption_context.get());
+
+    ASSERT_MSG((mbedtls_cipher_setup(
+            encryption_context.get(),
+            mbedtls_cipher_info_from_type(static_cast<mbedtls_cipher_type_t>(mode))) ||
+                mbedtls_cipher_setup(decryption_context.get(),
+                                     mbedtls_cipher_info_from_type(
+                                             static_cast<mbedtls_cipher_type_t>(mode)))) == 0,
+               "Failed to initialize mbedtls ciphers.");
+
+    ASSERT(
+            !mbedtls_cipher_setkey(encryption_context.get(), key.data(), KeySize * 8, MBEDTLS_ENCRYPT));
+    ASSERT(
+            !mbedtls_cipher_setkey(decryption_context.get(), key.data(), KeySize * 8, MBEDTLS_DECRYPT));
+    //"Failed to set key on mbedtls ciphers.");
+}
+
+template<typename Key, size_t KeySize>
+AESCipher<Key, KeySize>::~AESCipher() {
+    mbedtls_cipher_free(encryption_context.get());
+    mbedtls_cipher_free(decryption_context.get());
+}
+
+template<typename Key, size_t KeySize>
+void AESCipher<Key, KeySize>::SetIV(std::vector<u8> iv) {
+    ASSERT_MSG((mbedtls_cipher_set_iv(encryption_context.get(), iv.data(), iv.size()) ||
+                mbedtls_cipher_set_iv(decryption_context.get(), iv.data(), iv.size())) == 0,
+               "Failed to set IV on mbedtls ciphers.");
+}
+
+template<typename Key, size_t KeySize>
+void AESCipher<Key, KeySize>::Transcode(const u8* src, size_t size, u8* dest, Op op)  {
+    size_t written = 0;
+
+    const auto context = op == Op::Encrypt ? encryption_context.get() : decryption_context.get();
+
+    mbedtls_cipher_reset(context);
+
+    if (mbedtls_cipher_get_cipher_mode(context) == MBEDTLS_MODE_XTS) {
+        mbedtls_cipher_update(context, src, size,
+                              dest, &written);
+        if (written != size)
+            LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.",
+                        size, written);
+    } else {
+        const auto block_size = mbedtls_cipher_get_block_size(context);
+
+        for (size_t offset = 0; offset < size; offset += block_size) {
+            auto length = std::min<size_t>(block_size, size - offset);
+            mbedtls_cipher_update(context, src + offset, length,
+                                  dest + offset, &written);
+            if (written != length)
+                LOG_WARNING(Crypto,
+                            "Not all data was decrypted requested={:016X}, actual={:016X}.",
+                            length, written);
+        }
+    }
+
+    mbedtls_cipher_finish(context, nullptr, nullptr);
+}
+
+template<typename Key, size_t KeySize>
+void AESCipher<Key, KeySize>::XTSTranscode(const u8* src, size_t size, u8* dest, size_t sector_id, size_t sector_size,
+                                           Op op) {
+    if (size % sector_size > 0) {
+        LOG_CRITICAL(Crypto, "Data size must be a multiple of sector size.");
+        return;
+    }
+
+    for (size_t i = 0; i < size; i += sector_size) {
+        SetIV(CalculateNintendoTweak(sector_id++));
+        Transcode<u8, u8>(src + i, sector_size,
+                          dest + i, op);
+    }
+}
+
+template<typename Key, size_t KeySize>
+std::vector<u8> AESCipher<Key, KeySize>::CalculateNintendoTweak(size_t sector_id) {
+    std::vector<u8> out(0x10);
+    for (size_t i = 0xF; i <= 0xF; --i) {
+        out[i] = sector_id & 0xFF;
+        sector_id >>= 8;
+    }
+    return out;
+}
+
+template class AESCipher<Key128>;
+template class AESCipher<Key256>;
+}
\ No newline at end of file
diff --git a/src/core/crypto/aes_util.h b/src/core/crypto/aes_util.h
index 9807b9234..5c09718b2 100644
--- a/src/core/crypto/aes_util.h
+++ b/src/core/crypto/aes_util.h
@@ -6,113 +6,53 @@
 
 #include "common/assert.h"
 #include "core/file_sys/vfs.h"
-#include "mbedtls/cipher.h"
 
-namespace Crypto {
+namespace Core::Crypto {
 
 enum class Mode {
-    CTR = MBEDTLS_CIPHER_AES_128_CTR,
-    ECB = MBEDTLS_CIPHER_AES_128_ECB,
-    XTS = MBEDTLS_CIPHER_AES_128_XTS,
+    CTR = 11,
+    ECB = 2,
+    XTS = 70,
 };
 
 enum class Op {
-    ENCRYPT,
-    DECRYPT,
+    Encrypt,
+    Decrypt,
 };
 
+struct mbedtls_cipher_context_t;
+
 template <typename Key, size_t KeySize = sizeof(Key)>
-struct AESCipher {
+class AESCipher {
     static_assert(std::is_same_v<Key, std::array<u8, KeySize>>, "Key must be std::array of u8.");
     static_assert(KeySize == 0x10 || KeySize == 0x20, "KeySize must be 128 or 256.");
 
-    AESCipher(Key key, Mode mode) {
-        mbedtls_cipher_init(&encryption_context);
-        mbedtls_cipher_init(&decryption_context);
+public:
+    AESCipher(Key key, Mode mode);
 
-        ASSERT_MSG((mbedtls_cipher_setup(
-                        &encryption_context,
-                        mbedtls_cipher_info_from_type(static_cast<mbedtls_cipher_type_t>(mode))) ||
-                    mbedtls_cipher_setup(&decryption_context,
-                                         mbedtls_cipher_info_from_type(
-                                             static_cast<mbedtls_cipher_type_t>(mode)))) == 0,
-                   "Failed to initialize mbedtls ciphers.");
+    ~AESCipher();
 
-        ASSERT(
-            !mbedtls_cipher_setkey(&encryption_context, key.data(), KeySize * 8, MBEDTLS_ENCRYPT));
-        ASSERT(
-            !mbedtls_cipher_setkey(&decryption_context, key.data(), KeySize * 8, MBEDTLS_DECRYPT));
-        //"Failed to set key on mbedtls ciphers.");
-    }
-
-    ~AESCipher() {
-        mbedtls_cipher_free(&encryption_context);
-        mbedtls_cipher_free(&decryption_context);
-    }
-
-    void SetIV(std::vector<u8> iv) {
-        ASSERT_MSG((mbedtls_cipher_set_iv(&encryption_context, iv.data(), iv.size()) ||
-                    mbedtls_cipher_set_iv(&decryption_context, iv.data(), iv.size())) == 0,
-                   "Failed to set IV on mbedtls ciphers.");
-    }
+    void SetIV(std::vector<u8> iv);
 
     template <typename Source, typename Dest>
     void Transcode(const Source* src, size_t size, Dest* dest, Op op) {
-        size_t written = 0;
-
-        const auto context = op == Op::ENCRYPT ? &encryption_context : &decryption_context;
-
-        mbedtls_cipher_reset(context);
-
-        if (mbedtls_cipher_get_cipher_mode(context) == MBEDTLS_MODE_XTS) {
-            mbedtls_cipher_update(context, reinterpret_cast<const u8*>(src), size,
-                                  reinterpret_cast<u8*>(dest), &written);
-            if (written != size)
-                LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.",
-                            size, written);
-        } else {
-            const auto block_size = mbedtls_cipher_get_block_size(context);
-
-            for (size_t offset = 0; offset < size; offset += block_size) {
-                auto length = std::min<size_t>(block_size, size - offset);
-                mbedtls_cipher_update(context, reinterpret_cast<const u8*>(src) + offset, length,
-                                      reinterpret_cast<u8*>(dest) + offset, &written);
-                if (written != length)
-                    LOG_WARNING(Crypto,
-                                "Not all data was decrypted requested={:016X}, actual={:016X}.",
-                                length, written);
-            }
-        }
-
-        mbedtls_cipher_finish(context, nullptr, nullptr);
+        Transcode(reinterpret_cast<const u8*>(src), size, reinterpret_cast<u8*>(dest), op);
     }
 
+    void Transcode(const u8* src, size_t size, u8* dest, Op op);
+
     template <typename Source, typename Dest>
     void XTSTranscode(const Source* src, size_t size, Dest* dest, size_t sector_id,
                       size_t sector_size, Op op) {
-        if (size % sector_size > 0) {
-            LOG_CRITICAL(Crypto, "Data size must be a multiple of sector size.");
-            return;
-        }
-
-        for (size_t i = 0; i < size; i += sector_size) {
-            SetIV(CalculateNintendoTweak(sector_id++));
-            Transcode<u8, u8>(reinterpret_cast<const u8*>(src) + i, sector_size,
-                              reinterpret_cast<u8*>(dest) + i, op);
-        }
+        XTSTranscode(reinterpret_cast<const u8*>(src), size, reinterpret_cast<u8*>(dest), sector_id, sector_size, op);
     }
 
+    void XTSTranscode(const u8* src, size_t size, u8* dest, size_t sector_id, size_t sector_size, Op op);
+
 private:
-    mbedtls_cipher_context_t encryption_context;
-    mbedtls_cipher_context_t decryption_context;
+    std::unique_ptr<mbedtls_cipher_context_t> encryption_context;
+    std::unique_ptr<mbedtls_cipher_context_t> decryption_context;
 
-    static std::vector<u8> CalculateNintendoTweak(size_t sector_id) {
-        std::vector<u8> out(0x10);
-        for (size_t i = 0xF; i <= 0xF; --i) {
-            out[i] = sector_id & 0xFF;
-            sector_id >>= 8;
-        }
-        return out;
-    }
+    static std::vector<u8> CalculateNintendoTweak(size_t sector_id);
 };
 } // namespace Crypto
diff --git a/src/core/file_sys/content_archive.h b/src/core/file_sys/content_archive.h
index 1e8d9c8ae..d9ad3bf7e 100644
--- a/src/core/file_sys/content_archive.h
+++ b/src/core/file_sys/content_archive.h
@@ -9,6 +9,7 @@
 #include <string>
 #include <vector>
 
+#include "core/loader/loader.h"
 #include "common/common_funcs.h"
 #include "common/common_types.h"
 #include "common/swap.h"
@@ -108,7 +109,7 @@ private:
 
     Crypto::Key128 GetKeyAreaKey(NCASectionCryptoType type);
 
-    VirtualFile Decrypt(NCASectionHeader header, VirtualFile in, size_t starting_offset);
+    VirtualFile Decrypt(NCASectionHeader header, VirtualFile in, u64 starting_offset);
 };
 
 } // namespace FileSys