summaryrefslogblamecommitdiffstats
path: root/src/core/hle/service/ssl/ssl_backend_schannel.cpp
blob: a1d6a186ef736504afa5914024a7b01293c9cb9b (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14













                                                               




                                                                             
                      
                    
                     
 


                                   




                                                    



                                                                                           



                                                                 
 
                               








                                                                                                   





                                                                                                  





















                                                                                               

                                                                             

     

                                                                 





                                   
                                      




                                                                             
                                                                           






                                                                       
                                                  





                                                                                          
                                                                           




                                                                       
                                                            










                                                                                      

                                                                                
                                                         
                                                         
                                                                                         


                                                                                         


                                                             
                                                        

                                   
                                               

                                    
                                               






                                                                                  

                                                                             

                                         


                                                                                   











                                                                                      



                                                                                     





                                                                                                       
                                                                                   
                                              
                                                       




                                                                                                 
                              
                                              
                                    



                                                 
                              
                                              
                                    

                     
                              
                                              
                                    












                                                                          
                                                                    

                                                                      
                                                                            

                                                                                     
                                                  

         
                                   
                                                                                         
                                                                                
                                                                                        



                                                                                 
                                                                                                


                                                             

                                                                              
                                                                                              



                                                          

                                                                              









                                                                                         


                                                                                                 

                                                                       
                                            
             
                                                             



                                                                                            

                                                                


                                                                            

                                                             




                                                                                                   
                                                    




                                       
                                   
                                                                                   


                                                                                                 
                                                    





                                                         
                                                           


                                                                                        
                                               


                             




                                                                                          

                                 
                                               




                                                  

                                                  
                                                                                           
                                                     
                                                               
                      


                          

                                      
                                                                      






                                                                              
                                                                                        







                                                                                        


                                                                                     
                                                                   


                                                                                                   

                                                                         
                                                    





                                                              
                                        






                                                                       
                                                     


                                     

                                              





                                                                
                                                           





                                                                                         

                                                                                             


                                                                               

                                                                                    




                                                                                                 
                                                                 

         


                                                               


                                          
                                                  








                                                                            
                                                   












                                                                                        
                                                                                                 



                                                                                                  





                                                                                     



                                                   
                                                   



                                 

                                                                          




                                                                       
                                   
                                                                                           



















                                                                                                 

                                                         


























                                                                               
                                                
 

                                           
 

                                                
 



                                         
 

                                  



                                                                               
                                    






                           
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later

#include "core/hle/service/ssl/ssl_backend.h"
#include "core/internal_network/network.h"
#include "core/internal_network/sockets.h"

#include "common/error.h"
#include "common/fs/file.h"
#include "common/hex_util.h"
#include "common/string_util.h"

#include <mutex>

namespace {

// These includes are inside the namespace to avoid a conflict on MinGW where
// the headers define an enum containing Network and Service as enumerators
// (which clash with the correspondingly named namespaces).
#define SECURITY_WIN32
#include <schnlsp.h>
#include <security.h>

std::once_flag one_time_init_flag;
bool one_time_init_success = false;

SCHANNEL_CRED schannel_cred{};
CredHandle cred_handle;

static void OneTimeInit() {
    schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
    schannel_cred.dwFlags =
        SCH_USE_STRONG_CRYPTO |         // don't allow insecure protocols
        SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
        SCH_CRED_NO_DEFAULT_CREDS;      // don't automatically present a client certificate
    // ^ I'm assuming that nobody would want to connect Yuzu to a
    // service that requires some OS-provided corporate client
    // certificate, and presenting one to some arbitrary server
    // might be a privacy concern?  Who knows, though.

    const SECURITY_STATUS ret =
        AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
                                 nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
    if (ret != SEC_E_OK) {
        // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
        LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
                  Common::NativeErrorToString(ret));
        return;
    }

    if (getenv("SSLKEYLOGFILE")) {
        LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
                                  "keys; not logging keys!");
        // Not fatal.
    }

    one_time_init_success = true;
}

} // namespace

namespace Service::SSL {

class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
public:
    Result Init() {
        std::call_once(one_time_init_flag, OneTimeInit);

        if (!one_time_init_success) {
            LOG_ERROR(
                Service_SSL,
                "Can't create SSL connection because Schannel one-time initialization failed");
            return ResultInternalError;
        }

        return ResultSuccess;
    }

    void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
        socket = std::move(socket_in);
    }

    Result SetHostName(const std::string& hostname_in) override {
        hostname = hostname_in;
        return ResultSuccess;
    }

    Result DoHandshake() override {
        while (1) {
            Result r;
            switch (handshake_state) {
            case HandshakeState::Initial:
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
                    (r = CallInitializeSecurityContext()) != ResultSuccess) {
                    return r;
                }
                // CallInitializeSecurityContext updated `handshake_state`.
                continue;
            case HandshakeState::ContinueNeeded:
            case HandshakeState::IncompleteMessage:
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
                    (r = FillCiphertextReadBuf()) != ResultSuccess) {
                    return r;
                }
                if (ciphertext_read_buf.empty()) {
                    LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
                    return ResultInternalError;
                }
                if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
                    return r;
                }
                // CallInitializeSecurityContext updated `handshake_state`.
                continue;
            case HandshakeState::DoneAfterFlush:
                if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
                    return r;
                }
                handshake_state = HandshakeState::Connected;
                return ResultSuccess;
            case HandshakeState::Connected:
                LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
                return ResultInternalError;
            case HandshakeState::Error:
                return ResultInternalError;
            }
        }
    }

    Result FillCiphertextReadBuf() {
        const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
        read_buf_fill_size = 0;
        // This unnecessarily zeroes the buffer; oh well.
        const size_t offset = ciphertext_read_buf.size();
        ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
        ciphertext_read_buf.resize(offset + fill_size, 0);
        const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
        const auto [actual, err] = socket->Recv(0, read_span);
        switch (err) {
        case Network::Errno::SUCCESS:
            ASSERT(static_cast<size_t>(actual) <= fill_size);
            ciphertext_read_buf.resize(offset + actual);
            return ResultSuccess;
        case Network::Errno::AGAIN:
            ciphertext_read_buf.resize(offset);
            return ResultWouldBlock;
        default:
            ciphertext_read_buf.resize(offset);
            LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
            return ResultInternalError;
        }
    }

    // Returns success if the write buffer has been completely emptied.
    Result FlushCiphertextWriteBuf() {
        while (!ciphertext_write_buf.empty()) {
            const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
            switch (err) {
            case Network::Errno::SUCCESS:
                ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
                ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
                                           ciphertext_write_buf.begin() + actual);
                break;
            case Network::Errno::AGAIN:
                return ResultWouldBlock;
            default:
                LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
                return ResultInternalError;
            }
        }
        return ResultSuccess;
    }

    Result CallInitializeSecurityContext() {
        const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
                                  ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
                                  ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
                                  ISC_REQ_USE_SUPPLIED_CREDS;
        unsigned long attr;
        // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
        std::array<SecBuffer, 2> input_buffers{{
            // only used if `initial_call_done`
            {
                // [0]
                .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
                .BufferType = SECBUFFER_TOKEN,
                .pvBuffer = ciphertext_read_buf.data(),
            },
            {
                // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
                //     returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
                //     whole buffer wasn't used)
                .cbBuffer = 0,
                .BufferType = SECBUFFER_EMPTY,
                .pvBuffer = nullptr,
            },
        }};
        std::array<SecBuffer, 2> output_buffers{{
            {
                .cbBuffer = 0,
                .BufferType = SECBUFFER_TOKEN,
                .pvBuffer = nullptr,
            }, // [0]
            {
                .cbBuffer = 0,
                .BufferType = SECBUFFER_ALERT,
                .pvBuffer = nullptr,
            }, // [1]
        }};
        SecBufferDesc input_desc{
            .ulVersion = SECBUFFER_VERSION,
            .cBuffers = static_cast<unsigned long>(input_buffers.size()),
            .pBuffers = input_buffers.data(),
        };
        SecBufferDesc output_desc{
            .ulVersion = SECBUFFER_VERSION,
            .cBuffers = static_cast<unsigned long>(output_buffers.size()),
            .pBuffers = output_buffers.data(),
        };
        ASSERT_OR_EXECUTE_MSG(
            input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
            { return ResultInternalError; }, "read buffer too large");

        bool initial_call_done = handshake_state != HandshakeState::Initial;
        if (initial_call_done) {
            LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
                      ciphertext_read_buf.size());
        }

        const SECURITY_STATUS ret =
            InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
                                       // Caller ensured we have set a hostname:
                                       const_cast<char*>(hostname.value().c_str()), req,
                                       0, // Reserved1
                                       0, // TargetDataRep not used with Schannel
                                       initial_call_done ? &input_desc : nullptr,
                                       0, // Reserved2
                                       initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
                                       nullptr); // ptsExpiry

        if (output_buffers[0].pvBuffer) {
            const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
                                 output_buffers[0].cbBuffer);
            ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
            FreeContextBuffer(output_buffers[0].pvBuffer);
        }

        if (output_buffers[1].pvBuffer) {
            const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
                                 output_buffers[1].cbBuffer);
            // The documentation doesn't explain what format this data is in.
            LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
                      Common::HexToString(span));
        }

        switch (ret) {
        case SEC_I_CONTINUE_NEEDED:
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
            if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
                LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
                ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
                ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
                                          ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
            } else {
                ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
                ciphertext_read_buf.clear();
            }
            handshake_state = HandshakeState::ContinueNeeded;
            return ResultSuccess;
        case SEC_E_INCOMPLETE_MESSAGE:
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
            ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
            read_buf_fill_size = input_buffers[1].cbBuffer;
            handshake_state = HandshakeState::IncompleteMessage;
            return ResultSuccess;
        case SEC_E_OK:
            LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
            ciphertext_read_buf.clear();
            handshake_state = HandshakeState::DoneAfterFlush;
            return GrabStreamSizes();
        default:
            LOG_ERROR(Service_SSL,
                      "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
                      Common::NativeErrorToString(ret));
            handshake_state = HandshakeState::Error;
            return ResultInternalError;
        }
    }

    Result GrabStreamSizes() {
        const SECURITY_STATUS ret =
            QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
        if (ret != SEC_E_OK) {
            LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
                      Common::NativeErrorToString(ret));
            handshake_state = HandshakeState::Error;
            return ResultInternalError;
        }
        return ResultSuccess;
    }

    ResultVal<size_t> Read(std::span<u8> data) override {
        if (handshake_state != HandshakeState::Connected) {
            LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
            return ResultInternalError;
        }
        if (data.size() == 0 || got_read_eof) {
            return size_t(0);
        }
        while (1) {
            if (!cleartext_read_buf.empty()) {
                const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
                std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
                cleartext_read_buf.erase(cleartext_read_buf.begin(),
                                         cleartext_read_buf.begin() + read_size);
                return read_size;
            }
            if (!ciphertext_read_buf.empty()) {
                SecBuffer empty{
                    .cbBuffer = 0,
                    .BufferType = SECBUFFER_EMPTY,
                    .pvBuffer = nullptr,
                };
                std::array<SecBuffer, 5> buffers{{
                    {
                        .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
                        .BufferType = SECBUFFER_DATA,
                        .pvBuffer = ciphertext_read_buf.data(),
                    },
                    empty,
                    empty,
                    empty,
                }};
                ASSERT_OR_EXECUTE_MSG(
                    buffers[0].cbBuffer == ciphertext_read_buf.size(),
                    { return ResultInternalError; }, "read buffer too large");
                SecBufferDesc desc{
                    .ulVersion = SECBUFFER_VERSION,
                    .cBuffers = static_cast<unsigned long>(buffers.size()),
                    .pBuffers = buffers.data(),
                };
                SECURITY_STATUS ret =
                    DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
                switch (ret) {
                case SEC_E_OK:
                    ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
                                      { return ResultInternalError; });
                    ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
                                      { return ResultInternalError; });
                    ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
                                      { return ResultInternalError; });
                    cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
                                              static_cast<u8*>(buffers[1].pvBuffer) +
                                                  buffers[1].cbBuffer);
                    if (buffers[3].BufferType == SECBUFFER_EXTRA) {
                        ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
                        ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
                                                  ciphertext_read_buf.end() - buffers[3].cbBuffer);
                    } else {
                        ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
                        ciphertext_read_buf.clear();
                    }
                    continue;
                case SEC_E_INCOMPLETE_MESSAGE:
                    break;
                case SEC_I_CONTEXT_EXPIRED:
                    // Server hung up by sending close_notify.
                    got_read_eof = true;
                    return size_t(0);
                default:
                    LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
                              Common::NativeErrorToString(ret));
                    return ResultInternalError;
                }
            }
            const Result r = FillCiphertextReadBuf();
            if (r != ResultSuccess) {
                return r;
            }
            if (ciphertext_read_buf.empty()) {
                got_read_eof = true;
                return size_t(0);
            }
        }
    }

    ResultVal<size_t> Write(std::span<const u8> data) override {
        if (handshake_state != HandshakeState::Connected) {
            LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
            return ResultInternalError;
        }
        if (data.size() == 0) {
            return size_t(0);
        }
        data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
        if (!cleartext_write_buf.empty()) {
            // Already in the middle of a write.  It wouldn't make sense to not
            // finish sending the entire buffer since TLS has
            // header/MAC/padding/etc.
            if (data.size() != cleartext_write_buf.size() ||
                std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
                LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
                return ResultInternalError;
            }
            return WriteAlreadyEncryptedData();
        } else {
            cleartext_write_buf.assign(data.begin(), data.end());
        }

        std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
        std::vector<u8> tmp_data_buf = cleartext_write_buf;
        std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);

        std::array<SecBuffer, 3> buffers{{
            {
                .cbBuffer = stream_sizes.cbHeader,
                .BufferType = SECBUFFER_STREAM_HEADER,
                .pvBuffer = header_buf.data(),
            },
            {
                .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
                .BufferType = SECBUFFER_DATA,
                .pvBuffer = tmp_data_buf.data(),
            },
            {
                .cbBuffer = stream_sizes.cbTrailer,
                .BufferType = SECBUFFER_STREAM_TRAILER,
                .pvBuffer = trailer_buf.data(),
            },
        }};
        ASSERT_OR_EXECUTE_MSG(
            buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
            "temp buffer too large");
        SecBufferDesc desc{
            .ulVersion = SECBUFFER_VERSION,
            .cBuffers = static_cast<unsigned long>(buffers.size()),
            .pBuffers = buffers.data(),
        };

        const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
        if (ret != SEC_E_OK) {
            LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
            return ResultInternalError;
        }
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
                                    header_buf.end());
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
                                    tmp_data_buf.end());
        ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
                                    trailer_buf.end());
        return WriteAlreadyEncryptedData();
    }

    ResultVal<size_t> WriteAlreadyEncryptedData() {
        const Result r = FlushCiphertextWriteBuf();
        if (r != ResultSuccess) {
            return r;
        }
        // write buf is empty
        const size_t cleartext_bytes_written = cleartext_write_buf.size();
        cleartext_write_buf.clear();
        return cleartext_bytes_written;
    }

    ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
        PCCERT_CONTEXT returned_cert = nullptr;
        const SECURITY_STATUS ret =
            QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
        if (ret != SEC_E_OK) {
            LOG_ERROR(Service_SSL,
                      "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
                      Common::NativeErrorToString(ret));
            return ResultInternalError;
        }
        PCCERT_CONTEXT some_cert = nullptr;
        std::vector<std::vector<u8>> certs;
        while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
            certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
                               static_cast<u8*>(some_cert->pbCertEncoded) +
                                   some_cert->cbCertEncoded);
        }
        std::reverse(certs.begin(),
                     certs.end()); // Windows returns certs in reverse order from what we want
        CertFreeCertificateContext(returned_cert);
        return certs;
    }

    ~SSLConnectionBackendSchannel() {
        if (handshake_state != HandshakeState::Initial) {
            DeleteSecurityContext(&ctxt);
        }
    }

    enum class HandshakeState {
        // Haven't called anything yet.
        Initial,
        // `SEC_I_CONTINUE_NEEDED` was returned by
        // `InitializeSecurityContext`; must finish sending data (if any) in
        // the write buffer, then read at least one byte before calling
        // `InitializeSecurityContext` again.
        ContinueNeeded,
        // `SEC_E_INCOMPLETE_MESSAGE` was returned by
        // `InitializeSecurityContext`; hopefully the write buffer is empty;
        // must read at least one byte before calling
        // `InitializeSecurityContext` again.
        IncompleteMessage,
        // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
        // finish sending data in the write buffer before having `DoHandshake`
        // report success.
        DoneAfterFlush,
        // We finished the above and are now connected.  At this point, writing
        // and reading are separate 'state machines' represented by the
        // nonemptiness of the ciphertext and cleartext read and write buffers.
        Connected,
        // Another error was returned and we shouldn't allow initialization
        // to continue.
        Error,
    } handshake_state = HandshakeState::Initial;

    CtxtHandle ctxt;
    SecPkgContext_StreamSizes stream_sizes;

    std::shared_ptr<Network::SocketBase> socket;
    std::optional<std::string> hostname;

    std::vector<u8> ciphertext_read_buf;
    std::vector<u8> ciphertext_write_buf;
    std::vector<u8> cleartext_read_buf;
    std::vector<u8> cleartext_write_buf;

    bool got_read_eof = false;
    size_t read_buf_fill_size = 0;
};

ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
    auto conn = std::make_unique<SSLConnectionBackendSchannel>();
    const Result res = conn->Init();
    if (res.IsFailure()) {
        return res;
    }
    return conn;
}

} // namespace Service::SSL