summaryrefslogtreecommitdiffstats
path: root/src/core/hw/aes/ccm.cpp
blob: dc7035ab62ebb58e3f2e868976490326ed0645a9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright 2017 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.

#include <algorithm>
#include <cryptopp/aes.h>
#include <cryptopp/ccm.h>
#include <cryptopp/cryptlib.h>
#include <cryptopp/filters.h>
#include "common/alignment.h"
#include "common/logging/log.h"
#include "core/hw/aes/ccm.h"
#include "core/hw/aes/key.h"

namespace HW {
namespace AES {

namespace {

// 3DS uses a non-standard AES-CCM algorithm, so we need to derive a sub class from the standard one
// and override with the non-standard part.
using CryptoPP::lword;
using CryptoPP::AES;
using CryptoPP::CCM_Final;
using CryptoPP::CCM_Base;
template <bool T_IsEncryption>
class CCM_3DSVariant_Final : public CCM_Final<AES, CCM_MAC_SIZE, T_IsEncryption> {
public:
    void UncheckedSpecifyDataLengths(lword header_length, lword message_length,
                                     lword footer_length) override {
        // 3DS uses the aligned size to generate B0 for authentication, instead of the original size
        lword aligned_message_length = Common::AlignUp(message_length, AES_BLOCK_SIZE);
        CCM_Base::UncheckedSpecifyDataLengths(header_length, aligned_message_length, footer_length);
        CCM_Base::m_messageLength = message_length; // restore the actual message size
    }
};

class CCM_3DSVariant {
public:
    using Encryption = CCM_3DSVariant_Final<true>;
    using Decryption = CCM_3DSVariant_Final<false>;
};

} // namespace

std::vector<u8> EncryptSignCCM(const std::vector<u8>& pdata, const CCMNonce& nonce,
                               size_t slot_id) {
    if (!IsNormalKeyAvailable(slot_id)) {
        LOG_ERROR(HW_AES, "Key slot %d not available. Will use zero key.", slot_id);
    }
    const AESKey normal = GetNormalKey(slot_id);
    std::vector<u8> cipher(pdata.size() + CCM_MAC_SIZE);

    try {
        CCM_3DSVariant::Encryption e;
        e.SetKeyWithIV(normal.data(), AES_BLOCK_SIZE, nonce.data(), CCM_NONCE_SIZE);
        e.SpecifyDataLengths(0, pdata.size(), 0);
        CryptoPP::ArraySource as(pdata.data(), pdata.size(), true,
                                 new CryptoPP::AuthenticatedEncryptionFilter(
                                     e, new CryptoPP::ArraySink(cipher.data(), cipher.size())));
    } catch (const CryptoPP::Exception& e) {
        LOG_ERROR(HW_AES, "FAILED with: %s", e.what());
    }
    return cipher;
}

std::vector<u8> DecryptVerifyCCM(const std::vector<u8>& cipher, const CCMNonce& nonce,
                                 size_t slot_id) {
    if (!IsNormalKeyAvailable(slot_id)) {
        LOG_ERROR(HW_AES, "Key slot %d not available. Will use zero key.", slot_id);
    }
    const AESKey normal = GetNormalKey(slot_id);
    const std::size_t pdata_size = cipher.size() - CCM_MAC_SIZE;
    std::vector<u8> pdata(pdata_size);

    try {
        CCM_3DSVariant::Decryption d;
        d.SetKeyWithIV(normal.data(), AES_BLOCK_SIZE, nonce.data(), CCM_NONCE_SIZE);
        d.SpecifyDataLengths(0, pdata_size, 0);
        CryptoPP::AuthenticatedDecryptionFilter df(
            d, new CryptoPP::ArraySink(pdata.data(), pdata_size));
        CryptoPP::ArraySource as(cipher.data(), cipher.size(), true, new CryptoPP::Redirector(df));
        if (!df.GetLastResult()) {
            LOG_ERROR(HW_AES, "FAILED");
            return {};
        }
    } catch (const CryptoPP::Exception& e) {
        LOG_ERROR(HW_AES, "FAILED with: %s", e.what());
        return {};
    }
    return pdata;
}

} // namespace AES
} // namespace HW