diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_securetransport.cpp | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp new file mode 100644 index 000000000..be40a5aeb --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include "core/hle/service/ssl/ssl_backend.h" +#include "core/internal_network/network.h" +#include "core/internal_network/sockets.h" + +#include <mutex> + +#include <Security/SecureTransport.h> + +// SecureTransport has been deprecated in its entirety in favor of +// Network.framework, but that does not allow layering TLS on top of an +// arbitrary socket. +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +namespace { + +template <typename T> +struct CFReleaser { + T ptr; + + YUZU_NON_COPYABLE(CFReleaser); + constexpr CFReleaser() : ptr(nullptr) {} + constexpr CFReleaser(T ptr) : ptr(ptr) {} + constexpr operator T() { + return ptr; + } + ~CFReleaser() { + if (ptr) { + CFRelease(ptr); + } + } +}; + +std::string CFStringToString(CFStringRef cfstr) { + CFReleaser<CFDataRef> cfdata( + CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); + ASSERT_OR_EXECUTE(cfdata, { return "???"; }); + return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), + CFDataGetLength(cfdata)); +} + +std::string OSStatusToString(OSStatus status) { + CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); + if (!cfstr) { + return "[unknown error]"; + } + return CFStringToString(cfstr); +} + +} // namespace + +namespace Service::SSL { + +class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { +public: + Result Init() { + static std::once_flag once_flag; + std::call_once(once_flag, []() { + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " + "support exporting keys; not logging keys!"); + // Not fatal. + } + }); + + context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); + if (!context) { + LOG_ERROR(Service_SSL, "SSLCreateContext failed"); + return ResultInternalError; + } + + OSStatus status; + if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || + (status = SSLSetConnection(context, this))) { + LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", + OSStatusToString(status)); + return ResultInternalError; + } + + return ResultSuccess; + } + + void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { + socket = std::move(in_socket); + } + + Result SetHostName(const std::string& hostname) override { + OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); + if (status) { + LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + return ResultSuccess; + } + + Result DoHandshake() override { + OSStatus status = SSLHandshake(context); + return HandleReturn("SSLHandshake", 0, status).Code(); + } + + ResultVal<size_t> Read(std::span<u8> data) override { + size_t actual; + OSStatus status = SSLRead(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLRead", actual, status); + } + + ResultVal<size_t> Write(std::span<const u8> data) override { + size_t actual; + OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); + ; + return HandleReturn("SSLWrite", actual, status); + } + + ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { + switch (status) { + case 0: + return actual; + case errSSLWouldBlock: + return ResultWouldBlock; + default: { + std::string reason; + if (got_read_eof) { + reason = "server hung up"; + } else { + reason = OSStatusToString(status); + } + LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); + return ResultInternalError; + } + } + } + + ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { + CFReleaser<SecTrustRef> trust; + OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); + if (status) { + LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); + return ResultInternalError; + } + std::vector<std::vector<u8>> ret; + for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { + SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); + CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); + ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); + const u8* ptr = CFDataGetBytePtr(data); + ret.emplace_back(ptr, ptr + CFDataGetLength(data)); + } + return ret; + } + + static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { + return ReadOrWriteCallback(connection, data, dataLength, true); + } + + static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, + size_t* dataLength) { + return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); + } + + static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, + bool is_read) { + auto self = + static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); + ASSERT_OR_EXECUTE_MSG( + self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", + is_read ? "read" : "write"); + + // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are + // expected to read/write the full requested dataLength or return an + // error, so we have to add a loop ourselves. + size_t requested_len = *dataLength; + size_t offset = 0; + while (offset < requested_len) { + std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); + auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); + LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, + actual, cur.size(), static_cast<s32>(err)); + switch (err) { + case Network::Errno::SUCCESS: + offset += actual; + if (actual == 0) { + ASSERT(is_read); + self->got_read_eof = true; + return errSecEndOfData; + } + break; + case Network::Errno::AGAIN: + *dataLength = offset; + return errSSLWouldBlock; + default: + LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", + is_read ? "recv" : "send", err); + return errSecIO; + } + } + ASSERT(offset == requested_len); + return 0; + } + +private: + CFReleaser<SSLContextRef> context = nullptr; + bool got_read_eof = false; + + std::shared_ptr<Network::SocketBase> socket; +}; + +ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { + auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); + const Result res = conn->Init(); + if (res.IsFailure()) { + return res; + } + return conn; +} + +} // namespace Service::SSL |