diff options
Diffstat (limited to '')
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_schannel.cpp | 544 |
1 files changed, 544 insertions, 0 deletions
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..d8074339a --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -0,0 +1,544 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include <mutex> + +#include "common/error.h" +#include "common/fs/file.h" +#include "common/hex_util.h" +#include "common/string_util.h" + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +namespace { + +// These includes are inside the namespace to avoid a conflict on MinGW where +// the headers define an enum containing Network and Service as enumerators +// (which clash with the correspondingly named namespaces). +#define SECURITY_WIN32 +#include <schnlsp.h> +#include <security.h> +#include <wincrypt.h> + +std::once_flag one_time_init_flag; +bool one_time_init_success = false; + +SCHANNEL_CRED schannel_cred{}; +CredHandle cred_handle; + +static void OneTimeInit() { + schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; + schannel_cred.dwFlags = + SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols + SCH_CRED_AUTO_CRED_VALIDATION | // validate certs + SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate + // ^ I'm assuming that nobody would want to connect Yuzu to a + // service that requires some OS-provided corporate client + // certificate, and presenting one to some arbitrary server + // might be a privacy concern? Who knows, though. + + const SECURITY_STATUS ret = + AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, + nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); + if (ret != SEC_E_OK) { + // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. + LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", + Common::NativeErrorToString(ret)); + return; + } + + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " + "keys; not logging keys!"); + // Not fatal. + } + + one_time_init_success = true; +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSchannel final : public SSLConnectionBackend { +public: + Result Init() { + std::call_once(one_time_init_flag, OneTimeInit); + + if (!one_time_init_success) { + LOG_ERROR( + Service_SSL, + "Can't create SSL connection because Schannel one-time initialization failed"); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); + } + + Result SetHostName(const std::string& hostname_in) override { + hostname = hostname_in; + return ResultSuccess; + } + + Result DoHandshake() override { + while (1) { + Result r; + switch (handshake_state) { + case HandshakeState::Initial: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::ContinueNeeded: + case HandshakeState::IncompleteMessage: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || + (r = FillCiphertextReadBuf()) != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); + return ResultInternalError; + } + if ((r = CallInitializeSecurityContext()) != ResultSuccess) { + return r; + } + // CallInitializeSecurityContext updated `handshake_state`. + continue; + case HandshakeState::DoneAfterFlush: + if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { + return r; + } + handshake_state = HandshakeState::Connected; + return ResultSuccess; + case HandshakeState::Connected: + LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); + return ResultInternalError; + case HandshakeState::Error: + return ResultInternalError; + } + } + } + + Result FillCiphertextReadBuf() { + const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; + read_buf_fill_size = 0; + // This unnecessarily zeroes the buffer; oh well. + const size_t offset = ciphertext_read_buf.size(); + ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); + ciphertext_read_buf.resize(offset + fill_size, 0); + const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); + const auto [actual, err] = socket->Recv(0, read_span); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= fill_size); + ciphertext_read_buf.resize(offset + actual); + return ResultSuccess; + case Network::Errno::AGAIN: + ciphertext_read_buf.resize(offset); + return ResultWouldBlock; + default: + ciphertext_read_buf.resize(offset); + LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); + return ResultInternalError; + } + } + + // Returns success if the write buffer has been completely emptied. + Result FlushCiphertextWriteBuf() { + while (!ciphertext_write_buf.empty()) { + const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); + switch (err) { + case Network::Errno::SUCCESS: + ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); + ciphertext_write_buf.erase(ciphertext_write_buf.begin(), + ciphertext_write_buf.begin() + actual); + break; + case Network::Errno::AGAIN: + return ResultWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); + return ResultInternalError; + } + } + return ResultSuccess; + } + + Result CallInitializeSecurityContext() { + const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | + ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | + ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | + ISC_REQ_USE_SUPPLIED_CREDS; + unsigned long attr; + // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel + std::array<SecBuffer, 2> input_buffers{{ + // only used if `initial_call_done` + { + // [0] + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = ciphertext_read_buf.data(), + }, + { + // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is + // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the + // whole buffer wasn't used) + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }, + }}; + std::array<SecBuffer, 2> output_buffers{{ + { + .cbBuffer = 0, + .BufferType = SECBUFFER_TOKEN, + .pvBuffer = nullptr, + }, // [0] + { + .cbBuffer = 0, + .BufferType = SECBUFFER_ALERT, + .pvBuffer = nullptr, + }, // [1] + }}; + SecBufferDesc input_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(input_buffers.size()), + .pBuffers = input_buffers.data(), + }; + SecBufferDesc output_desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(output_buffers.size()), + .pBuffers = output_buffers.data(), + }; + ASSERT_OR_EXECUTE_MSG( + input_buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + + bool initial_call_done = handshake_state != HandshakeState::Initial; + if (initial_call_done) { + LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", + ciphertext_read_buf.size()); + } + + const SECURITY_STATUS ret = + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, + // Caller ensured we have set a hostname: + const_cast<char*>(hostname.value().c_str()), req, + 0, // Reserved1 + 0, // TargetDataRep not used with Schannel + initial_call_done ? &input_desc : nullptr, + 0, // Reserved2 + initial_call_done ? nullptr : &ctxt, &output_desc, &attr, + nullptr); // ptsExpiry + + if (output_buffers[0].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), + output_buffers[0].cbBuffer); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); + FreeContextBuffer(output_buffers[0].pvBuffer); + } + + if (output_buffers[1].pvBuffer) { + const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), + output_buffers[1].cbBuffer); + // The documentation doesn't explain what format this data is in. + LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), + Common::HexToString(span)); + } + + switch (ret) { + case SEC_I_CONTINUE_NEEDED: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); + if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { + LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - input_buffers[1].cbBuffer); + } else { + ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + handshake_state = HandshakeState::ContinueNeeded; + return ResultSuccess; + case SEC_E_INCOMPLETE_MESSAGE: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); + ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); + read_buf_fill_size = input_buffers[1].cbBuffer; + handshake_state = HandshakeState::IncompleteMessage; + return ResultSuccess; + case SEC_E_OK: + LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); + ciphertext_read_buf.clear(); + handshake_state = HandshakeState::DoneAfterFlush; + return GrabStreamSizes(); + default: + LOG_ERROR(Service_SSL, + "InitializeSecurityContext failed (probably certificate/protocol issue): {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + } + + Result GrabStreamSizes() { + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", + Common::NativeErrorToString(ret)); + handshake_state = HandshakeState::Error; + return ResultInternalError; + } + return ResultSuccess; + } + + ResultVal<size_t> Read(std::span<u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0 || got_read_eof) { + return size_t(0); + } + while (1) { + if (!cleartext_read_buf.empty()) { + const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf.data(), read_size); + cleartext_read_buf.erase(cleartext_read_buf.begin(), + cleartext_read_buf.begin() + read_size); + return read_size; + } + if (!ciphertext_read_buf.empty()) { + SecBuffer empty{ + .cbBuffer = 0, + .BufferType = SECBUFFER_EMPTY, + .pvBuffer = nullptr, + }; + std::array<SecBuffer, 5> buffers{{ + { + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = ciphertext_read_buf.data(), + }, + empty, + empty, + empty, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[0].cbBuffer == ciphertext_read_buf.size(), + { return ResultInternalError; }, "read buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + SECURITY_STATUS ret = + DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + switch (ret) { + case SEC_E_OK: + ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, + { return ResultInternalError; }); + ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, + { return ResultInternalError; }); + cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), + static_cast<u8*>(buffers[1].pvBuffer) + + buffers[1].cbBuffer); + if (buffers[3].BufferType == SECBUFFER_EXTRA) { + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - buffers[3].cbBuffer); + } else { + ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); + ciphertext_read_buf.clear(); + } + continue; + case SEC_E_INCOMPLETE_MESSAGE: + break; + case SEC_I_CONTEXT_EXPIRED: + // Server hung up by sending close_notify. + got_read_eof = true; + return size_t(0); + default: + LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + } + const Result r = FillCiphertextReadBuf(); + if (r != ResultSuccess) { + return r; + } + if (ciphertext_read_buf.empty()) { + got_read_eof = true; + return size_t(0); + } + } + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + if (handshake_state != HandshakeState::Connected) { + LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); + return ResultInternalError; + } + if (data.size() == 0) { + return size_t(0); + } + data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); + if (!cleartext_write_buf.empty()) { + // Already in the middle of a write. It wouldn't make sense to not + // finish sending the entire buffer since TLS has + // header/MAC/padding/etc. + if (data.size() != cleartext_write_buf.size() || + std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { + LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); + return ResultInternalError; + } + return WriteAlreadyEncryptedData(); + } else { + cleartext_write_buf.assign(data.begin(), data.end()); + } + + std::vector<u8> header_buf(stream_sizes.cbHeader, 0); + std::vector<u8> tmp_data_buf = cleartext_write_buf; + std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); + + std::array<SecBuffer, 3> buffers{{ + { + .cbBuffer = stream_sizes.cbHeader, + .BufferType = SECBUFFER_STREAM_HEADER, + .pvBuffer = header_buf.data(), + }, + { + .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), + .BufferType = SECBUFFER_DATA, + .pvBuffer = tmp_data_buf.data(), + }, + { + .cbBuffer = stream_sizes.cbTrailer, + .BufferType = SECBUFFER_STREAM_TRAILER, + .pvBuffer = trailer_buf.data(), + }, + }}; + ASSERT_OR_EXECUTE_MSG( + buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, + "temp buffer too large"); + SecBufferDesc desc{ + .ulVersion = SECBUFFER_VERSION, + .cBuffers = static_cast<unsigned long>(buffers.size()), + .pBuffers = buffers.data(), + }; + + const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); + return ResultInternalError; + } + ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), + trailer_buf.end()); + return WriteAlreadyEncryptedData(); + } + + ResultVal<size_t> WriteAlreadyEncryptedData() { + const Result r = FlushCiphertextWriteBuf(); + if (r != ResultSuccess) { + return r; + } + // write buf is empty + const size_t cleartext_bytes_written = cleartext_write_buf.size(); + cleartext_write_buf.clear(); + return cleartext_bytes_written; + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + PCCERT_CONTEXT returned_cert = nullptr; + const SECURITY_STATUS ret = + QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + if (ret != SEC_E_OK) { + LOG_ERROR(Service_SSL, + "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", + Common::NativeErrorToString(ret)); + return ResultInternalError; + } + PCCERT_CONTEXT some_cert = nullptr; + std::vector<std::vector<u8>> certs; + while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { + certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), + static_cast<u8*>(some_cert->pbCertEncoded) + + some_cert->cbCertEncoded); + } + std::reverse(certs.begin(), + certs.end()); // Windows returns certs in reverse order from what we want + CertFreeCertificateContext(returned_cert); + return certs; + } + + ~SSLConnectionBackendSchannel() { + if (handshake_state != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt); + } + } + + enum class HandshakeState { + // Haven't called anything yet. + Initial, + // `SEC_I_CONTINUE_NEEDED` was returned by + // `InitializeSecurityContext`; must finish sending data (if any) in + // the write buffer, then read at least one byte before calling + // `InitializeSecurityContext` again. + ContinueNeeded, + // `SEC_E_INCOMPLETE_MESSAGE` was returned by + // `InitializeSecurityContext`; hopefully the write buffer is empty; + // must read at least one byte before calling + // `InitializeSecurityContext` again. + IncompleteMessage, + // `SEC_E_OK` was returned by `InitializeSecurityContext`; must + // finish sending data in the write buffer before having `DoHandshake` + // report success. + DoneAfterFlush, + // We finished the above and are now connected. At this point, writing + // and reading are separate 'state machines' represented by the + // nonemptiness of the ciphertext and cleartext read and write buffers. + Connected, + // Another error was returned and we shouldn't allow initialization + // to continue. + Error, + } handshake_state = HandshakeState::Initial; + + CtxtHandle ctxt; + SecPkgContext_StreamSizes stream_sizes; + + std::shared_ptr<Network::SocketBase> socket; + std::optional<std::string> hostname; + + std::vector<u8> ciphertext_read_buf; + std::vector<u8> ciphertext_write_buf; + std::vector<u8> cleartext_read_buf; + std::vector<u8> cleartext_write_buf; + + bool got_read_eof = false; + size_t read_buf_fill_size = 0; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSchannel>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL |