summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/android/app/src/main/jni/native.cpp2
-rw-r--r--src/android/app/src/main/res/values/strings.xml2
-rw-r--r--src/common/demangle.cpp2
-rw-r--r--src/common/detached_tasks.cpp4
-rw-r--r--src/common/settings.cpp3
-rw-r--r--src/common/socket_types.h17
-rw-r--r--src/common/time_zone.cpp47
-rw-r--r--src/core/CMakeLists.txt18
-rw-r--r--src/core/arm/dynarmic/arm_dynarmic_32.cpp6
-rw-r--r--src/core/arm/dynarmic/arm_dynarmic_32.h6
-rw-r--r--src/core/arm/dynarmic/arm_dynarmic_64.cpp6
-rw-r--r--src/core/arm/dynarmic/arm_dynarmic_64.h6
-rw-r--r--src/core/arm/dynarmic/dynarmic_exclusive_monitor.h5
-rw-r--r--src/core/core.cpp8
-rw-r--r--src/core/core.h3
-rw-r--r--src/core/debugger/gdbstub.cpp13
-rw-r--r--src/core/hle/kernel/k_code_memory.cpp17
-rw-r--r--src/core/hle/kernel/k_page_table.h41
-rw-r--r--src/core/hle/kernel/k_process.cpp2
-rw-r--r--src/core/hle/kernel/k_process.h10
-rw-r--r--src/core/hle/kernel/k_shared_memory.cpp6
-rw-r--r--src/core/hle/kernel/k_thread.cpp4
-rw-r--r--src/core/hle/kernel/k_thread_local_page.cpp10
-rw-r--r--src/core/hle/kernel/kernel.cpp6
-rw-r--r--src/core/hle/kernel/physical_core.cpp8
-rw-r--r--src/core/hle/kernel/svc/svc_cache.cpp2
-rw-r--r--src/core/hle/kernel/svc/svc_code_memory.cpp14
-rw-r--r--src/core/hle/kernel/svc/svc_device_address_space.cpp6
-rw-r--r--src/core/hle/kernel/svc/svc_info.cpp16
-rw-r--r--src/core/hle/kernel/svc/svc_ipc.cpp11
-rw-r--r--src/core/hle/kernel/svc/svc_memory.cpp33
-rw-r--r--src/core/hle/kernel/svc/svc_physical_memory.cpp14
-rw-r--r--src/core/hle/kernel/svc/svc_process.cpp4
-rw-r--r--src/core/hle/kernel/svc/svc_process_memory.cpp34
-rw-r--r--src/core/hle/kernel/svc/svc_query_memory.cpp2
-rw-r--r--src/core/hle/kernel/svc/svc_shared_memory.cpp4
-rw-r--r--src/core/hle/kernel/svc/svc_synchronization.cpp11
-rw-r--r--src/core/hle/kernel/svc/svc_thread.cpp2
-rw-r--r--src/core/hle/kernel/svc/svc_transfer_memory.cpp2
-rw-r--r--src/core/hle/service/acc/acc.cpp3
-rw-r--r--src/core/hle/service/am/am.cpp4
-rw-r--r--src/core/hle/service/glue/ectx.cpp43
-rw-r--r--src/core/hle/service/glue/ectx.h3
-rw-r--r--src/core/hle/service/ldr/ldr.cpp20
-rw-r--r--src/core/hle/service/nfc/common/amiibo_crypto.cpp2
-rw-r--r--src/core/hle/service/nfc/common/amiibo_crypto.h2
-rw-r--r--src/core/hle/service/nfc/common/device.cpp3
-rw-r--r--src/core/hle/service/nfc/common/device_manager.cpp135
-rw-r--r--src/core/hle/service/nfc/common/device_manager.h34
-rw-r--r--src/core/hle/service/nfc/nfc_interface.cpp18
-rw-r--r--src/core/hle/service/nfc/nfc_result.h3
-rw-r--r--src/core/hle/service/nifm/nifm.cpp1
-rw-r--r--src/core/hle/service/nifm/nifm.h7
-rw-r--r--src/core/hle/service/nvdrv/devices/nvmap.cpp4
-rw-r--r--src/core/hle/service/sockets/bsd.cpp122
-rw-r--r--src/core/hle/service/sockets/bsd.h13
-rw-r--r--src/core/hle/service/sockets/nsd.cpp85
-rw-r--r--src/core/hle/service/sockets/nsd.h6
-rw-r--r--src/core/hle/service/sockets/sfdnsres.cpp404
-rw-r--r--src/core/hle/service/sockets/sfdnsres.h4
-rw-r--r--src/core/hle/service/sockets/sockets.h33
-rw-r--r--src/core/hle/service/sockets/sockets_translate.cpp152
-rw-r--r--src/core/hle/service/sockets/sockets_translate.h20
-rw-r--r--src/core/hle/service/ssl/ssl.cpp353
-rw-r--r--src/core/hle/service/ssl/ssl_backend.h45
-rw-r--r--src/core/hle/service/ssl/ssl_backend_none.cpp16
-rw-r--r--src/core/hle/service/ssl/ssl_backend_openssl.cpp351
-rw-r--r--src/core/hle/service/ssl/ssl_backend_schannel.cpp544
-rw-r--r--src/core/hle/service/ssl/ssl_backend_securetransport.cpp222
-rw-r--r--src/core/internal_network/network.cpp286
-rw-r--r--src/core/internal_network/network.h36
-rw-r--r--src/core/internal_network/socket_proxy.cpp23
-rw-r--r--src/core/internal_network/socket_proxy.h12
-rw-r--r--src/core/internal_network/sockets.h19
-rw-r--r--src/core/loader/deconstructed_rom_directory.cpp2
-rw-r--r--src/core/loader/kip.cpp2
-rw-r--r--src/core/loader/nro.cpp2
-rw-r--r--src/core/loader/nso.cpp2
-rw-r--r--src/core/memory.cpp121
-rw-r--r--src/core/memory.h13
-rw-r--r--src/core/memory/cheat_engine.cpp2
-rw-r--r--src/core/reporter.cpp4
-rw-r--r--src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp11
-rw-r--r--src/video_core/CMakeLists.txt6
-rw-r--r--src/video_core/buffer_cache/buffer_cache.h10
-rw-r--r--src/video_core/buffer_cache/buffer_cache_base.h2
-rw-r--r--src/video_core/renderer_base.cpp4
-rw-r--r--src/video_core/renderer_opengl/gl_graphics_pipeline.cpp15
-rw-r--r--src/video_core/renderer_opengl/gl_shader_cache.cpp8
-rw-r--r--src/video_core/renderer_vulkan/vk_buffer_cache.cpp33
-rw-r--r--src/video_core/renderer_vulkan/vk_pipeline_cache.cpp12
-rw-r--r--src/video_core/renderer_vulkan/vk_query_cache.cpp11
-rw-r--r--src/video_core/renderer_vulkan/vk_texture_cache.cpp6
-rw-r--r--src/video_core/vulkan_common/vma.cpp8
-rw-r--r--src/web_service/announce_room_json.cpp10
-rw-r--r--src/yuzu/game_list_worker.cpp8
96 files changed, 3010 insertions, 717 deletions
diff --git a/src/android/app/src/main/jni/native.cpp b/src/android/app/src/main/jni/native.cpp
index 8bc6a4a04..c23b2f19e 100644
--- a/src/android/app/src/main/jni/native.cpp
+++ b/src/android/app/src/main/jni/native.cpp
@@ -449,7 +449,7 @@ private:
loader->ReadTitle(entry.title);
loader->ReadIcon(entry.icon);
if (loader->GetFileType() == Loader::FileType::NRO) {
- jauto loader_nro = dynamic_cast<Loader::AppLoader_NRO*>(loader.get());
+ jauto loader_nro = reinterpret_cast<Loader::AppLoader_NRO*>(loader.get());
entry.isHomebrew = loader_nro->IsHomebrew();
} else {
entry.isHomebrew = false;
diff --git a/src/android/app/src/main/res/values/strings.xml b/src/android/app/src/main/res/values/strings.xml
index b963f0119..bfdebd35b 100644
--- a/src/android/app/src/main/res/values/strings.xml
+++ b/src/android/app/src/main/res/values/strings.xml
@@ -232,7 +232,7 @@
<!-- ROM loading errors -->
<string name="loader_error_encrypted">Your ROM is encrypted</string>
- <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-cartridge-games">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-installed-titles-eshop">installed titles</a>.]]></string>
+ <string name="loader_error_encrypted_roms_description"><![CDATA[Please follow the guides to redump your <a href="https://yuzu-emu.org/help/quickstart/#dumping-physical-titles-game-cards">game cartidges</a> or <a href="https://yuzu-emu.org/help/quickstart/#dumping-digital-titles-eshop">installed titles</a>.]]></string>
<string name="loader_error_encrypted_keys_description"><![CDATA[Please ensure your <a href="https://yuzu-emu.org/help/quickstart/#dumping-prodkeys-and-titlekeys">prod.keys</a> file is installed so that games can be decrypted.]]></string>
<string name="loader_error_video_core">An error occurred initializing the video core</string>
<string name="loader_error_video_core_description">This is usually caused by an incompatible GPU driver. Installing a custom GPU driver may resolve this problem.</string>
diff --git a/src/common/demangle.cpp b/src/common/demangle.cpp
index 3310faf86..6e117cb41 100644
--- a/src/common/demangle.cpp
+++ b/src/common/demangle.cpp
@@ -23,7 +23,7 @@ std::string DemangleSymbol(const std::string& mangled) {
SCOPE_EXIT({ std::free(demangled); });
if (is_itanium(mangled)) {
- demangled = llvm::itaniumDemangle(mangled.c_str(), nullptr, nullptr, nullptr);
+ demangled = llvm::itaniumDemangle(mangled.c_str());
}
if (!demangled) {
diff --git a/src/common/detached_tasks.cpp b/src/common/detached_tasks.cpp
index da64848da..f2ed795cc 100644
--- a/src/common/detached_tasks.cpp
+++ b/src/common/detached_tasks.cpp
@@ -30,8 +30,8 @@ DetachedTasks::~DetachedTasks() {
void DetachedTasks::AddTask(std::function<void()> task) {
std::unique_lock lock{instance->mutex};
++instance->count;
- std::thread([task{std::move(task)}]() {
- task();
+ std::thread([task_{std::move(task)}]() {
+ task_();
std::unique_lock thread_lock{instance->mutex};
--instance->count;
std::notify_all_at_thread_exit(instance->cv, std::move(thread_lock));
diff --git a/src/common/settings.cpp b/src/common/settings.cpp
index 5972480e5..d4e55f988 100644
--- a/src/common/settings.cpp
+++ b/src/common/settings.cpp
@@ -26,7 +26,8 @@ std::string GetTimeZoneString() {
std::string location_name;
if (time_zone_index == 0) { // Auto
-#if __cpp_lib_chrono >= 201907L
+#if __cpp_lib_chrono >= 201907L && !defined(MINGW)
+ // Disabled for MinGW -- tzdb always returns Etc/UTC
try {
const struct std::chrono::tzdb& time_zone_data = std::chrono::get_tzdb();
const std::chrono::time_zone* current_zone = time_zone_data.current_zone();
diff --git a/src/common/socket_types.h b/src/common/socket_types.h
index 0a801a443..63824a5c4 100644
--- a/src/common/socket_types.h
+++ b/src/common/socket_types.h
@@ -3,17 +3,22 @@
#pragma once
+#include <optional>
+#include <string>
+
#include "common/common_types.h"
namespace Network {
/// Address families
enum class Domain : u8 {
- INET, ///< Address family for IPv4
+ Unspecified, ///< Represents 0, used in getaddrinfo hints
+ INET, ///< Address family for IPv4
};
/// Socket types
enum class Type {
+ Unspecified, ///< Represents 0, used in getaddrinfo hints
STREAM,
DGRAM,
RAW,
@@ -22,6 +27,7 @@ enum class Type {
/// Protocol values for sockets
enum class Protocol : u8 {
+ Unspecified, ///< Represents 0, usable in various places
ICMP,
TCP,
UDP,
@@ -48,4 +54,13 @@ constexpr u32 FLAG_MSG_PEEK = 0x2;
constexpr u32 FLAG_MSG_DONTWAIT = 0x80;
constexpr u32 FLAG_O_NONBLOCK = 0x800;
+/// Cross-platform addrinfo structure
+struct AddrInfo {
+ Domain family;
+ Type socket_type;
+ Protocol protocol;
+ SockAddrIn addr;
+ std::optional<std::string> canon_name;
+};
+
} // namespace Network
diff --git a/src/common/time_zone.cpp b/src/common/time_zone.cpp
index d8d7896c6..69e728a9d 100644
--- a/src/common/time_zone.cpp
+++ b/src/common/time_zone.cpp
@@ -4,13 +4,13 @@
#include <chrono>
#include <exception>
#include <iomanip>
+#include <map>
#include <sstream>
#include <stdexcept>
#include <fmt/chrono.h>
#include <fmt/core.h>
#include "common/logging/log.h"
-#include "common/settings.h"
#include "common/time_zone.h"
namespace Common::TimeZone {
@@ -33,32 +33,29 @@ std::string GetDefaultTimeZone() {
return "GMT";
}
-static std::string GetOsTimeZoneOffset() {
- const std::time_t t{std::time(nullptr)};
- const std::tm tm{*std::localtime(&t)};
-
- return fmt::format("{:%z}", tm);
-}
-
-static int ConvertOsTimeZoneOffsetToInt(const std::string& timezone) {
- try {
- return std::stoi(timezone);
- } catch (const std::invalid_argument&) {
- LOG_CRITICAL(Common, "invalid_argument with {}!", timezone);
- return 0;
- } catch (const std::out_of_range&) {
- LOG_CRITICAL(Common, "out_of_range with {}!", timezone);
- return 0;
- }
+// Results are not comparable to seconds since Epoch
+static std::time_t TmSpecToSeconds(const struct std::tm& spec) {
+ const int year = spec.tm_year - 1; // Years up to now
+ const int leap_years = year / 4 - year / 100;
+ std::time_t cumulative = spec.tm_year;
+ cumulative = cumulative * 365 + leap_years + spec.tm_yday; // Years to days
+ cumulative = cumulative * 24 + spec.tm_hour; // Days to hours
+ cumulative = cumulative * 60 + spec.tm_min; // Hours to minutes
+ cumulative = cumulative * 60 + spec.tm_sec; // Minutes to seconds
+ return cumulative;
}
std::chrono::seconds GetCurrentOffsetSeconds() {
- const int offset{ConvertOsTimeZoneOffsetToInt(GetOsTimeZoneOffset())};
+ const std::time_t t{std::time(nullptr)};
+ const std::tm local{*std::localtime(&t)};
+ const std::tm gmt{*std::gmtime(&t)};
- int seconds{(offset / 100) * 60 * 60}; // Convert hour component to seconds
- seconds += (offset % 100) * 60; // Convert minute component to seconds
+ // gmt_seconds is a different offset than time(nullptr)
+ const auto gmt_seconds = TmSpecToSeconds(gmt);
+ const auto local_seconds = TmSpecToSeconds(local);
+ const auto seconds_offset = local_seconds - gmt_seconds;
- return std::chrono::seconds{seconds};
+ return std::chrono::seconds{seconds_offset};
}
// Key is [Hours * 100 + Minutes], multiplied by 100 if DST
@@ -71,11 +68,6 @@ const static std::map<s64, const char*> off_timezones = {
};
std::string FindSystemTimeZone() {
-#if defined(MINGW)
- // MinGW has broken strftime -- https://sourceforge.net/p/mingw-w64/bugs/793/
- // e.g. fmt::format("{:%z}") -- returns "Eastern Daylight Time" when it should be "-0400"
- return timezones[0];
-#else
const s64 seconds = static_cast<s64>(GetCurrentOffsetSeconds().count());
const s64 minutes = seconds / 60;
@@ -97,7 +89,6 @@ std::string FindSystemTimeZone() {
}
}
return fmt::format("Etc/GMT{:s}{:d}", hours > 0 ? "-" : "+", std::abs(hours));
-#endif
}
} // namespace Common::TimeZone
diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 28cb6f86f..4b7395be8 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -723,6 +723,7 @@ add_library(core STATIC
hle/service/spl/spl_types.h
hle/service/ssl/ssl.cpp
hle/service/ssl/ssl.h
+ hle/service/ssl/ssl_backend.h
hle/service/time/clock_types.h
hle/service/time/ephemeral_network_system_clock_context_writer.h
hle/service/time/ephemeral_network_system_clock_core.h
@@ -864,6 +865,23 @@ if (ARCHITECTURE_x86_64 OR ARCHITECTURE_arm64)
target_link_libraries(core PRIVATE dynarmic::dynarmic)
endif()
+if(ENABLE_OPENSSL)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_openssl.cpp)
+ target_link_libraries(core PRIVATE OpenSSL::SSL)
+elseif (APPLE)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_securetransport.cpp)
+ target_link_libraries(core PRIVATE "-framework Security")
+elseif (WIN32)
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_schannel.cpp)
+ target_link_libraries(core PRIVATE crypt32 secur32)
+else()
+ target_sources(core PRIVATE
+ hle/service/ssl/ssl_backend_none.cpp)
+endif()
+
if (YUZU_USE_PRECOMPILED_HEADERS)
target_precompile_headers(core PRIVATE precompiled_headers.h)
endif()
diff --git a/src/core/arm/dynarmic/arm_dynarmic_32.cpp b/src/core/arm/dynarmic/arm_dynarmic_32.cpp
index 3b82fb73c..c97158a71 100644
--- a/src/core/arm/dynarmic/arm_dynarmic_32.cpp
+++ b/src/core/arm/dynarmic/arm_dynarmic_32.cpp
@@ -346,11 +346,11 @@ void ARM_Dynarmic_32::RewindBreakpointInstruction() {
}
ARM_Dynarmic_32::ARM_Dynarmic_32(System& system_, bool uses_wall_clock_,
- ExclusiveMonitor& exclusive_monitor_, std::size_t core_index_)
+ DynarmicExclusiveMonitor& exclusive_monitor_,
+ std::size_t core_index_)
: ARM_Interface{system_, uses_wall_clock_}, cb(std::make_unique<DynarmicCallbacks32>(*this)),
cp15(std::make_shared<DynarmicCP15>(*this)), core_index{core_index_},
- exclusive_monitor{dynamic_cast<DynarmicExclusiveMonitor&>(exclusive_monitor_)},
- null_jit{MakeJit(nullptr)}, jit{null_jit.get()} {}
+ exclusive_monitor{exclusive_monitor_}, null_jit{MakeJit(nullptr)}, jit{null_jit.get()} {}
ARM_Dynarmic_32::~ARM_Dynarmic_32() = default;
diff --git a/src/core/arm/dynarmic/arm_dynarmic_32.h b/src/core/arm/dynarmic/arm_dynarmic_32.h
index a990845cb..92fb3f836 100644
--- a/src/core/arm/dynarmic/arm_dynarmic_32.h
+++ b/src/core/arm/dynarmic/arm_dynarmic_32.h
@@ -12,7 +12,7 @@
#include "common/common_types.h"
#include "common/hash.h"
#include "core/arm/arm_interface.h"
-#include "core/arm/exclusive_monitor.h"
+#include "core/arm/dynarmic/dynarmic_exclusive_monitor.h"
namespace Core::Memory {
class Memory;
@@ -28,8 +28,8 @@ class System;
class ARM_Dynarmic_32 final : public ARM_Interface {
public:
- ARM_Dynarmic_32(System& system_, bool uses_wall_clock_, ExclusiveMonitor& exclusive_monitor_,
- std::size_t core_index_);
+ ARM_Dynarmic_32(System& system_, bool uses_wall_clock_,
+ DynarmicExclusiveMonitor& exclusive_monitor_, std::size_t core_index_);
~ARM_Dynarmic_32() override;
void SetPC(u64 pc) override;
diff --git a/src/core/arm/dynarmic/arm_dynarmic_64.cpp b/src/core/arm/dynarmic/arm_dynarmic_64.cpp
index bb97ed5bc..791d466ca 100644
--- a/src/core/arm/dynarmic/arm_dynarmic_64.cpp
+++ b/src/core/arm/dynarmic/arm_dynarmic_64.cpp
@@ -405,11 +405,11 @@ void ARM_Dynarmic_64::RewindBreakpointInstruction() {
}
ARM_Dynarmic_64::ARM_Dynarmic_64(System& system_, bool uses_wall_clock_,
- ExclusiveMonitor& exclusive_monitor_, std::size_t core_index_)
+ DynarmicExclusiveMonitor& exclusive_monitor_,
+ std::size_t core_index_)
: ARM_Interface{system_, uses_wall_clock_},
cb(std::make_unique<DynarmicCallbacks64>(*this)), core_index{core_index_},
- exclusive_monitor{dynamic_cast<DynarmicExclusiveMonitor&>(exclusive_monitor_)},
- null_jit{MakeJit(nullptr, 48)}, jit{null_jit.get()} {}
+ exclusive_monitor{exclusive_monitor_}, null_jit{MakeJit(nullptr, 48)}, jit{null_jit.get()} {}
ARM_Dynarmic_64::~ARM_Dynarmic_64() = default;
diff --git a/src/core/arm/dynarmic/arm_dynarmic_64.h b/src/core/arm/dynarmic/arm_dynarmic_64.h
index af2aa1f1c..2b88a08e2 100644
--- a/src/core/arm/dynarmic/arm_dynarmic_64.h
+++ b/src/core/arm/dynarmic/arm_dynarmic_64.h
@@ -11,7 +11,7 @@
#include "common/common_types.h"
#include "common/hash.h"
#include "core/arm/arm_interface.h"
-#include "core/arm/exclusive_monitor.h"
+#include "core/arm/dynarmic/dynarmic_exclusive_monitor.h"
namespace Core::Memory {
class Memory;
@@ -25,8 +25,8 @@ class System;
class ARM_Dynarmic_64 final : public ARM_Interface {
public:
- ARM_Dynarmic_64(System& system_, bool uses_wall_clock_, ExclusiveMonitor& exclusive_monitor_,
- std::size_t core_index_);
+ ARM_Dynarmic_64(System& system_, bool uses_wall_clock_,
+ DynarmicExclusiveMonitor& exclusive_monitor_, std::size_t core_index_);
~ARM_Dynarmic_64() override;
void SetPC(u64 pc) override;
diff --git a/src/core/arm/dynarmic/dynarmic_exclusive_monitor.h b/src/core/arm/dynarmic/dynarmic_exclusive_monitor.h
index 57e6dd0d0..fbfcd8d95 100644
--- a/src/core/arm/dynarmic/dynarmic_exclusive_monitor.h
+++ b/src/core/arm/dynarmic/dynarmic_exclusive_monitor.h
@@ -6,8 +6,6 @@
#include <dynarmic/interface/exclusive_monitor.h>
#include "common/common_types.h"
-#include "core/arm/dynarmic/arm_dynarmic_32.h"
-#include "core/arm/dynarmic/arm_dynarmic_64.h"
#include "core/arm/exclusive_monitor.h"
namespace Core::Memory {
@@ -16,6 +14,9 @@ class Memory;
namespace Core {
+class ARM_Dynarmic_32;
+class ARM_Dynarmic_64;
+
class DynarmicExclusiveMonitor final : public ExclusiveMonitor {
public:
explicit DynarmicExclusiveMonitor(Memory::Memory& memory_, std::size_t core_count_);
diff --git a/src/core/core.cpp b/src/core/core.cpp
index 9e3eb3795..48233d7c8 100644
--- a/src/core/core.cpp
+++ b/src/core/core.cpp
@@ -880,6 +880,14 @@ const FileSys::ContentProvider& System::GetContentProvider() const {
return *impl->content_provider;
}
+FileSys::ContentProviderUnion& System::GetContentProviderUnion() {
+ return *impl->content_provider;
+}
+
+const FileSys::ContentProviderUnion& System::GetContentProviderUnion() const {
+ return *impl->content_provider;
+}
+
Service::FileSystem::FileSystemController& System::GetFileSystemController() {
return impl->fs_controller;
}
diff --git a/src/core/core.h b/src/core/core.h
index 14b2f7785..c70ea1965 100644
--- a/src/core/core.h
+++ b/src/core/core.h
@@ -381,6 +381,9 @@ public:
[[nodiscard]] FileSys::ContentProvider& GetContentProvider();
[[nodiscard]] const FileSys::ContentProvider& GetContentProvider() const;
+ [[nodiscard]] FileSys::ContentProviderUnion& GetContentProviderUnion();
+ [[nodiscard]] const FileSys::ContentProviderUnion& GetContentProviderUnion() const;
+
[[nodiscard]] Service::FileSystem::FileSystemController& GetFileSystemController();
[[nodiscard]] const Service::FileSystem::FileSystemController& GetFileSystemController() const;
diff --git a/src/core/debugger/gdbstub.cpp b/src/core/debugger/gdbstub.cpp
index e2a13bbd2..0f839d5b4 100644
--- a/src/core/debugger/gdbstub.cpp
+++ b/src/core/debugger/gdbstub.cpp
@@ -261,10 +261,8 @@ void GDBStub::ExecuteCommand(std::string_view packet, std::vector<DebuggerAction
const size_t addr{static_cast<size_t>(strtoll(command.data(), nullptr, 16))};
const size_t size{static_cast<size_t>(strtoll(command.data() + sep, nullptr, 16))};
- if (system.ApplicationMemory().IsValidVirtualAddressRange(addr, size)) {
- std::vector<u8> mem(size);
- system.ApplicationMemory().ReadBlock(addr, mem.data(), size);
-
+ std::vector<u8> mem(size);
+ if (system.ApplicationMemory().ReadBlock(addr, mem.data(), size)) {
SendReply(Common::HexToString(mem));
} else {
SendReply(GDB_STUB_REPLY_ERR);
@@ -281,8 +279,7 @@ void GDBStub::ExecuteCommand(std::string_view packet, std::vector<DebuggerAction
const auto mem_substr{std::string_view(command).substr(mem_sep)};
const auto mem{Common::HexStringToVector(mem_substr, false)};
- if (system.ApplicationMemory().IsValidVirtualAddressRange(addr, size)) {
- system.ApplicationMemory().WriteBlock(addr, mem.data(), size);
+ if (system.ApplicationMemory().WriteBlock(addr, mem.data(), size)) {
system.InvalidateCpuInstructionCacheRange(addr, size);
SendReply(GDB_STUB_REPLY_OK);
} else {
@@ -556,7 +553,7 @@ void GDBStub::HandleQuery(std::string_view command) {
} else {
SendReply(fmt::format(
"TextSeg={:x}",
- GetInteger(system.ApplicationProcess()->PageTable().GetCodeRegionStart())));
+ GetInteger(system.ApplicationProcess()->GetPageTable().GetCodeRegionStart())));
}
} else if (command.starts_with("Xfer:libraries:read::")) {
Loader::AppLoader::Modules modules;
@@ -731,7 +728,7 @@ void GDBStub::HandleRcmd(const std::vector<u8>& command) {
std::string reply;
auto* process = system.ApplicationProcess();
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
const char* commands = "Commands:\n"
" get fastmem\n"
diff --git a/src/core/hle/kernel/k_code_memory.cpp b/src/core/hle/kernel/k_code_memory.cpp
index 3583bee44..7454be55c 100644
--- a/src/core/hle/kernel/k_code_memory.cpp
+++ b/src/core/hle/kernel/k_code_memory.cpp
@@ -25,7 +25,7 @@ Result KCodeMemory::Initialize(Core::DeviceMemory& device_memory, KProcessAddres
m_owner = GetCurrentProcessPointer(m_kernel);
// Get the owner page table.
- auto& page_table = m_owner->PageTable();
+ auto& page_table = m_owner->GetPageTable();
// Construct the page group.
m_page_group.emplace(m_kernel, page_table.GetBlockInfoManager());
@@ -53,7 +53,7 @@ void KCodeMemory::Finalize() {
// Unlock.
if (!m_is_mapped && !m_is_owner_mapped) {
const size_t size = m_page_group->GetNumPages() * PageSize;
- m_owner->PageTable().UnlockForCodeMemory(m_address, size, *m_page_group);
+ m_owner->GetPageTable().UnlockForCodeMemory(m_address, size, *m_page_group);
}
// Close the page group.
@@ -75,7 +75,7 @@ Result KCodeMemory::Map(KProcessAddress address, size_t size) {
R_UNLESS(!m_is_mapped, ResultInvalidState);
// Map the memory.
- R_TRY(GetCurrentProcess(m_kernel).PageTable().MapPageGroup(
+ R_TRY(GetCurrentProcess(m_kernel).GetPageTable().MapPageGroup(
address, *m_page_group, KMemoryState::CodeOut, KMemoryPermission::UserReadWrite));
// Mark ourselves as mapped.
@@ -92,8 +92,8 @@ Result KCodeMemory::Unmap(KProcessAddress address, size_t size) {
KScopedLightLock lk(m_lock);
// Unmap the memory.
- R_TRY(GetCurrentProcess(m_kernel).PageTable().UnmapPageGroup(address, *m_page_group,
- KMemoryState::CodeOut));
+ R_TRY(GetCurrentProcess(m_kernel).GetPageTable().UnmapPageGroup(address, *m_page_group,
+ KMemoryState::CodeOut));
// Mark ourselves as unmapped.
m_is_mapped = false;
@@ -126,8 +126,8 @@ Result KCodeMemory::MapToOwner(KProcessAddress address, size_t size, Svc::Memory
}
// Map the memory.
- R_TRY(m_owner->PageTable().MapPageGroup(address, *m_page_group, KMemoryState::GeneratedCode,
- k_perm));
+ R_TRY(m_owner->GetPageTable().MapPageGroup(address, *m_page_group, KMemoryState::GeneratedCode,
+ k_perm));
// Mark ourselves as mapped.
m_is_owner_mapped = true;
@@ -143,7 +143,8 @@ Result KCodeMemory::UnmapFromOwner(KProcessAddress address, size_t size) {
KScopedLightLock lk(m_lock);
// Unmap the memory.
- R_TRY(m_owner->PageTable().UnmapPageGroup(address, *m_page_group, KMemoryState::GeneratedCode));
+ R_TRY(m_owner->GetPageTable().UnmapPageGroup(address, *m_page_group,
+ KMemoryState::GeneratedCode));
// Mark ourselves as unmapped.
m_is_owner_mapped = false;
diff --git a/src/core/hle/kernel/k_page_table.h b/src/core/hle/kernel/k_page_table.h
index 022d15f35..b9e8c6042 100644
--- a/src/core/hle/kernel/k_page_table.h
+++ b/src/core/hle/kernel/k_page_table.h
@@ -388,39 +388,6 @@ public:
constexpr size_t GetHeapSize() const {
return m_current_heap_end - m_heap_region_start;
}
- constexpr bool IsInsideAddressSpace(KProcessAddress address, size_t size) const {
- return m_address_space_start <= address && address + size - 1 <= m_address_space_end - 1;
- }
- constexpr bool IsOutsideAliasRegion(KProcessAddress address, size_t size) const {
- return m_alias_region_start > address || address + size - 1 > m_alias_region_end - 1;
- }
- constexpr bool IsOutsideStackRegion(KProcessAddress address, size_t size) const {
- return m_stack_region_start > address || address + size - 1 > m_stack_region_end - 1;
- }
- constexpr bool IsInvalidRegion(KProcessAddress address, size_t size) const {
- return address + size - 1 > GetAliasCodeRegionStart() + GetAliasCodeRegionSize() - 1;
- }
- constexpr bool IsInsideHeapRegion(KProcessAddress address, size_t size) const {
- return address + size > m_heap_region_start && m_heap_region_end > address;
- }
- constexpr bool IsInsideAliasRegion(KProcessAddress address, size_t size) const {
- return address + size > m_alias_region_start && m_alias_region_end > address;
- }
- constexpr bool IsOutsideASLRRegion(KProcessAddress address, size_t size) const {
- if (IsInvalidRegion(address, size)) {
- return true;
- }
- if (IsInsideHeapRegion(address, size)) {
- return true;
- }
- if (IsInsideAliasRegion(address, size)) {
- return true;
- }
- return {};
- }
- constexpr bool IsInsideASLRRegion(KProcessAddress address, size_t size) const {
- return !IsOutsideASLRRegion(address, size);
- }
constexpr size_t GetNumGuardPages() const {
return IsKernel() ? 1 : 4;
}
@@ -436,6 +403,14 @@ public:
return m_address_space_start <= addr && addr < addr + size &&
addr + size - 1 <= m_address_space_end - 1;
}
+ constexpr bool IsInAliasRegion(KProcessAddress addr, size_t size) const {
+ return this->Contains(addr, size) && m_alias_region_start <= addr &&
+ addr + size - 1 <= m_alias_region_end - 1;
+ }
+ constexpr bool IsInHeapRegion(KProcessAddress addr, size_t size) const {
+ return this->Contains(addr, size) && m_heap_region_start <= addr &&
+ addr + size - 1 <= m_heap_region_end - 1;
+ }
public:
static KVirtualAddress GetLinearMappedVirtualAddress(const KMemoryLayout& layout,
diff --git a/src/core/hle/kernel/k_process.cpp b/src/core/hle/kernel/k_process.cpp
index efe86ad27..44c7cb22f 100644
--- a/src/core/hle/kernel/k_process.cpp
+++ b/src/core/hle/kernel/k_process.cpp
@@ -38,7 +38,7 @@ namespace {
*/
void SetupMainThread(Core::System& system, KProcess& owner_process, u32 priority,
KProcessAddress stack_top) {
- const KProcessAddress entry_point = owner_process.PageTable().GetCodeRegionStart();
+ const KProcessAddress entry_point = owner_process.GetPageTable().GetCodeRegionStart();
ASSERT(owner_process.GetResourceLimit()->Reserve(LimitableResource::ThreadCountMax, 1));
KThread* thread = KThread::Create(system.Kernel());
diff --git a/src/core/hle/kernel/k_process.h b/src/core/hle/kernel/k_process.h
index 925981d06..c9b37e138 100644
--- a/src/core/hle/kernel/k_process.h
+++ b/src/core/hle/kernel/k_process.h
@@ -110,16 +110,6 @@ public:
ProcessType type, KResourceLimit* res_limit);
/// Gets a reference to the process' page table.
- KPageTable& PageTable() {
- return m_page_table;
- }
-
- /// Gets const a reference to the process' page table.
- const KPageTable& PageTable() const {
- return m_page_table;
- }
-
- /// Gets a reference to the process' page table.
KPageTable& GetPageTable() {
return m_page_table;
}
diff --git a/src/core/hle/kernel/k_shared_memory.cpp b/src/core/hle/kernel/k_shared_memory.cpp
index efb5699de..f713968f6 100644
--- a/src/core/hle/kernel/k_shared_memory.cpp
+++ b/src/core/hle/kernel/k_shared_memory.cpp
@@ -90,8 +90,8 @@ Result KSharedMemory::Map(KProcess& target_process, KProcessAddress address, std
R_UNLESS(map_perm == test_perm, ResultInvalidNewMemoryPermission);
}
- R_RETURN(target_process.PageTable().MapPageGroup(address, *m_page_group, KMemoryState::Shared,
- ConvertToKMemoryPermission(map_perm)));
+ R_RETURN(target_process.GetPageTable().MapPageGroup(
+ address, *m_page_group, KMemoryState::Shared, ConvertToKMemoryPermission(map_perm)));
}
Result KSharedMemory::Unmap(KProcess& target_process, KProcessAddress address,
@@ -100,7 +100,7 @@ Result KSharedMemory::Unmap(KProcess& target_process, KProcessAddress address,
R_UNLESS(m_size == unmap_size, ResultInvalidSize);
R_RETURN(
- target_process.PageTable().UnmapPageGroup(address, *m_page_group, KMemoryState::Shared));
+ target_process.GetPageTable().UnmapPageGroup(address, *m_page_group, KMemoryState::Shared));
}
} // namespace Kernel
diff --git a/src/core/hle/kernel/k_thread.cpp b/src/core/hle/kernel/k_thread.cpp
index 2a105a762..7df8fd7f7 100644
--- a/src/core/hle/kernel/k_thread.cpp
+++ b/src/core/hle/kernel/k_thread.cpp
@@ -302,12 +302,12 @@ Result KThread::InitializeServiceThread(Core::System& system, KThread* thread,
std::function<void()>&& func, s32 prio, s32 virt_core,
KProcess* owner) {
system.Kernel().GlobalSchedulerContext().AddThread(thread);
- std::function<void()> func2{[&system, func{std::move(func)}] {
+ std::function<void()> func2{[&system, func_{std::move(func)}] {
// Similar to UserModeThreadStarter.
system.Kernel().CurrentScheduler()->OnThreadStart();
// Run the guest function.
- func();
+ func_();
// Exit.
Svc::ExitThread(system);
diff --git a/src/core/hle/kernel/k_thread_local_page.cpp b/src/core/hle/kernel/k_thread_local_page.cpp
index b4a1e3cdb..2c45b4232 100644
--- a/src/core/hle/kernel/k_thread_local_page.cpp
+++ b/src/core/hle/kernel/k_thread_local_page.cpp
@@ -25,9 +25,9 @@ Result KThreadLocalPage::Initialize(KernelCore& kernel, KProcess* process) {
// Map the address in.
const auto phys_addr = kernel.System().DeviceMemory().GetPhysicalAddr(page_buf);
- R_TRY(m_owner->PageTable().MapPages(std::addressof(m_virt_addr), 1, PageSize, phys_addr,
- KMemoryState::ThreadLocal,
- KMemoryPermission::UserReadWrite));
+ R_TRY(m_owner->GetPageTable().MapPages(std::addressof(m_virt_addr), 1, PageSize, phys_addr,
+ KMemoryState::ThreadLocal,
+ KMemoryPermission::UserReadWrite));
// We succeeded.
page_buf_guard.Cancel();
@@ -37,11 +37,11 @@ Result KThreadLocalPage::Initialize(KernelCore& kernel, KProcess* process) {
Result KThreadLocalPage::Finalize() {
// Get the physical address of the page.
- const KPhysicalAddress phys_addr = m_owner->PageTable().GetPhysicalAddr(m_virt_addr);
+ const KPhysicalAddress phys_addr = m_owner->GetPageTable().GetPhysicalAddr(m_virt_addr);
ASSERT(phys_addr);
// Unmap the page.
- R_TRY(m_owner->PageTable().UnmapPages(this->GetAddress(), 1, KMemoryState::ThreadLocal));
+ R_TRY(m_owner->GetPageTable().UnmapPages(this->GetAddress(), 1, KMemoryState::ThreadLocal));
// Free the page.
KPageBuffer::Free(*m_kernel, KPageBuffer::FromPhysicalAddress(m_kernel->System(), phys_addr));
diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp
index f33600ca5..ebe7582c6 100644
--- a/src/core/hle/kernel/kernel.cpp
+++ b/src/core/hle/kernel/kernel.cpp
@@ -1089,15 +1089,15 @@ static std::jthread RunHostThreadFunc(KernelCore& kernel, KProcess* process,
KThread::Register(kernel, thread);
return std::jthread(
- [&kernel, thread, thread_name{std::move(thread_name)}, func{std::move(func)}] {
+ [&kernel, thread, thread_name_{std::move(thread_name)}, func_{std::move(func)}] {
// Set the thread name.
- Common::SetCurrentThreadName(thread_name.c_str());
+ Common::SetCurrentThreadName(thread_name_.c_str());
// Set the thread as current.
kernel.RegisterHostThread(thread);
// Run the callback.
- func();
+ func_();
// Close the thread.
// This will free the process if it is the last reference.
diff --git a/src/core/hle/kernel/physical_core.cpp b/src/core/hle/kernel/physical_core.cpp
index 2e0c36129..5ee869fa2 100644
--- a/src/core/hle/kernel/physical_core.cpp
+++ b/src/core/hle/kernel/physical_core.cpp
@@ -17,7 +17,9 @@ PhysicalCore::PhysicalCore(std::size_t core_index, Core::System& system, KSchedu
// a 32-bit instance of Dynarmic. This should be abstracted out to a CPU manager.
auto& kernel = system.Kernel();
m_arm_interface = std::make_unique<Core::ARM_Dynarmic_64>(
- system, kernel.IsMulticore(), kernel.GetExclusiveMonitor(), m_core_index);
+ system, kernel.IsMulticore(),
+ reinterpret_cast<Core::DynarmicExclusiveMonitor&>(kernel.GetExclusiveMonitor()),
+ m_core_index);
#else
#error Platform not supported yet.
#endif
@@ -31,7 +33,9 @@ void PhysicalCore::Initialize(bool is_64_bit) {
if (!is_64_bit) {
// We already initialized a 64-bit core, replace with a 32-bit one.
m_arm_interface = std::make_unique<Core::ARM_Dynarmic_32>(
- m_system, kernel.IsMulticore(), kernel.GetExclusiveMonitor(), m_core_index);
+ m_system, kernel.IsMulticore(),
+ reinterpret_cast<Core::DynarmicExclusiveMonitor&>(kernel.GetExclusiveMonitor()),
+ m_core_index);
}
#else
#error Platform not supported yet.
diff --git a/src/core/hle/kernel/svc/svc_cache.cpp b/src/core/hle/kernel/svc/svc_cache.cpp
index 082942dab..c2c8be10f 100644
--- a/src/core/hle/kernel/svc/svc_cache.cpp
+++ b/src/core/hle/kernel/svc/svc_cache.cpp
@@ -42,7 +42,7 @@ Result FlushProcessDataCache(Core::System& system, Handle process_handle, u64 ad
R_UNLESS(process.IsNotNull(), ResultInvalidHandle);
// Verify the region is within range.
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
R_UNLESS(page_table.Contains(address, size), ResultInvalidCurrentMemory);
// Perform the operation.
diff --git a/src/core/hle/kernel/svc/svc_code_memory.cpp b/src/core/hle/kernel/svc/svc_code_memory.cpp
index 687baff82..bae4cb0cd 100644
--- a/src/core/hle/kernel/svc/svc_code_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_code_memory.cpp
@@ -48,7 +48,7 @@ Result CreateCodeMemory(Core::System& system, Handle* out, u64 address, uint64_t
SCOPE_EXIT({ code_mem->Close(); });
// Verify that the region is in range.
- R_UNLESS(GetCurrentProcess(system.Kernel()).PageTable().Contains(address, size),
+ R_UNLESS(GetCurrentProcess(system.Kernel()).GetPageTable().Contains(address, size),
ResultInvalidCurrentMemory);
// Initialize the code memory.
@@ -92,7 +92,7 @@ Result ControlCodeMemory(Core::System& system, Handle code_memory_handle,
case CodeMemoryOperation::Map: {
// Check that the region is in range.
R_UNLESS(GetCurrentProcess(system.Kernel())
- .PageTable()
+ .GetPageTable()
.CanContain(address, size, KMemoryState::CodeOut),
ResultInvalidMemoryRegion);
@@ -105,7 +105,7 @@ Result ControlCodeMemory(Core::System& system, Handle code_memory_handle,
case CodeMemoryOperation::Unmap: {
// Check that the region is in range.
R_UNLESS(GetCurrentProcess(system.Kernel())
- .PageTable()
+ .GetPageTable()
.CanContain(address, size, KMemoryState::CodeOut),
ResultInvalidMemoryRegion);
@@ -117,8 +117,8 @@ Result ControlCodeMemory(Core::System& system, Handle code_memory_handle,
} break;
case CodeMemoryOperation::MapToOwner: {
// Check that the region is in range.
- R_UNLESS(code_mem->GetOwner()->PageTable().CanContain(address, size,
- KMemoryState::GeneratedCode),
+ R_UNLESS(code_mem->GetOwner()->GetPageTable().CanContain(address, size,
+ KMemoryState::GeneratedCode),
ResultInvalidMemoryRegion);
// Check the memory permission.
@@ -129,8 +129,8 @@ Result ControlCodeMemory(Core::System& system, Handle code_memory_handle,
} break;
case CodeMemoryOperation::UnmapFromOwner: {
// Check that the region is in range.
- R_UNLESS(code_mem->GetOwner()->PageTable().CanContain(address, size,
- KMemoryState::GeneratedCode),
+ R_UNLESS(code_mem->GetOwner()->GetPageTable().CanContain(address, size,
+ KMemoryState::GeneratedCode),
ResultInvalidMemoryRegion);
// Check the memory permission.
diff --git a/src/core/hle/kernel/svc/svc_device_address_space.cpp b/src/core/hle/kernel/svc/svc_device_address_space.cpp
index ec3143e67..42add9473 100644
--- a/src/core/hle/kernel/svc/svc_device_address_space.cpp
+++ b/src/core/hle/kernel/svc/svc_device_address_space.cpp
@@ -107,7 +107,7 @@ Result MapDeviceAddressSpaceByForce(Core::System& system, Handle das_handle, Han
R_UNLESS(process.IsNotNull(), ResultInvalidHandle);
// Validate that the process address is within range.
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
R_UNLESS(page_table.Contains(process_address, size), ResultInvalidCurrentMemory);
// Map.
@@ -148,7 +148,7 @@ Result MapDeviceAddressSpaceAligned(Core::System& system, Handle das_handle, Han
R_UNLESS(process.IsNotNull(), ResultInvalidHandle);
// Validate that the process address is within range.
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
R_UNLESS(page_table.Contains(process_address, size), ResultInvalidCurrentMemory);
// Map.
@@ -180,7 +180,7 @@ Result UnmapDeviceAddressSpace(Core::System& system, Handle das_handle, Handle p
R_UNLESS(process.IsNotNull(), ResultInvalidHandle);
// Validate that the process address is within range.
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
R_UNLESS(page_table.Contains(process_address, size), ResultInvalidCurrentMemory);
R_RETURN(das->Unmap(std::addressof(page_table), process_address, size, device_address));
diff --git a/src/core/hle/kernel/svc/svc_info.cpp b/src/core/hle/kernel/svc/svc_info.cpp
index 445cdd87b..f99964028 100644
--- a/src/core/hle/kernel/svc/svc_info.cpp
+++ b/src/core/hle/kernel/svc/svc_info.cpp
@@ -54,35 +54,35 @@ Result GetInfo(Core::System& system, u64* result, InfoType info_id_type, Handle
R_SUCCEED();
case InfoType::AliasRegionAddress:
- *result = GetInteger(process->PageTable().GetAliasRegionStart());
+ *result = GetInteger(process->GetPageTable().GetAliasRegionStart());
R_SUCCEED();
case InfoType::AliasRegionSize:
- *result = process->PageTable().GetAliasRegionSize();
+ *result = process->GetPageTable().GetAliasRegionSize();
R_SUCCEED();
case InfoType::HeapRegionAddress:
- *result = GetInteger(process->PageTable().GetHeapRegionStart());
+ *result = GetInteger(process->GetPageTable().GetHeapRegionStart());
R_SUCCEED();
case InfoType::HeapRegionSize:
- *result = process->PageTable().GetHeapRegionSize();
+ *result = process->GetPageTable().GetHeapRegionSize();
R_SUCCEED();
case InfoType::AslrRegionAddress:
- *result = GetInteger(process->PageTable().GetAliasCodeRegionStart());
+ *result = GetInteger(process->GetPageTable().GetAliasCodeRegionStart());
R_SUCCEED();
case InfoType::AslrRegionSize:
- *result = process->PageTable().GetAliasCodeRegionSize();
+ *result = process->GetPageTable().GetAliasCodeRegionSize();
R_SUCCEED();
case InfoType::StackRegionAddress:
- *result = GetInteger(process->PageTable().GetStackRegionStart());
+ *result = GetInteger(process->GetPageTable().GetStackRegionStart());
R_SUCCEED();
case InfoType::StackRegionSize:
- *result = process->PageTable().GetStackRegionSize();
+ *result = process->GetPageTable().GetStackRegionSize();
R_SUCCEED();
case InfoType::TotalMemorySize:
diff --git a/src/core/hle/kernel/svc/svc_ipc.cpp b/src/core/hle/kernel/svc/svc_ipc.cpp
index bb94f6934..373ae7c8d 100644
--- a/src/core/hle/kernel/svc/svc_ipc.cpp
+++ b/src/core/hle/kernel/svc/svc_ipc.cpp
@@ -8,6 +8,7 @@
#include "core/hle/kernel/k_process.h"
#include "core/hle/kernel/k_server_session.h"
#include "core/hle/kernel/svc.h"
+#include "core/hle/kernel/svc_results.h"
namespace Kernel::Svc {
@@ -49,14 +50,10 @@ Result ReplyAndReceive(Core::System& system, s32* out_index, uint64_t handles_ad
// Copy user handles.
if (num_handles > 0) {
- // Ensure we can try to get the handles.
- R_UNLESS(GetCurrentMemory(kernel).IsValidVirtualAddressRange(
- handles_addr, static_cast<u64>(sizeof(Handle) * num_handles)),
- ResultInvalidPointer);
-
// Get the handles.
- GetCurrentMemory(kernel).ReadBlock(handles_addr, handles.data(),
- sizeof(Handle) * num_handles);
+ R_UNLESS(GetCurrentMemory(kernel).ReadBlock(handles_addr, handles.data(),
+ sizeof(Handle) * num_handles),
+ ResultInvalidPointer);
// Convert the handles to objects.
R_UNLESS(handle_table.GetMultipleObjects<KSynchronizationObject>(
diff --git a/src/core/hle/kernel/svc/svc_memory.cpp b/src/core/hle/kernel/svc/svc_memory.cpp
index 5dcb7f045..2cab74127 100644
--- a/src/core/hle/kernel/svc/svc_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_memory.cpp
@@ -63,36 +63,13 @@ Result MapUnmapMemorySanityChecks(const KPageTable& manager, u64 dst_addr, u64 s
R_THROW(ResultInvalidCurrentMemory);
}
- if (!manager.IsInsideAddressSpace(src_addr, size)) {
+ if (!manager.Contains(src_addr, size)) {
LOG_ERROR(Kernel_SVC,
"Source is not within the address space, addr=0x{:016X}, size=0x{:016X}",
src_addr, size);
R_THROW(ResultInvalidCurrentMemory);
}
- if (manager.IsOutsideStackRegion(dst_addr, size)) {
- LOG_ERROR(Kernel_SVC,
- "Destination is not within the stack region, addr=0x{:016X}, size=0x{:016X}",
- dst_addr, size);
- R_THROW(ResultInvalidMemoryRegion);
- }
-
- if (manager.IsInsideHeapRegion(dst_addr, size)) {
- LOG_ERROR(Kernel_SVC,
- "Destination does not fit within the heap region, addr=0x{:016X}, "
- "size=0x{:016X}",
- dst_addr, size);
- R_THROW(ResultInvalidMemoryRegion);
- }
-
- if (manager.IsInsideAliasRegion(dst_addr, size)) {
- LOG_ERROR(Kernel_SVC,
- "Destination does not fit within the map region, addr=0x{:016X}, "
- "size=0x{:016X}",
- dst_addr, size);
- R_THROW(ResultInvalidMemoryRegion);
- }
-
R_SUCCEED();
}
@@ -112,7 +89,7 @@ Result SetMemoryPermission(Core::System& system, u64 address, u64 size, MemoryPe
R_UNLESS(IsValidSetMemoryPermission(perm), ResultInvalidNewMemoryPermission);
// Validate that the region is in range for the current process.
- auto& page_table = GetCurrentProcess(system.Kernel()).PageTable();
+ auto& page_table = GetCurrentProcess(system.Kernel()).GetPageTable();
R_UNLESS(page_table.Contains(address, size), ResultInvalidCurrentMemory);
// Set the memory attribute.
@@ -136,7 +113,7 @@ Result SetMemoryAttribute(Core::System& system, u64 address, u64 size, u32 mask,
R_UNLESS((mask | attr | SupportedMask) == SupportedMask, ResultInvalidCombination);
// Validate that the region is in range for the current process.
- auto& page_table{GetCurrentProcess(system.Kernel()).PageTable()};
+ auto& page_table{GetCurrentProcess(system.Kernel()).GetPageTable()};
R_UNLESS(page_table.Contains(address, size), ResultInvalidCurrentMemory);
// Set the memory attribute.
@@ -148,7 +125,7 @@ Result MapMemory(Core::System& system, u64 dst_addr, u64 src_addr, u64 size) {
LOG_TRACE(Kernel_SVC, "called, dst_addr=0x{:X}, src_addr=0x{:X}, size=0x{:X}", dst_addr,
src_addr, size);
- auto& page_table{GetCurrentProcess(system.Kernel()).PageTable()};
+ auto& page_table{GetCurrentProcess(system.Kernel()).GetPageTable()};
if (const Result result{MapUnmapMemorySanityChecks(page_table, dst_addr, src_addr, size)};
result.IsError()) {
@@ -163,7 +140,7 @@ Result UnmapMemory(Core::System& system, u64 dst_addr, u64 src_addr, u64 size) {
LOG_TRACE(Kernel_SVC, "called, dst_addr=0x{:X}, src_addr=0x{:X}, size=0x{:X}", dst_addr,
src_addr, size);
- auto& page_table{GetCurrentProcess(system.Kernel()).PageTable()};
+ auto& page_table{GetCurrentProcess(system.Kernel()).GetPageTable()};
if (const Result result{MapUnmapMemorySanityChecks(page_table, dst_addr, src_addr, size)};
result.IsError()) {
diff --git a/src/core/hle/kernel/svc/svc_physical_memory.cpp b/src/core/hle/kernel/svc/svc_physical_memory.cpp
index c2fbfb59a..d3545f232 100644
--- a/src/core/hle/kernel/svc/svc_physical_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_physical_memory.cpp
@@ -16,7 +16,7 @@ Result SetHeapSize(Core::System& system, u64* out_address, u64 size) {
R_UNLESS(size < MainMemorySizeMax, ResultInvalidSize);
// Set the heap size.
- R_RETURN(GetCurrentProcess(system.Kernel()).PageTable().SetHeapSize(out_address, size));
+ R_RETURN(GetCurrentProcess(system.Kernel()).GetPageTable().SetHeapSize(out_address, size));
}
/// Maps memory at a desired address
@@ -44,21 +44,21 @@ Result MapPhysicalMemory(Core::System& system, u64 addr, u64 size) {
}
KProcess* const current_process{GetCurrentProcessPointer(system.Kernel())};
- auto& page_table{current_process->PageTable()};
+ auto& page_table{current_process->GetPageTable()};
if (current_process->GetSystemResourceSize() == 0) {
LOG_ERROR(Kernel_SVC, "System Resource Size is zero");
R_THROW(ResultInvalidState);
}
- if (!page_table.IsInsideAddressSpace(addr, size)) {
+ if (!page_table.Contains(addr, size)) {
LOG_ERROR(Kernel_SVC,
"Address is not within the address space, addr=0x{:016X}, size=0x{:016X}", addr,
size);
R_THROW(ResultInvalidMemoryRegion);
}
- if (page_table.IsOutsideAliasRegion(addr, size)) {
+ if (!page_table.IsInAliasRegion(addr, size)) {
LOG_ERROR(Kernel_SVC,
"Address is not within the alias region, addr=0x{:016X}, size=0x{:016X}", addr,
size);
@@ -93,21 +93,21 @@ Result UnmapPhysicalMemory(Core::System& system, u64 addr, u64 size) {
}
KProcess* const current_process{GetCurrentProcessPointer(system.Kernel())};
- auto& page_table{current_process->PageTable()};
+ auto& page_table{current_process->GetPageTable()};
if (current_process->GetSystemResourceSize() == 0) {
LOG_ERROR(Kernel_SVC, "System Resource Size is zero");
R_THROW(ResultInvalidState);
}
- if (!page_table.IsInsideAddressSpace(addr, size)) {
+ if (!page_table.Contains(addr, size)) {
LOG_ERROR(Kernel_SVC,
"Address is not within the address space, addr=0x{:016X}, size=0x{:016X}", addr,
size);
R_THROW(ResultInvalidMemoryRegion);
}
- if (page_table.IsOutsideAliasRegion(addr, size)) {
+ if (!page_table.IsInAliasRegion(addr, size)) {
LOG_ERROR(Kernel_SVC,
"Address is not within the alias region, addr=0x{:016X}, size=0x{:016X}", addr,
size);
diff --git a/src/core/hle/kernel/svc/svc_process.cpp b/src/core/hle/kernel/svc/svc_process.cpp
index 619ed16a3..caa8bee9a 100644
--- a/src/core/hle/kernel/svc/svc_process.cpp
+++ b/src/core/hle/kernel/svc/svc_process.cpp
@@ -66,8 +66,8 @@ Result GetProcessList(Core::System& system, s32* out_num_processes, u64 out_proc
auto& kernel = system.Kernel();
const auto total_copy_size = out_process_ids_size * sizeof(u64);
- if (out_process_ids_size > 0 && !GetCurrentProcess(kernel).PageTable().IsInsideAddressSpace(
- out_process_ids, total_copy_size)) {
+ if (out_process_ids_size > 0 &&
+ !GetCurrentProcess(kernel).GetPageTable().Contains(out_process_ids, total_copy_size)) {
LOG_ERROR(Kernel_SVC, "Address range outside address space. begin=0x{:016X}, end=0x{:016X}",
out_process_ids, out_process_ids + total_copy_size);
R_THROW(ResultInvalidCurrentMemory);
diff --git a/src/core/hle/kernel/svc/svc_process_memory.cpp b/src/core/hle/kernel/svc/svc_process_memory.cpp
index aee0f2f36..07cd48175 100644
--- a/src/core/hle/kernel/svc/svc_process_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_process_memory.cpp
@@ -49,7 +49,7 @@ Result SetProcessMemoryPermission(Core::System& system, Handle process_handle, u
R_UNLESS(process.IsNotNull(), ResultInvalidHandle);
// Validate that the address is in range.
- auto& page_table = process->PageTable();
+ auto& page_table = process->GetPageTable();
R_UNLESS(page_table.Contains(address, size), ResultInvalidCurrentMemory);
// Set the memory permission.
@@ -77,8 +77,8 @@ Result MapProcessMemory(Core::System& system, u64 dst_address, Handle process_ha
R_UNLESS(src_process.IsNotNull(), ResultInvalidHandle);
// Get the page tables.
- auto& dst_pt = dst_process->PageTable();
- auto& src_pt = src_process->PageTable();
+ auto& dst_pt = dst_process->GetPageTable();
+ auto& src_pt = src_process->GetPageTable();
// Validate that the mapping is in range.
R_UNLESS(src_pt.Contains(src_address, size), ResultInvalidCurrentMemory);
@@ -118,8 +118,8 @@ Result UnmapProcessMemory(Core::System& system, u64 dst_address, Handle process_
R_UNLESS(src_process.IsNotNull(), ResultInvalidHandle);
// Get the page tables.
- auto& dst_pt = dst_process->PageTable();
- auto& src_pt = src_process->PageTable();
+ auto& dst_pt = dst_process->GetPageTable();
+ auto& src_pt = src_process->GetPageTable();
// Validate that the mapping is in range.
R_UNLESS(src_pt.Contains(src_address, size), ResultInvalidCurrentMemory);
@@ -178,8 +178,8 @@ Result MapProcessCodeMemory(Core::System& system, Handle process_handle, u64 dst
R_THROW(ResultInvalidHandle);
}
- auto& page_table = process->PageTable();
- if (!page_table.IsInsideAddressSpace(src_address, size)) {
+ auto& page_table = process->GetPageTable();
+ if (!page_table.Contains(src_address, size)) {
LOG_ERROR(Kernel_SVC,
"Source address range is not within the address space (src_address=0x{:016X}, "
"size=0x{:016X}).",
@@ -187,14 +187,6 @@ Result MapProcessCodeMemory(Core::System& system, Handle process_handle, u64 dst
R_THROW(ResultInvalidCurrentMemory);
}
- if (!page_table.IsInsideASLRRegion(dst_address, size)) {
- LOG_ERROR(Kernel_SVC,
- "Destination address range is not within the ASLR region (dst_address=0x{:016X}, "
- "size=0x{:016X}).",
- dst_address, size);
- R_THROW(ResultInvalidMemoryRegion);
- }
-
R_RETURN(page_table.MapCodeMemory(dst_address, src_address, size));
}
@@ -246,8 +238,8 @@ Result UnmapProcessCodeMemory(Core::System& system, Handle process_handle, u64 d
R_THROW(ResultInvalidHandle);
}
- auto& page_table = process->PageTable();
- if (!page_table.IsInsideAddressSpace(src_address, size)) {
+ auto& page_table = process->GetPageTable();
+ if (!page_table.Contains(src_address, size)) {
LOG_ERROR(Kernel_SVC,
"Source address range is not within the address space (src_address=0x{:016X}, "
"size=0x{:016X}).",
@@ -255,14 +247,6 @@ Result UnmapProcessCodeMemory(Core::System& system, Handle process_handle, u64 d
R_THROW(ResultInvalidCurrentMemory);
}
- if (!page_table.IsInsideASLRRegion(dst_address, size)) {
- LOG_ERROR(Kernel_SVC,
- "Destination address range is not within the ASLR region (dst_address=0x{:016X}, "
- "size=0x{:016X}).",
- dst_address, size);
- R_THROW(ResultInvalidMemoryRegion);
- }
-
R_RETURN(page_table.UnmapCodeMemory(dst_address, src_address, size,
KPageTable::ICacheInvalidationStrategy::InvalidateAll));
}
diff --git a/src/core/hle/kernel/svc/svc_query_memory.cpp b/src/core/hle/kernel/svc/svc_query_memory.cpp
index 4d9fcd25f..51af06e97 100644
--- a/src/core/hle/kernel/svc/svc_query_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_query_memory.cpp
@@ -31,7 +31,7 @@ Result QueryProcessMemory(Core::System& system, uint64_t out_memory_info, PageIn
}
auto& current_memory{GetCurrentMemory(system.Kernel())};
- const auto memory_info{process->PageTable().QueryInfo(address).GetSvcMemoryInfo()};
+ const auto memory_info{process->GetPageTable().QueryInfo(address).GetSvcMemoryInfo()};
current_memory.WriteBlock(out_memory_info, std::addressof(memory_info), sizeof(memory_info));
diff --git a/src/core/hle/kernel/svc/svc_shared_memory.cpp b/src/core/hle/kernel/svc/svc_shared_memory.cpp
index a698596aa..012b1ae2b 100644
--- a/src/core/hle/kernel/svc/svc_shared_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_shared_memory.cpp
@@ -43,7 +43,7 @@ Result MapSharedMemory(Core::System& system, Handle shmem_handle, u64 address, u
// Get the current process.
auto& process = GetCurrentProcess(system.Kernel());
- auto& page_table = process.PageTable();
+ auto& page_table = process.GetPageTable();
// Get the shared memory.
KScopedAutoObject shmem = process.GetHandleTable().GetObject<KSharedMemory>(shmem_handle);
@@ -73,7 +73,7 @@ Result UnmapSharedMemory(Core::System& system, Handle shmem_handle, u64 address,
// Get the current process.
auto& process = GetCurrentProcess(system.Kernel());
- auto& page_table = process.PageTable();
+ auto& page_table = process.GetPageTable();
// Get the shared memory.
KScopedAutoObject shmem = process.GetHandleTable().GetObject<KSharedMemory>(shmem_handle);
diff --git a/src/core/hle/kernel/svc/svc_synchronization.cpp b/src/core/hle/kernel/svc/svc_synchronization.cpp
index f02d03f30..366e8ed4a 100644
--- a/src/core/hle/kernel/svc/svc_synchronization.cpp
+++ b/src/core/hle/kernel/svc/svc_synchronization.cpp
@@ -7,6 +7,7 @@
#include "core/hle/kernel/k_process.h"
#include "core/hle/kernel/k_readable_event.h"
#include "core/hle/kernel/svc.h"
+#include "core/hle/kernel/svc_results.h"
namespace Kernel::Svc {
@@ -64,14 +65,10 @@ Result WaitSynchronization(Core::System& system, int32_t* out_index, u64 user_ha
// Copy user handles.
if (num_handles > 0) {
- // Ensure we can try to get the handles.
- R_UNLESS(GetCurrentMemory(kernel).IsValidVirtualAddressRange(
- user_handles, static_cast<u64>(sizeof(Handle) * num_handles)),
- ResultInvalidPointer);
-
// Get the handles.
- GetCurrentMemory(kernel).ReadBlock(user_handles, handles.data(),
- sizeof(Handle) * num_handles);
+ R_UNLESS(GetCurrentMemory(kernel).ReadBlock(user_handles, handles.data(),
+ sizeof(Handle) * num_handles),
+ ResultInvalidPointer);
// Convert the handles to objects.
R_UNLESS(handle_table.GetMultipleObjects<KSynchronizationObject>(
diff --git a/src/core/hle/kernel/svc/svc_thread.cpp b/src/core/hle/kernel/svc/svc_thread.cpp
index 36b94e6bf..92bcea72b 100644
--- a/src/core/hle/kernel/svc/svc_thread.cpp
+++ b/src/core/hle/kernel/svc/svc_thread.cpp
@@ -236,7 +236,7 @@ Result GetThreadList(Core::System& system, s32* out_num_threads, u64 out_thread_
const auto total_copy_size = out_thread_ids_size * sizeof(u64);
if (out_thread_ids_size > 0 &&
- !current_process->PageTable().IsInsideAddressSpace(out_thread_ids, total_copy_size)) {
+ !current_process->GetPageTable().Contains(out_thread_ids, total_copy_size)) {
LOG_ERROR(Kernel_SVC, "Address range outside address space. begin=0x{:016X}, end=0x{:016X}",
out_thread_ids, out_thread_ids + total_copy_size);
R_THROW(ResultInvalidCurrentMemory);
diff --git a/src/core/hle/kernel/svc/svc_transfer_memory.cpp b/src/core/hle/kernel/svc/svc_transfer_memory.cpp
index 82d469a37..7d94e7f09 100644
--- a/src/core/hle/kernel/svc/svc_transfer_memory.cpp
+++ b/src/core/hle/kernel/svc/svc_transfer_memory.cpp
@@ -55,7 +55,7 @@ Result CreateTransferMemory(Core::System& system, Handle* out, u64 address, u64
SCOPE_EXIT({ trmem->Close(); });
// Ensure that the region is in range.
- R_UNLESS(process.PageTable().Contains(address, size), ResultInvalidCurrentMemory);
+ R_UNLESS(process.GetPageTable().Contains(address, size), ResultInvalidCurrentMemory);
// Initialize the transfer memory.
R_TRY(trmem->Initialize(address, size, map_perm));
diff --git a/src/core/hle/service/acc/acc.cpp b/src/core/hle/service/acc/acc.cpp
index 6c29cb613..2632cd3ef 100644
--- a/src/core/hle/service/acc/acc.cpp
+++ b/src/core/hle/service/acc/acc.cpp
@@ -496,8 +496,9 @@ public:
void LoadIdTokenCache(HLERequestContext& ctx) {
LOG_WARNING(Service_ACC, "(STUBBED) called");
- IPC::ResponseBuilder rb{ctx, 2};
+ IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
+ rb.Push(0);
}
protected:
diff --git a/src/core/hle/service/am/am.cpp b/src/core/hle/service/am/am.cpp
index a2375508a..4f400d341 100644
--- a/src/core/hle/service/am/am.cpp
+++ b/src/core/hle/service/am/am.cpp
@@ -506,8 +506,8 @@ void ISelfController::SetHandlesRequestToDisplay(HLERequestContext& ctx) {
void ISelfController::SetIdleTimeDetectionExtension(HLERequestContext& ctx) {
IPC::RequestParser rp{ctx};
idle_time_detection_extension = rp.Pop<u32>();
- LOG_WARNING(Service_AM, "(STUBBED) called idle_time_detection_extension={}",
- idle_time_detection_extension);
+ LOG_DEBUG(Service_AM, "(STUBBED) called idle_time_detection_extension={}",
+ idle_time_detection_extension);
IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ResultSuccess);
diff --git a/src/core/hle/service/glue/ectx.cpp b/src/core/hle/service/glue/ectx.cpp
index 1bd9314ae..6f71b62f3 100644
--- a/src/core/hle/service/glue/ectx.cpp
+++ b/src/core/hle/service/glue/ectx.cpp
@@ -2,13 +2,48 @@
// SPDX-License-Identifier: GPL-2.0-or-later
#include "core/hle/service/glue/ectx.h"
+#include "core/hle/service/ipc_helpers.h"
namespace Service::Glue {
+// This is nn::err::context::IContextRegistrar
+class IContextRegistrar : public ServiceFramework<IContextRegistrar> {
+public:
+ IContextRegistrar(Core::System& system_) : ServiceFramework{system_, "IContextRegistrar"} {
+ // clang-format off
+ static const FunctionInfo functions[] = {
+ {0, &IContextRegistrar::Complete, "Complete"},
+ };
+ // clang-format on
+
+ RegisterHandlers(functions);
+ }
+
+ ~IContextRegistrar() override = default;
+
+private:
+ void Complete(HLERequestContext& ctx) {
+ struct InputParameters {
+ u32 unk;
+ };
+ struct OutputParameters {
+ u32 unk;
+ };
+
+ IPC::RequestParser rp{ctx};
+ [[maybe_unused]] auto input = rp.PopRaw<InputParameters>();
+ [[maybe_unused]] auto value = ctx.ReadBuffer();
+
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(ResultSuccess);
+ rb.Push(0);
+ }
+};
+
ECTX_AW::ECTX_AW(Core::System& system_) : ServiceFramework{system_, "ectx:aw"} {
// clang-format off
static const FunctionInfo functions[] = {
- {0, nullptr, "CreateContextRegistrar"},
+ {0, &ECTX_AW::CreateContextRegistrar, "CreateContextRegistrar"},
{1, nullptr, "CommitContext"},
};
// clang-format on
@@ -18,4 +53,10 @@ ECTX_AW::ECTX_AW(Core::System& system_) : ServiceFramework{system_, "ectx:aw"} {
ECTX_AW::~ECTX_AW() = default;
+void ECTX_AW::CreateContextRegistrar(HLERequestContext& ctx) {
+ IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+ rb.Push(ResultSuccess);
+ rb.PushIpcInterface<IContextRegistrar>(std::make_shared<IContextRegistrar>(system));
+}
+
} // namespace Service::Glue
diff --git a/src/core/hle/service/glue/ectx.h b/src/core/hle/service/glue/ectx.h
index a608de053..ffa74d8d3 100644
--- a/src/core/hle/service/glue/ectx.h
+++ b/src/core/hle/service/glue/ectx.h
@@ -15,6 +15,9 @@ class ECTX_AW final : public ServiceFramework<ECTX_AW> {
public:
explicit ECTX_AW(Core::System& system_);
~ECTX_AW() override;
+
+private:
+ void CreateContextRegistrar(HLERequestContext& ctx);
};
} // namespace Service::Glue
diff --git a/src/core/hle/service/ldr/ldr.cpp b/src/core/hle/service/ldr/ldr.cpp
index c42489ff9..055c0a2db 100644
--- a/src/core/hle/service/ldr/ldr.cpp
+++ b/src/core/hle/service/ldr/ldr.cpp
@@ -318,15 +318,15 @@ public:
return false;
}
- if (!page_table.IsInsideAddressSpace(out_addr, size)) {
+ if (!page_table.Contains(out_addr, size)) {
return false;
}
- if (page_table.IsInsideHeapRegion(out_addr, size)) {
+ if (page_table.IsInHeapRegion(out_addr, size)) {
return false;
}
- if (page_table.IsInsideAliasRegion(out_addr, size)) {
+ if (page_table.IsInAliasRegion(out_addr, size)) {
return false;
}
@@ -358,7 +358,7 @@ public:
}
ResultVal<VAddr> MapProcessCodeMemory(Kernel::KProcess* process, VAddr base_addr, u64 size) {
- auto& page_table{process->PageTable()};
+ auto& page_table{process->GetPageTable()};
VAddr addr{};
for (std::size_t retry = 0; retry < MAXIMUM_MAP_RETRIES; retry++) {
@@ -382,7 +382,7 @@ public:
ResultVal<VAddr> MapNro(Kernel::KProcess* process, VAddr nro_addr, std::size_t nro_size,
VAddr bss_addr, std::size_t bss_size, std::size_t size) {
for (std::size_t retry = 0; retry < MAXIMUM_MAP_RETRIES; retry++) {
- auto& page_table{process->PageTable()};
+ auto& page_table{process->GetPageTable()};
VAddr addr{};
CASCADE_RESULT(addr, MapProcessCodeMemory(process, nro_addr, nro_size));
@@ -437,12 +437,12 @@ public:
CopyCode(nro_addr + nro_header.segment_headers[DATA_INDEX].memory_offset, data_start,
nro_header.segment_headers[DATA_INDEX].memory_size);
- CASCADE_CODE(process->PageTable().SetProcessMemoryPermission(
+ CASCADE_CODE(process->GetPageTable().SetProcessMemoryPermission(
text_start, ro_start - text_start, Kernel::Svc::MemoryPermission::ReadExecute));
- CASCADE_CODE(process->PageTable().SetProcessMemoryPermission(
+ CASCADE_CODE(process->GetPageTable().SetProcessMemoryPermission(
ro_start, data_start - ro_start, Kernel::Svc::MemoryPermission::Read));
- return process->PageTable().SetProcessMemoryPermission(
+ return process->GetPageTable().SetProcessMemoryPermission(
data_start, bss_end_addr - data_start, Kernel::Svc::MemoryPermission::ReadWrite);
}
@@ -571,7 +571,7 @@ public:
Result UnmapNro(const NROInfo& info) {
// Each region must be unmapped separately to validate memory state
- auto& page_table{system.ApplicationProcess()->PageTable()};
+ auto& page_table{system.ApplicationProcess()->GetPageTable()};
if (info.bss_size != 0) {
CASCADE_CODE(page_table.UnmapCodeMemory(
@@ -643,7 +643,7 @@ public:
initialized = true;
current_map_addr =
- GetInteger(system.ApplicationProcess()->PageTable().GetAliasCodeRegionStart());
+ GetInteger(system.ApplicationProcess()->GetPageTable().GetAliasCodeRegionStart());
IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ResultSuccess);
diff --git a/src/core/hle/service/nfc/common/amiibo_crypto.cpp b/src/core/hle/service/nfc/common/amiibo_crypto.cpp
index bc232c334..9556e9193 100644
--- a/src/core/hle/service/nfc/common/amiibo_crypto.cpp
+++ b/src/core/hle/service/nfc/common/amiibo_crypto.cpp
@@ -180,7 +180,7 @@ std::vector<u8> GenerateInternalKey(const InternalKey& key, const HashSeed& seed
}
void CryptoInit(CryptoCtx& ctx, mbedtls_md_context_t& hmac_ctx, const HmacKey& hmac_key,
- const std::vector<u8>& seed) {
+ std::span<const u8> seed) {
// Initialize context
ctx.used = false;
ctx.counter = 0;
diff --git a/src/core/hle/service/nfc/common/amiibo_crypto.h b/src/core/hle/service/nfc/common/amiibo_crypto.h
index 6a3e0841e..2cc0e4d51 100644
--- a/src/core/hle/service/nfc/common/amiibo_crypto.h
+++ b/src/core/hle/service/nfc/common/amiibo_crypto.h
@@ -75,7 +75,7 @@ std::vector<u8> GenerateInternalKey(const InternalKey& key, const HashSeed& seed
// Initializes mbedtls context
void CryptoInit(CryptoCtx& ctx, mbedtls_md_context_t& hmac_ctx, const HmacKey& hmac_key,
- const std::vector<u8>& seed);
+ std::span<const u8> seed);
// Feeds data to mbedtls context to generate the derived key
void CryptoStep(CryptoCtx& ctx, mbedtls_md_context_t& hmac_ctx, DrgbOutput& output);
diff --git a/src/core/hle/service/nfc/common/device.cpp b/src/core/hle/service/nfc/common/device.cpp
index 2d633b03f..49446bc42 100644
--- a/src/core/hle/service/nfc/common/device.cpp
+++ b/src/core/hle/service/nfc/common/device.cpp
@@ -34,8 +34,6 @@
#include "core/hle/service/nfc/mifare_result.h"
#include "core/hle/service/nfc/nfc_result.h"
#include "core/hle/service/time/time_manager.h"
-#include "core/hle/service/time/time_zone_content_manager.h"
-#include "core/hle/service/time/time_zone_types.h"
namespace Service::NFC {
NfcDevice::NfcDevice(Core::HID::NpadIdType npad_id_, Core::System& system_,
@@ -1486,6 +1484,7 @@ DeviceState NfcDevice::GetCurrentState() const {
}
Result NfcDevice::GetNpadId(Core::HID::NpadIdType& out_npad_id) const {
+ // TODO: This should get the npad id from nn::hid::system::GetXcdHandleForNpadWithNfc
out_npad_id = npad_id;
return ResultSuccess;
}
diff --git a/src/core/hle/service/nfc/common/device_manager.cpp b/src/core/hle/service/nfc/common/device_manager.cpp
index 562f3a28e..a71d26157 100644
--- a/src/core/hle/service/nfc/common/device_manager.cpp
+++ b/src/core/hle/service/nfc/common/device_manager.cpp
@@ -1,6 +1,8 @@
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: GPL-3.0-or-later
+#include <algorithm>
+
#include "common/logging/log.h"
#include "core/core.h"
#include "core/hid/hid_types.h"
@@ -10,6 +12,7 @@
#include "core/hle/service/nfc/common/device_manager.h"
#include "core/hle/service/nfc/nfc_result.h"
#include "core/hle/service/time/clock_types.h"
+#include "core/hle/service/time/time_manager.h"
namespace Service::NFC {
@@ -51,22 +54,53 @@ Result DeviceManager::Finalize() {
return ResultSuccess;
}
-Result DeviceManager::ListDevices(std::vector<u64>& nfp_devices,
- std::size_t max_allowed_devices) const {
+Result DeviceManager::ListDevices(std::vector<u64>& nfp_devices, std::size_t max_allowed_devices,
+ bool skip_fatal_errors) const {
+ std::scoped_lock lock{mutex};
+ if (max_allowed_devices < 1) {
+ return ResultInvalidArgument;
+ }
+
+ Result result = IsNfcParameterSet();
+ if (result.IsError()) {
+ return result;
+ }
+
+ result = IsNfcEnabled();
+ if (result.IsError()) {
+ return result;
+ }
+
+ result = IsNfcInitialized();
+ if (result.IsError()) {
+ return result;
+ }
+
for (auto& device : devices) {
if (nfp_devices.size() >= max_allowed_devices) {
continue;
}
- if (device->GetCurrentState() != DeviceState::Unavailable) {
- nfp_devices.push_back(device->GetHandle());
+ if (skip_fatal_errors) {
+ constexpr u64 MinimumRecoveryTime = 60;
+ auto& standard_steady_clock{system.GetTimeManager().GetStandardSteadyClockCore()};
+ const u64 elapsed_time = standard_steady_clock.GetCurrentTimePoint(system).time_point -
+ time_since_last_error;
+
+ if (time_since_last_error != 0 && elapsed_time < MinimumRecoveryTime) {
+ continue;
+ }
}
+ if (device->GetCurrentState() == DeviceState::Unavailable) {
+ continue;
+ }
+ nfp_devices.push_back(device->GetHandle());
}
if (nfp_devices.empty()) {
return ResultDeviceNotFound;
}
- return ResultSuccess;
+ return result;
}
DeviceState DeviceManager::GetDeviceState(u64 device_handle) const {
@@ -79,10 +113,10 @@ DeviceState DeviceManager::GetDeviceState(u64 device_handle) const {
return device->GetCurrentState();
}
- return DeviceState::Unavailable;
+ return DeviceState::Finalized;
}
-Result DeviceManager::GetNpadId(u64 device_handle, Core::HID::NpadIdType& npad_id) const {
+Result DeviceManager::GetNpadId(u64 device_handle, Core::HID::NpadIdType& npad_id) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -128,7 +162,7 @@ Result DeviceManager::StopDetection(u64 device_handle) {
return result;
}
-Result DeviceManager::GetTagInfo(u64 device_handle, TagInfo& tag_info) const {
+Result DeviceManager::GetTagInfo(u64 device_handle, TagInfo& tag_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -142,24 +176,46 @@ Result DeviceManager::GetTagInfo(u64 device_handle, TagInfo& tag_info) const {
return result;
}
-Kernel::KReadableEvent& DeviceManager::AttachActivateEvent(u64 device_handle) const {
- std::scoped_lock lock{mutex};
-
+Result DeviceManager::AttachActivateEvent(Kernel::KReadableEvent** out_event,
+ u64 device_handle) const {
+ std::vector<u64> nfp_devices;
std::shared_ptr<NfcDevice> device = nullptr;
- GetDeviceFromHandle(device_handle, device, false);
+ Result result = ListDevices(nfp_devices, 9, false);
- // TODO: Return proper error code on failure
- return device->GetActivateEvent();
-}
+ if (result.IsSuccess()) {
+ result = CheckHandleOnList(device_handle, nfp_devices);
+ }
-Kernel::KReadableEvent& DeviceManager::AttachDeactivateEvent(u64 device_handle) const {
- std::scoped_lock lock{mutex};
+ if (result.IsSuccess()) {
+ result = GetDeviceFromHandle(device_handle, device, false);
+ }
+
+ if (result.IsSuccess()) {
+ *out_event = &device->GetActivateEvent();
+ }
+
+ return result;
+}
+Result DeviceManager::AttachDeactivateEvent(Kernel::KReadableEvent** out_event,
+ u64 device_handle) const {
+ std::vector<u64> nfp_devices;
std::shared_ptr<NfcDevice> device = nullptr;
- GetDeviceFromHandle(device_handle, device, false);
+ Result result = ListDevices(nfp_devices, 9, false);
- // TODO: Return proper error code on failure
- return device->GetDeactivateEvent();
+ if (result.IsSuccess()) {
+ result = CheckHandleOnList(device_handle, nfp_devices);
+ }
+
+ if (result.IsSuccess()) {
+ result = GetDeviceFromHandle(device_handle, device, false);
+ }
+
+ if (result.IsSuccess()) {
+ *out_event = &device->GetDeactivateEvent();
+ }
+
+ return result;
}
Result DeviceManager::ReadMifare(u64 device_handle,
@@ -253,7 +309,7 @@ Result DeviceManager::OpenApplicationArea(u64 device_handle, u32 access_id) {
return result;
}
-Result DeviceManager::GetApplicationArea(u64 device_handle, std::span<u8> data) const {
+Result DeviceManager::GetApplicationArea(u64 device_handle, std::span<u8> data) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -324,7 +380,7 @@ Result DeviceManager::CreateApplicationArea(u64 device_handle, u32 access_id,
return result;
}
-Result DeviceManager::GetRegisterInfo(u64 device_handle, NFP::RegisterInfo& register_info) const {
+Result DeviceManager::GetRegisterInfo(u64 device_handle, NFP::RegisterInfo& register_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -338,7 +394,7 @@ Result DeviceManager::GetRegisterInfo(u64 device_handle, NFP::RegisterInfo& regi
return result;
}
-Result DeviceManager::GetCommonInfo(u64 device_handle, NFP::CommonInfo& common_info) const {
+Result DeviceManager::GetCommonInfo(u64 device_handle, NFP::CommonInfo& common_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -352,7 +408,7 @@ Result DeviceManager::GetCommonInfo(u64 device_handle, NFP::CommonInfo& common_i
return result;
}
-Result DeviceManager::GetModelInfo(u64 device_handle, NFP::ModelInfo& model_info) const {
+Result DeviceManager::GetModelInfo(u64 device_handle, NFP::ModelInfo& model_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -399,7 +455,7 @@ Result DeviceManager::Format(u64 device_handle) {
return result;
}
-Result DeviceManager::GetAdminInfo(u64 device_handle, NFP::AdminInfo& admin_info) const {
+Result DeviceManager::GetAdminInfo(u64 device_handle, NFP::AdminInfo& admin_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -414,7 +470,7 @@ Result DeviceManager::GetAdminInfo(u64 device_handle, NFP::AdminInfo& admin_info
}
Result DeviceManager::GetRegisterInfoPrivate(u64 device_handle,
- NFP::RegisterInfoPrivate& register_info) const {
+ NFP::RegisterInfoPrivate& register_info) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -471,7 +527,7 @@ Result DeviceManager::DeleteApplicationArea(u64 device_handle) {
return result;
}
-Result DeviceManager::ExistsApplicationArea(u64 device_handle, bool& has_application_area) const {
+Result DeviceManager::ExistsApplicationArea(u64 device_handle, bool& has_application_area) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -485,7 +541,7 @@ Result DeviceManager::ExistsApplicationArea(u64 device_handle, bool& has_applica
return result;
}
-Result DeviceManager::GetAll(u64 device_handle, NFP::NfpData& nfp_data) const {
+Result DeviceManager::GetAll(u64 device_handle, NFP::NfpData& nfp_data) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -541,7 +597,7 @@ Result DeviceManager::BreakTag(u64 device_handle, NFP::BreakType break_type) {
return result;
}
-Result DeviceManager::ReadBackupData(u64 device_handle, std::span<u8> data) const {
+Result DeviceManager::ReadBackupData(u64 device_handle, std::span<u8> data) {
std::scoped_lock lock{mutex};
std::shared_ptr<NfcDevice> device = nullptr;
@@ -593,6 +649,19 @@ Result DeviceManager::WriteNtf(u64 device_handle, NFP::WriteType, std::span<cons
return result;
}
+Result DeviceManager::CheckHandleOnList(u64 device_handle,
+ const std::span<const u64> device_list) const {
+ if (device_list.size() < 1) {
+ return ResultDeviceNotFound;
+ }
+
+ if (std::find(device_list.begin(), device_list.end(), device_handle) != device_list.end()) {
+ return ResultSuccess;
+ }
+
+ return ResultDeviceNotFound;
+}
+
Result DeviceManager::GetDeviceFromHandle(u64 handle, std::shared_ptr<NfcDevice>& nfc_device,
bool check_state) const {
if (check_state) {
@@ -647,7 +716,7 @@ Result DeviceManager::GetDeviceHandle(u64 handle, std::shared_ptr<NfcDevice>& de
}
Result DeviceManager::VerifyDeviceResult(std::shared_ptr<NfcDevice> device,
- Result operation_result) const {
+ Result operation_result) {
if (operation_result.IsSuccess()) {
return operation_result;
}
@@ -669,6 +738,12 @@ Result DeviceManager::VerifyDeviceResult(std::shared_ptr<NfcDevice> device,
return device_state;
}
+ if (operation_result == ResultUnknown112 || operation_result == ResultUnknown114 ||
+ operation_result == ResultUnknown115) {
+ auto& standard_steady_clock{system.GetTimeManager().GetStandardSteadyClockCore()};
+ time_since_last_error = standard_steady_clock.GetCurrentTimePoint(system).time_point;
+ }
+
return operation_result;
}
diff --git a/src/core/hle/service/nfc/common/device_manager.h b/src/core/hle/service/nfc/common/device_manager.h
index c61ba0cf3..c9f038e32 100644
--- a/src/core/hle/service/nfc/common/device_manager.h
+++ b/src/core/hle/service/nfc/common/device_manager.h
@@ -27,15 +27,16 @@ public:
// Nfc device manager
Result Initialize();
Result Finalize();
- Result ListDevices(std::vector<u64>& nfp_devices, std::size_t max_allowed_devices) const;
+ Result ListDevices(std::vector<u64>& nfp_devices, std::size_t max_allowed_devices,
+ bool skip_fatal_errors) const;
DeviceState GetDeviceState(u64 device_handle) const;
- Result GetNpadId(u64 device_handle, Core::HID::NpadIdType& npad_id) const;
+ Result GetNpadId(u64 device_handle, Core::HID::NpadIdType& npad_id);
Kernel::KReadableEvent& AttachAvailabilityChangeEvent() const;
Result StartDetection(u64 device_handle, NfcProtocol tag_protocol);
Result StopDetection(u64 device_handle);
- Result GetTagInfo(u64 device_handle, NFP::TagInfo& tag_info) const;
- Kernel::KReadableEvent& AttachActivateEvent(u64 device_handle) const;
- Kernel::KReadableEvent& AttachDeactivateEvent(u64 device_handle) const;
+ Result GetTagInfo(u64 device_handle, NFP::TagInfo& tag_info);
+ Result AttachActivateEvent(Kernel::KReadableEvent** event, u64 device_handle) const;
+ Result AttachDeactivateEvent(Kernel::KReadableEvent** event, u64 device_handle) const;
Result ReadMifare(u64 device_handle,
const std::span<const MifareReadBlockParameter> read_parameters,
std::span<MifareReadBlockData> read_data);
@@ -48,28 +49,28 @@ public:
Result Mount(u64 device_handle, NFP::ModelType model_type, NFP::MountTarget mount_target);
Result Unmount(u64 device_handle);
Result OpenApplicationArea(u64 device_handle, u32 access_id);
- Result GetApplicationArea(u64 device_handle, std::span<u8> data) const;
+ Result GetApplicationArea(u64 device_handle, std::span<u8> data);
Result SetApplicationArea(u64 device_handle, std::span<const u8> data);
Result Flush(u64 device_handle);
Result Restore(u64 device_handle);
Result CreateApplicationArea(u64 device_handle, u32 access_id, std::span<const u8> data);
- Result GetRegisterInfo(u64 device_handle, NFP::RegisterInfo& register_info) const;
- Result GetCommonInfo(u64 device_handle, NFP::CommonInfo& common_info) const;
- Result GetModelInfo(u64 device_handle, NFP::ModelInfo& model_info) const;
+ Result GetRegisterInfo(u64 device_handle, NFP::RegisterInfo& register_info);
+ Result GetCommonInfo(u64 device_handle, NFP::CommonInfo& common_info);
+ Result GetModelInfo(u64 device_handle, NFP::ModelInfo& model_info);
u32 GetApplicationAreaSize() const;
Result RecreateApplicationArea(u64 device_handle, u32 access_id, std::span<const u8> data);
Result Format(u64 device_handle);
- Result GetAdminInfo(u64 device_handle, NFP::AdminInfo& admin_info) const;
- Result GetRegisterInfoPrivate(u64 device_handle, NFP::RegisterInfoPrivate& register_info) const;
+ Result GetAdminInfo(u64 device_handle, NFP::AdminInfo& admin_info);
+ Result GetRegisterInfoPrivate(u64 device_handle, NFP::RegisterInfoPrivate& register_info);
Result SetRegisterInfoPrivate(u64 device_handle, const NFP::RegisterInfoPrivate& register_info);
Result DeleteRegisterInfo(u64 device_handle);
Result DeleteApplicationArea(u64 device_handle);
- Result ExistsApplicationArea(u64 device_handle, bool& has_application_area) const;
- Result GetAll(u64 device_handle, NFP::NfpData& nfp_data) const;
+ Result ExistsApplicationArea(u64 device_handle, bool& has_application_area);
+ Result GetAll(u64 device_handle, NFP::NfpData& nfp_data);
Result SetAll(u64 device_handle, const NFP::NfpData& nfp_data);
Result FlushDebug(u64 device_handle);
Result BreakTag(u64 device_handle, NFP::BreakType break_type);
- Result ReadBackupData(u64 device_handle, std::span<u8> data) const;
+ Result ReadBackupData(u64 device_handle, std::span<u8> data);
Result WriteBackupData(u64 device_handle, std::span<const u8> data);
Result WriteNtf(u64 device_handle, NFP::WriteType, std::span<const u8> data);
@@ -78,17 +79,20 @@ private:
Result IsNfcParameterSet() const;
Result IsNfcInitialized() const;
+ Result CheckHandleOnList(u64 device_handle, std::span<const u64> device_list) const;
+
Result GetDeviceFromHandle(u64 handle, std::shared_ptr<NfcDevice>& device,
bool check_state) const;
Result GetDeviceHandle(u64 handle, std::shared_ptr<NfcDevice>& device) const;
- Result VerifyDeviceResult(std::shared_ptr<NfcDevice> device, Result operation_result) const;
+ Result VerifyDeviceResult(std::shared_ptr<NfcDevice> device, Result operation_result);
Result CheckDeviceState(std::shared_ptr<NfcDevice> device) const;
std::optional<std::shared_ptr<NfcDevice>> GetNfcDevice(u64 handle);
const std::optional<std::shared_ptr<NfcDevice>> GetNfcDevice(u64 handle) const;
bool is_initialized = false;
+ u64 time_since_last_error = 0;
mutable std::mutex mutex;
std::array<std::shared_ptr<NfcDevice>, 10> devices{};
diff --git a/src/core/hle/service/nfc/nfc_interface.cpp b/src/core/hle/service/nfc/nfc_interface.cpp
index e7ca7582e..179c7ba2c 100644
--- a/src/core/hle/service/nfc/nfc_interface.cpp
+++ b/src/core/hle/service/nfc/nfc_interface.cpp
@@ -79,7 +79,7 @@ void NfcInterface::ListDevices(HLERequestContext& ctx) {
const std::size_t max_allowed_devices = ctx.GetWriteBufferNumElements<u64>();
LOG_DEBUG(Service_NFC, "called");
- auto result = GetManager()->ListDevices(nfp_devices, max_allowed_devices);
+ auto result = GetManager()->ListDevices(nfp_devices, max_allowed_devices, true);
result = TranslateResultToServiceError(result);
if (result.IsError()) {
@@ -190,9 +190,13 @@ void NfcInterface::AttachActivateEvent(HLERequestContext& ctx) {
const auto device_handle{rp.Pop<u64>()};
LOG_DEBUG(Service_NFC, "called, device_handle={}", device_handle);
+ Kernel::KReadableEvent* out_event = nullptr;
+ auto result = GetManager()->AttachActivateEvent(&out_event, device_handle);
+ result = TranslateResultToServiceError(result);
+
IPC::ResponseBuilder rb{ctx, 2, 1};
- rb.Push(ResultSuccess);
- rb.PushCopyObjects(GetManager()->AttachActivateEvent(device_handle));
+ rb.Push(result);
+ rb.PushCopyObjects(out_event);
}
void NfcInterface::AttachDeactivateEvent(HLERequestContext& ctx) {
@@ -200,9 +204,13 @@ void NfcInterface::AttachDeactivateEvent(HLERequestContext& ctx) {
const auto device_handle{rp.Pop<u64>()};
LOG_DEBUG(Service_NFC, "called, device_handle={}", device_handle);
+ Kernel::KReadableEvent* out_event = nullptr;
+ auto result = GetManager()->AttachDeactivateEvent(&out_event, device_handle);
+ result = TranslateResultToServiceError(result);
+
IPC::ResponseBuilder rb{ctx, 2, 1};
- rb.Push(ResultSuccess);
- rb.PushCopyObjects(GetManager()->AttachDeactivateEvent(device_handle));
+ rb.Push(result);
+ rb.PushCopyObjects(out_event);
}
void NfcInterface::ReadMifare(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/nfc/nfc_result.h b/src/core/hle/service/nfc/nfc_result.h
index 715c0e80c..464b5fd69 100644
--- a/src/core/hle/service/nfc/nfc_result.h
+++ b/src/core/hle/service/nfc/nfc_result.h
@@ -17,7 +17,10 @@ constexpr Result ResultNfcNotInitialized(ErrorModule::NFC, 77);
constexpr Result ResultNfcDisabled(ErrorModule::NFC, 80);
constexpr Result ResultWriteAmiiboFailed(ErrorModule::NFC, 88);
constexpr Result ResultTagRemoved(ErrorModule::NFC, 97);
+constexpr Result ResultUnknown112(ErrorModule::NFC, 112);
constexpr Result ResultUnableToAccessBackupFile(ErrorModule::NFC, 113);
+constexpr Result ResultUnknown114(ErrorModule::NFC, 114);
+constexpr Result ResultUnknown115(ErrorModule::NFC, 115);
constexpr Result ResultRegistrationIsNotInitialized(ErrorModule::NFC, 120);
constexpr Result ResultApplicationAreaIsNotInitialized(ErrorModule::NFC, 128);
constexpr Result ResultCorruptedDataWithBackup(ErrorModule::NFC, 136);
diff --git a/src/core/hle/service/nifm/nifm.cpp b/src/core/hle/service/nifm/nifm.cpp
index 91d42853e..21b06d10b 100644
--- a/src/core/hle/service/nifm/nifm.cpp
+++ b/src/core/hle/service/nifm/nifm.cpp
@@ -7,6 +7,7 @@
#include "core/hle/service/kernel_helpers.h"
#include "core/hle/service/nifm/nifm.h"
#include "core/hle/service/server_manager.h"
+#include "network/network.h"
namespace {
diff --git a/src/core/hle/service/nifm/nifm.h b/src/core/hle/service/nifm/nifm.h
index 9b20e6823..ae99c4695 100644
--- a/src/core/hle/service/nifm/nifm.h
+++ b/src/core/hle/service/nifm/nifm.h
@@ -4,14 +4,15 @@
#pragma once
#include "core/hle/service/service.h"
-#include "network/network.h"
-#include "network/room.h"
-#include "network/room_member.h"
namespace Core {
class System;
}
+namespace Network {
+class RoomNetwork;
+}
+
namespace Service::NIFM {
void LoopProcess(Core::System& system);
diff --git a/src/core/hle/service/nvdrv/devices/nvmap.cpp b/src/core/hle/service/nvdrv/devices/nvmap.cpp
index e7f7e273b..968eaa175 100644
--- a/src/core/hle/service/nvdrv/devices/nvmap.cpp
+++ b/src/core/hle/service/nvdrv/devices/nvmap.cpp
@@ -128,7 +128,7 @@ NvResult nvmap::IocAlloc(std::span<const u8> input, std::span<u8> output) {
}
bool is_out_io{};
ASSERT(system.ApplicationProcess()
- ->PageTable()
+ ->GetPageTable()
.LockForMapDeviceAddressSpace(&is_out_io, handle_description->address,
handle_description->size,
Kernel::KMemoryPermission::None, true, false)
@@ -255,7 +255,7 @@ NvResult nvmap::IocFree(std::span<const u8> input, std::span<u8> output) {
if (auto freeInfo{file.FreeHandle(params.handle, false)}) {
if (freeInfo->can_unlock) {
ASSERT(system.ApplicationProcess()
- ->PageTable()
+ ->GetPageTable()
.UnlockForDeviceAddressSpace(freeInfo->address, freeInfo->size)
.IsSuccess());
}
diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp
index bce45d321..11f8efbac 100644
--- a/src/core/hle/service/sockets/bsd.cpp
+++ b/src/core/hle/service/sockets/bsd.cpp
@@ -20,6 +20,9 @@
#include "core/internal_network/sockets.h"
#include "network/network.h"
+using Common::Expected;
+using Common::Unexpected;
+
namespace Service::Sockets {
namespace {
@@ -265,16 +268,19 @@ void BSD::GetSockOpt(HLERequestContext& ctx) {
const u32 level = rp.Pop<u32>();
const auto optname = static_cast<OptName>(rp.Pop<u32>());
- LOG_WARNING(Service, "(STUBBED) called. fd={} level={} optname=0x{:x}", fd, level, optname);
-
std::vector<u8> optval(ctx.GetWriteBufferSize());
+ LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} len=0x{:x}", fd, level, optname,
+ optval.size());
+
+ const Errno err = GetSockOptImpl(fd, level, optname, optval);
+
ctx.WriteBuffer(optval);
IPC::ResponseBuilder rb{ctx, 5};
rb.Push(ResultSuccess);
- rb.Push<s32>(-1);
- rb.PushEnum(Errno::NOTCONN);
+ rb.Push<s32>(err == Errno::SUCCESS ? 0 : -1);
+ rb.PushEnum(err);
rb.Push<u32>(static_cast<u32>(optval.size()));
}
@@ -436,6 +442,31 @@ void BSD::Close(HLERequestContext& ctx) {
BuildErrnoResponse(ctx, CloseImpl(fd));
}
+void BSD::DuplicateSocket(HLERequestContext& ctx) {
+ struct InputParameters {
+ s32 fd;
+ u64 reserved;
+ };
+ static_assert(sizeof(InputParameters) == 0x10);
+
+ struct OutputParameters {
+ s32 ret;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ IPC::RequestParser rp{ctx};
+ auto input = rp.PopRaw<InputParameters>();
+
+ Expected<s32, Errno> res = DuplicateSocketImpl(input.fd);
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .ret = res.value_or(0),
+ .bsd_errno = res ? Errno::SUCCESS : res.error(),
+ });
+}
+
void BSD::EventFd(HLERequestContext& ctx) {
IPC::RequestParser rp{ctx};
const u64 initval = rp.Pop<u64>();
@@ -477,12 +508,12 @@ std::pair<s32, Errno> BSD::SocketImpl(Domain domain, Type type, Protocol protoco
auto room_member = room_network.GetRoomMember().lock();
if (room_member && room_member->IsConnected()) {
- descriptor.socket = std::make_unique<Network::ProxySocket>(room_network);
+ descriptor.socket = std::make_shared<Network::ProxySocket>(room_network);
} else {
- descriptor.socket = std::make_unique<Network::Socket>();
+ descriptor.socket = std::make_shared<Network::Socket>();
}
- descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(type, protocol));
+ descriptor.socket->Initialize(Translate(domain), Translate(type), Translate(protocol));
descriptor.is_connection_based = IsConnectionBased(type);
return {fd, Errno::SUCCESS};
@@ -528,7 +559,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
const std::optional<FileDescriptor>& descriptor = file_descriptors[pollfd.fd];
if (!descriptor) {
- LOG_ERROR(Service, "File descriptor handle={} is not allocated", pollfd.fd);
+ LOG_TRACE(Service, "File descriptor handle={} is not allocated", pollfd.fd);
pollfd.revents = PollEvents::Nval;
return {0, Errno::SUCCESS};
}
@@ -538,7 +569,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
std::transform(fds.begin(), fds.end(), host_pollfds.begin(), [this](PollFD pollfd) {
Network::PollFD result;
result.socket = file_descriptors[pollfd.fd]->socket.get();
- result.events = TranslatePollEventsToHost(pollfd.events);
+ result.events = Translate(pollfd.events);
result.revents = Network::PollEvents{};
return result;
});
@@ -547,7 +578,7 @@ std::pair<s32, Errno> BSD::PollImpl(std::vector<u8>& write_buffer, std::span<con
const size_t num = host_pollfds.size();
for (size_t i = 0; i < num; ++i) {
- fds[i].revents = TranslatePollEventsToGuest(host_pollfds[i].revents);
+ fds[i].revents = Translate(host_pollfds[i].revents);
}
std::memcpy(write_buffer.data(), fds.data(), length);
@@ -617,7 +648,8 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector<u8>& write_buffer) {
}
const SockAddrIn guest_addrin = Translate(addr_in);
- ASSERT(write_buffer.size() == sizeof(guest_addrin));
+ ASSERT(write_buffer.size() >= sizeof(guest_addrin));
+ write_buffer.resize(sizeof(guest_addrin));
std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
return Translate(bsd_errno);
}
@@ -633,7 +665,8 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer) {
}
const SockAddrIn guest_addrin = Translate(addr_in);
- ASSERT(write_buffer.size() == sizeof(guest_addrin));
+ ASSERT(write_buffer.size() >= sizeof(guest_addrin));
+ write_buffer.resize(sizeof(guest_addrin));
std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin));
return Translate(bsd_errno);
}
@@ -671,13 +704,47 @@ std::pair<s32, Errno> BSD::FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg) {
}
}
-Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
- UNIMPLEMENTED_IF(level != 0xffff); // SOL_SOCKET
+Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval) {
+ if (!IsFileDescriptorValid(fd)) {
+ return Errno::BADF;
+ }
+
+ if (level != static_cast<u32>(SocketLevel::SOCKET)) {
+ UNIMPLEMENTED_MSG("Unknown getsockopt level");
+ return Errno::SUCCESS;
+ }
+
+ Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
+
+ switch (optname) {
+ case OptName::ERROR_: {
+ auto [pending_err, getsockopt_err] = socket->GetPendingError();
+ if (getsockopt_err == Network::Errno::SUCCESS) {
+ Errno translated_pending_err = Translate(pending_err);
+ ASSERT_OR_EXECUTE_MSG(
+ optval.size() == sizeof(Errno), { return Errno::INVAL; },
+ "Incorrect getsockopt option size");
+ optval.resize(sizeof(Errno));
+ memcpy(optval.data(), &translated_pending_err, sizeof(Errno));
+ }
+ return Translate(getsockopt_err);
+ }
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
+ return Errno::SUCCESS;
+ }
+}
+Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) {
if (!IsFileDescriptorValid(fd)) {
return Errno::BADF;
}
+ if (level != static_cast<u32>(SocketLevel::SOCKET)) {
+ UNIMPLEMENTED_MSG("Unknown setsockopt level");
+ return Errno::SUCCESS;
+ }
+
Network::SocketBase* const socket = file_descriptors[fd]->socket.get();
if (optname == OptName::LINGER) {
@@ -711,6 +778,9 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con
return Translate(socket->SetSndTimeo(value));
case OptName::RCVTIMEO:
return Translate(socket->SetRcvTimeo(value));
+ case OptName::NOSIGPIPE:
+ LOG_WARNING(Service, "(STUBBED) setting NOSIGPIPE to {}", value);
+ return Errno::SUCCESS;
default:
UNIMPLEMENTED_MSG("Unimplemented optname={}", optname);
return Errno::SUCCESS;
@@ -841,6 +911,28 @@ Errno BSD::CloseImpl(s32 fd) {
return bsd_errno;
}
+Expected<s32, Errno> BSD::DuplicateSocketImpl(s32 fd) {
+ if (!IsFileDescriptorValid(fd)) {
+ return Unexpected(Errno::BADF);
+ }
+
+ const s32 new_fd = FindFreeFileDescriptorHandle();
+ if (new_fd < 0) {
+ LOG_ERROR(Service, "No more file descriptors available");
+ return Unexpected(Errno::MFILE);
+ }
+
+ file_descriptors[new_fd] = file_descriptors[fd];
+ return new_fd;
+}
+
+std::optional<std::shared_ptr<Network::SocketBase>> BSD::GetSocket(s32 fd) {
+ if (!IsFileDescriptorValid(fd)) {
+ return std::nullopt;
+ }
+ return file_descriptors[fd]->socket;
+}
+
s32 BSD::FindFreeFileDescriptorHandle() noexcept {
for (s32 fd = 0; fd < static_cast<s32>(file_descriptors.size()); ++fd) {
if (!file_descriptors[fd]) {
@@ -911,7 +1003,7 @@ BSD::BSD(Core::System& system_, const char* name)
{24, &BSD::Write, "Write"},
{25, &BSD::Read, "Read"},
{26, &BSD::Close, "Close"},
- {27, nullptr, "DuplicateSocket"},
+ {27, &BSD::DuplicateSocket, "DuplicateSocket"},
{28, nullptr, "GetResourceStatistics"},
{29, nullptr, "RecvMMsg"},
{30, nullptr, "SendMMsg"},
diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h
index 30ae9c140..430edb97c 100644
--- a/src/core/hle/service/sockets/bsd.h
+++ b/src/core/hle/service/sockets/bsd.h
@@ -8,6 +8,7 @@
#include <vector>
#include "common/common_types.h"
+#include "common/expected.h"
#include "common/socket_types.h"
#include "core/hle/service/service.h"
#include "core/hle/service/sockets/sockets.h"
@@ -29,12 +30,19 @@ public:
explicit BSD(Core::System& system_, const char* name);
~BSD() override;
+ // These methods are called from SSL; the first two are also called from
+ // this class for the corresponding IPC methods.
+ // On the real device, the SSL service makes IPC calls to this service.
+ Common::Expected<s32, Errno> DuplicateSocketImpl(s32 fd);
+ Errno CloseImpl(s32 fd);
+ std::optional<std::shared_ptr<Network::SocketBase>> GetSocket(s32 fd);
+
private:
/// Maximum number of file descriptors
static constexpr size_t MAX_FD = 128;
struct FileDescriptor {
- std::unique_ptr<Network::SocketBase> socket;
+ std::shared_ptr<Network::SocketBase> socket;
s32 flags = 0;
bool is_connection_based = false;
};
@@ -138,6 +146,7 @@ private:
void Write(HLERequestContext& ctx);
void Read(HLERequestContext& ctx);
void Close(HLERequestContext& ctx);
+ void DuplicateSocket(HLERequestContext& ctx);
void EventFd(HLERequestContext& ctx);
template <typename Work>
@@ -153,6 +162,7 @@ private:
Errno GetSockNameImpl(s32 fd, std::vector<u8>& write_buffer);
Errno ListenImpl(s32 fd, s32 backlog);
std::pair<s32, Errno> FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg);
+ Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector<u8>& optval);
Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval);
Errno ShutdownImpl(s32 fd, s32 how);
std::pair<s32, Errno> RecvImpl(s32 fd, u32 flags, std::vector<u8>& message);
@@ -161,7 +171,6 @@ private:
std::pair<s32, Errno> SendImpl(s32 fd, u32 flags, std::span<const u8> message);
std::pair<s32, Errno> SendToImpl(s32 fd, u32 flags, std::span<const u8> message,
std::span<const u8> addr);
- Errno CloseImpl(s32 fd);
s32 FindFreeFileDescriptorHandle() noexcept;
bool IsFileDescriptorValid(s32 fd) const noexcept;
diff --git a/src/core/hle/service/sockets/nsd.cpp b/src/core/hle/service/sockets/nsd.cpp
index 6491a73be..5dfcaabb1 100644
--- a/src/core/hle/service/sockets/nsd.cpp
+++ b/src/core/hle/service/sockets/nsd.cpp
@@ -1,22 +1,36 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/sockets/nsd.h"
+#include "common/string_util.h"
+
namespace Service::Sockets {
+constexpr Result ResultOverflow{ErrorModule::NSD, 6};
+
+// This is nn::oe::ServerEnvironmentType
+enum class ServerEnvironmentType : u8 {
+ Dd,
+ Lp,
+ Sd,
+ Sp,
+ Dp,
+};
+
NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, name} {
// clang-format off
static const FunctionInfo functions[] = {
{5, nullptr, "GetSettingUrl"},
{10, nullptr, "GetSettingName"},
- {11, nullptr, "GetEnvironmentIdentifier"},
+ {11, &NSD::GetEnvironmentIdentifier, "GetEnvironmentIdentifier"},
{12, nullptr, "GetDeviceId"},
{13, nullptr, "DeleteSettings"},
{14, nullptr, "ImportSettings"},
{15, nullptr, "SetChangeEnvironmentIdentifierDisabled"},
- {20, nullptr, "Resolve"},
- {21, nullptr, "ResolveEx"},
+ {20, &NSD::Resolve, "Resolve"},
+ {21, &NSD::ResolveEx, "ResolveEx"},
{30, nullptr, "GetNasServiceSetting"},
{31, nullptr, "GetNasServiceSettingEx"},
{40, nullptr, "GetNasRequestFqdn"},
@@ -31,7 +45,7 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
{62, nullptr, "DeleteSaveDataOfFsForTest"},
{63, nullptr, "IsChangeEnvironmentIdentifierDisabled"},
{64, nullptr, "SetWithoutDomainExchangeFqdns"},
- {100, nullptr, "GetApplicationServerEnvironmentType"},
+ {100, &NSD::GetApplicationServerEnvironmentType, "GetApplicationServerEnvironmentType"},
{101, nullptr, "SetApplicationServerEnvironmentType"},
{102, nullptr, "DeleteApplicationServerEnvironmentType"},
};
@@ -40,6 +54,69 @@ NSD::NSD(Core::System& system_, const char* name) : ServiceFramework{system_, na
RegisterHandlers(functions);
}
+static ResultVal<std::string> ResolveImpl(const std::string& fqdn_in) {
+ // The real implementation makes various substitutions.
+ // For now we just return the string as-is, which is good enough when not
+ // connecting to real Nintendo servers.
+ LOG_WARNING(Service, "(STUBBED) called, fqdn_in={}", fqdn_in);
+ return fqdn_in;
+}
+
+static Result ResolveCommon(const std::string& fqdn_in, std::array<char, 0x100>& fqdn_out) {
+ const auto res = ResolveImpl(fqdn_in);
+ if (res.Failed()) {
+ return res.Code();
+ }
+ if (res->size() >= fqdn_out.size()) {
+ return ResultOverflow;
+ }
+ std::memcpy(fqdn_out.data(), res->c_str(), res->size() + 1);
+ return ResultSuccess;
+}
+
+void NSD::Resolve(HLERequestContext& ctx) {
+ const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
+
+ std::array<char, 0x100> fqdn_out{};
+ const Result res = ResolveCommon(fqdn_in, fqdn_out);
+
+ ctx.WriteBuffer(fqdn_out);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+}
+
+void NSD::ResolveEx(HLERequestContext& ctx) {
+ const std::string fqdn_in = Common::StringFromBuffer(ctx.ReadBuffer(0));
+
+ std::array<char, 0x100> fqdn_out;
+ const Result res = ResolveCommon(fqdn_in, fqdn_out);
+
+ if (res.IsError()) {
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ return;
+ }
+
+ ctx.WriteBuffer(fqdn_out);
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(ResultSuccess);
+ rb.Push(ResultSuccess);
+}
+
+void NSD::GetEnvironmentIdentifier(HLERequestContext& ctx) {
+ const std::string environment_identifier = "lp1";
+ ctx.WriteBuffer(environment_identifier);
+
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(ResultSuccess);
+}
+
+void NSD::GetApplicationServerEnvironmentType(HLERequestContext& ctx) {
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(ResultSuccess);
+ rb.Push(static_cast<u32>(ServerEnvironmentType::Lp));
+}
+
NSD::~NSD() = default;
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/nsd.h b/src/core/hle/service/sockets/nsd.h
index 5cc12b855..b0cfec507 100644
--- a/src/core/hle/service/sockets/nsd.h
+++ b/src/core/hle/service/sockets/nsd.h
@@ -15,6 +15,12 @@ class NSD final : public ServiceFramework<NSD> {
public:
explicit NSD(Core::System& system_, const char* name);
~NSD() override;
+
+private:
+ void Resolve(HLERequestContext& ctx);
+ void ResolveEx(HLERequestContext& ctx);
+ void GetEnvironmentIdentifier(HLERequestContext& ctx);
+ void GetApplicationServerEnvironmentType(HLERequestContext& ctx);
};
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.cpp b/src/core/hle/service/sockets/sfdnsres.cpp
index 132dd5797..22e4a6f49 100644
--- a/src/core/hle/service/sockets/sfdnsres.cpp
+++ b/src/core/hle/service/sockets/sfdnsres.cpp
@@ -10,39 +10,30 @@
#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/sockets/sfdnsres.h"
+#include "core/hle/service/sockets/sockets.h"
+#include "core/hle/service/sockets/sockets_translate.h"
+#include "core/internal_network/network.h"
#include "core/memory.h"
-#ifdef _WIN32
-#include <ws2tcpip.h>
-#elif YUZU_UNIX
-#include <arpa/inet.h>
-#include <netdb.h>
-#include <netinet/in.h>
-#include <sys/socket.h>
-#ifndef EAI_NODATA
-#define EAI_NODATA EAI_NONAME
-#endif
-#endif
-
namespace Service::Sockets {
SFDNSRES::SFDNSRES(Core::System& system_) : ServiceFramework{system_, "sfdnsres"} {
static const FunctionInfo functions[] = {
{0, nullptr, "SetDnsAddressesPrivateRequest"},
{1, nullptr, "GetDnsAddressPrivateRequest"},
- {2, nullptr, "GetHostByNameRequest"},
+ {2, &SFDNSRES::GetHostByNameRequest, "GetHostByNameRequest"},
{3, nullptr, "GetHostByAddrRequest"},
{4, nullptr, "GetHostStringErrorRequest"},
- {5, nullptr, "GetGaiStringErrorRequest"},
+ {5, &SFDNSRES::GetGaiStringErrorRequest, "GetGaiStringErrorRequest"},
{6, &SFDNSRES::GetAddrInfoRequest, "GetAddrInfoRequest"},
{7, nullptr, "GetNameInfoRequest"},
{8, nullptr, "RequestCancelHandleRequest"},
{9, nullptr, "CancelRequest"},
- {10, nullptr, "GetHostByNameRequestWithOptions"},
+ {10, &SFDNSRES::GetHostByNameRequestWithOptions, "GetHostByNameRequestWithOptions"},
{11, nullptr, "GetHostByAddrRequestWithOptions"},
{12, &SFDNSRES::GetAddrInfoRequestWithOptions, "GetAddrInfoRequestWithOptions"},
{13, nullptr, "GetNameInfoRequestWithOptions"},
- {14, nullptr, "ResolverSetOptionRequest"},
+ {14, &SFDNSRES::ResolverSetOptionRequest, "ResolverSetOptionRequest"},
{15, nullptr, "ResolverGetOptionRequest"},
};
RegisterHandlers(functions);
@@ -59,188 +50,299 @@ enum class NetDbError : s32 {
NoData = 4,
};
-static NetDbError AddrInfoErrorToNetDbError(s32 result) {
- // Best effort guess to map errors
+static NetDbError GetAddrInfoErrorToNetDbError(GetAddrInfoError result) {
+ // These combinations have been verified on console (but are not
+ // exhaustive).
switch (result) {
- case 0:
+ case GetAddrInfoError::SUCCESS:
return NetDbError::Success;
- case EAI_AGAIN:
+ case GetAddrInfoError::AGAIN:
return NetDbError::TryAgain;
- case EAI_NODATA:
- return NetDbError::NoData;
+ case GetAddrInfoError::NODATA:
+ return NetDbError::HostNotFound;
+ case GetAddrInfoError::SERVICE:
+ return NetDbError::Success;
default:
return NetDbError::HostNotFound;
}
}
-static std::vector<u8> SerializeAddrInfo(const addrinfo* addrinfo, s32 result_code,
+static Errno GetAddrInfoErrorToErrno(GetAddrInfoError result) {
+ // These combinations have been verified on console (but are not
+ // exhaustive).
+ switch (result) {
+ case GetAddrInfoError::SUCCESS:
+ // Note: Sometimes a successful lookup sets errno to EADDRNOTAVAIL for
+ // some reason, but that doesn't seem useful to implement.
+ return Errno::SUCCESS;
+ case GetAddrInfoError::AGAIN:
+ return Errno::SUCCESS;
+ case GetAddrInfoError::NODATA:
+ return Errno::SUCCESS;
+ case GetAddrInfoError::SERVICE:
+ return Errno::INVAL;
+ default:
+ return Errno::SUCCESS;
+ }
+}
+
+template <typename T>
+static void Append(std::vector<u8>& vec, T t) {
+ const size_t offset = vec.size();
+ vec.resize(offset + sizeof(T));
+ std::memcpy(vec.data() + offset, &t, sizeof(T));
+}
+
+static void AppendNulTerminated(std::vector<u8>& vec, std::string_view str) {
+ const size_t offset = vec.size();
+ vec.resize(offset + str.size() + 1);
+ std::memmove(vec.data() + offset, str.data(), str.size());
+}
+
+// We implement gethostbyname using the host's getaddrinfo rather than the
+// host's gethostbyname, because it simplifies portability: e.g., getaddrinfo
+// behaves the same on Unix and Windows, unlike gethostbyname where Windows
+// doesn't implement h_errno.
+static std::vector<u8> SerializeAddrInfoAsHostEnt(const std::vector<Network::AddrInfo>& vec,
+ std::string_view host) {
+
+ std::vector<u8> data;
+ // h_name: use the input hostname (append nul-terminated)
+ AppendNulTerminated(data, host);
+ // h_aliases: leave empty
+
+ Append<u32_be>(data, 0); // count of h_aliases
+ // (If the count were nonzero, the aliases would be appended as nul-terminated here.)
+ Append<u16_be>(data, static_cast<u16>(Domain::INET)); // h_addrtype
+ Append<u16_be>(data, sizeof(Network::IPv4Address)); // h_length
+ // h_addr_list:
+ size_t count = vec.size();
+ ASSERT(count <= UINT32_MAX);
+ Append<u32_be>(data, static_cast<uint32_t>(count));
+ for (const Network::AddrInfo& addrinfo : vec) {
+ // On the Switch, this is passed through htonl despite already being
+ // big-endian, so it ends up as little-endian.
+ Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip));
+
+ LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
+ Network::IPv4AddressToString(addrinfo.addr.ip));
+ }
+ return data;
+}
+
+static std::pair<u32, GetAddrInfoError> GetHostByNameRequestImpl(HLERequestContext& ctx) {
+ struct InputParameters {
+ u8 use_nsd_resolve;
+ u32 cancel_handle;
+ u64 process_id;
+ };
+ static_assert(sizeof(InputParameters) == 0x10);
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<InputParameters>();
+
+ LOG_WARNING(
+ Service,
+ "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
+ parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
+
+ const auto host_buffer = ctx.ReadBuffer(0);
+ const std::string host = Common::StringFromBuffer(host_buffer);
+ // For now, ignore options, which are in input buffer 1 for GetHostByNameRequestWithOptions.
+
+ auto res = Network::GetAddressInfo(host, /*service*/ std::nullopt);
+ if (!res.has_value()) {
+ return {0, Translate(res.error())};
+ }
+
+ const std::vector<u8> data = SerializeAddrInfoAsHostEnt(res.value(), host);
+ const u32 data_size = static_cast<u32>(data.size());
+ ctx.WriteBuffer(data, 0);
+
+ return {data_size, GetAddrInfoError::SUCCESS};
+}
+
+void SFDNSRES::GetHostByNameRequest(HLERequestContext& ctx) {
+ auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
+
+ struct OutputParameters {
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ u32 data_size;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
+
+ IPC::ResponseBuilder rb{ctx, 5};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ .data_size = data_size,
+ });
+}
+
+void SFDNSRES::GetHostByNameRequestWithOptions(HLERequestContext& ctx) {
+ auto [data_size, emu_gai_err] = GetHostByNameRequestImpl(ctx);
+
+ struct OutputParameters {
+ u32 data_size;
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
+
+ IPC::ResponseBuilder rb{ctx, 5};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .data_size = data_size,
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ });
+}
+
+static std::vector<u8> SerializeAddrInfo(const std::vector<Network::AddrInfo>& vec,
std::string_view host) {
// Adapted from
// https://github.com/switchbrew/libnx/blob/c5a9a909a91657a9818a3b7e18c9b91ff0cbb6e3/nx/source/runtime/resolver.c#L190
std::vector<u8> data;
- auto* current = addrinfo;
- while (current != nullptr) {
- struct SerializedResponseHeader {
- u32 magic;
- s32 flags;
- s32 family;
- s32 socket_type;
- s32 protocol;
- u32 address_length;
- };
- static_assert(sizeof(SerializedResponseHeader) == 0x18,
- "Response header size must be 0x18 bytes");
-
- constexpr auto header_size = sizeof(SerializedResponseHeader);
- const auto addr_size =
- current->ai_addr && current->ai_addrlen > 0 ? current->ai_addrlen : 4;
- const auto canonname_size = current->ai_canonname ? strlen(current->ai_canonname) + 1 : 1;
-
- const auto last_size = data.size();
- data.resize(last_size + header_size + addr_size + canonname_size);
-
- // Header in network byte order
- SerializedResponseHeader header{};
-
- constexpr auto HEADER_MAGIC = 0xBEEFCAFE;
- header.magic = htonl(HEADER_MAGIC);
- header.family = htonl(current->ai_family);
- header.flags = htonl(current->ai_flags);
- header.socket_type = htonl(current->ai_socktype);
- header.protocol = htonl(current->ai_protocol);
- header.address_length = current->ai_addr ? htonl((u32)current->ai_addrlen) : 0;
-
- auto* header_ptr = data.data() + last_size;
- std::memcpy(header_ptr, &header, header_size);
-
- if (header.address_length == 0) {
- std::memset(header_ptr + header_size, 0, 4);
+ for (const Network::AddrInfo& addrinfo : vec) {
+ // serialized addrinfo:
+ Append<u32_be>(data, 0xBEEFCAFE); // magic
+ Append<u32_be>(data, 0); // ai_flags
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.family))); // ai_family
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.socket_type))); // ai_socktype
+ Append<u32_be>(data, static_cast<u32>(Translate(addrinfo.protocol))); // ai_protocol
+ Append<u32_be>(data, sizeof(SockAddrIn)); // ai_addrlen
+ // ^ *not* sizeof(SerializedSockAddrIn), not that it matters since they're the same size
+
+ // ai_addr:
+ Append<u16_be>(data, static_cast<u16>(Translate(addrinfo.addr.family))); // sin_family
+ // On the Switch, the following fields are passed through htonl despite
+ // already being big-endian, so they end up as little-endian.
+ Append<u16_le>(data, addrinfo.addr.portno); // sin_port
+ Append<u32_le>(data, Network::IPv4AddressToInteger(addrinfo.addr.ip)); // sin_addr
+ data.resize(data.size() + 8, 0); // sin_zero
+
+ if (addrinfo.canon_name.has_value()) {
+ AppendNulTerminated(data, *addrinfo.canon_name);
} else {
- switch (current->ai_family) {
- case AF_INET: {
- struct SockAddrIn {
- s16 sin_family;
- u16 sin_port;
- u32 sin_addr;
- u8 sin_zero[8];
- };
-
- SockAddrIn serialized_addr{};
- const auto addr = *reinterpret_cast<sockaddr_in*>(current->ai_addr);
- serialized_addr.sin_port = htons(addr.sin_port);
- serialized_addr.sin_family = htons(addr.sin_family);
- serialized_addr.sin_addr = htonl(addr.sin_addr.s_addr);
- std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn));
-
- char addr_string_buf[64]{};
- inet_ntop(AF_INET, &addr.sin_addr, addr_string_buf, std::size(addr_string_buf));
- LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host, addr_string_buf);
- break;
- }
- case AF_INET6: {
- struct SockAddrIn6 {
- s16 sin6_family;
- u16 sin6_port;
- u32 sin6_flowinfo;
- u8 sin6_addr[16];
- u32 sin6_scope_id;
- };
-
- SockAddrIn6 serialized_addr{};
- const auto addr = *reinterpret_cast<sockaddr_in6*>(current->ai_addr);
- serialized_addr.sin6_family = htons(addr.sin6_family);
- serialized_addr.sin6_port = htons(addr.sin6_port);
- serialized_addr.sin6_flowinfo = htonl(addr.sin6_flowinfo);
- serialized_addr.sin6_scope_id = htonl(addr.sin6_scope_id);
- std::memcpy(serialized_addr.sin6_addr, &addr.sin6_addr,
- sizeof(SockAddrIn6::sin6_addr));
- std::memcpy(header_ptr + header_size, &serialized_addr, sizeof(SockAddrIn6));
-
- char addr_string_buf[64]{};
- inet_ntop(AF_INET6, &addr.sin6_addr, addr_string_buf, std::size(addr_string_buf));
- LOG_INFO(Service, "Resolved host '{}' to IPv6 address {}", host, addr_string_buf);
- break;
- }
- default:
- std::memcpy(header_ptr + header_size, current->ai_addr, addr_size);
- break;
- }
- }
- if (current->ai_canonname) {
- std::memcpy(header_ptr + addr_size, current->ai_canonname, canonname_size);
- } else {
- *(header_ptr + header_size + addr_size) = 0;
+ data.push_back(0);
}
- current = current->ai_next;
+ LOG_INFO(Service, "Resolved host '{}' to IPv4 address {}", host,
+ Network::IPv4AddressToString(addrinfo.addr.ip));
}
- // 4-byte sentinel value
- data.push_back(0);
- data.push_back(0);
- data.push_back(0);
- data.push_back(0);
+ data.resize(data.size() + 4, 0); // 4-byte sentinel value
return data;
}
-static std::pair<u32, s32> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
- struct Parameters {
+static std::pair<u32, GetAddrInfoError> GetAddrInfoRequestImpl(HLERequestContext& ctx) {
+ struct InputParameters {
u8 use_nsd_resolve;
- u32 unknown;
+ u32 cancel_handle;
u64 process_id;
};
+ static_assert(sizeof(InputParameters) == 0x10);
IPC::RequestParser rp{ctx};
- const auto parameters = rp.PopRaw<Parameters>();
+ const auto parameters = rp.PopRaw<InputParameters>();
+
+ LOG_WARNING(
+ Service,
+ "called with ignored parameters: use_nsd_resolve={}, cancel_handle={}, process_id={}",
+ parameters.use_nsd_resolve, parameters.cancel_handle, parameters.process_id);
- LOG_WARNING(Service,
- "called with ignored parameters: use_nsd_resolve={}, unknown={}, process_id={}",
- parameters.use_nsd_resolve, parameters.unknown, parameters.process_id);
+ // TODO: If use_nsd_resolve is true, pass the name through NSD::Resolve
+ // before looking up.
const auto host_buffer = ctx.ReadBuffer(0);
const std::string host = Common::StringFromBuffer(host_buffer);
- const auto service_buffer = ctx.ReadBuffer(1);
- const std::string service = Common::StringFromBuffer(service_buffer);
-
- addrinfo* addrinfo;
- // Pass null for hints. Serialized hints are also passed in a buffer, but are ignored for now
- s32 result_code = getaddrinfo(host.c_str(), service.c_str(), nullptr, &addrinfo);
+ std::optional<std::string> service = std::nullopt;
+ if (ctx.CanReadBuffer(1)) {
+ const std::span<const u8> service_buffer = ctx.ReadBuffer(1);
+ service = Common::StringFromBuffer(service_buffer);
+ }
- u32 data_size = 0;
- if (result_code == 0 && addrinfo != nullptr) {
- const std::vector<u8>& data = SerializeAddrInfo(addrinfo, result_code, host);
- data_size = static_cast<u32>(data.size());
- freeaddrinfo(addrinfo);
+ // Serialized hints are also passed in a buffer, but are ignored for now.
- ctx.WriteBuffer(data, 0);
+ auto res = Network::GetAddressInfo(host, service);
+ if (!res.has_value()) {
+ return {0, Translate(res.error())};
}
- return std::make_pair(data_size, result_code);
+ const std::vector<u8> data = SerializeAddrInfo(res.value(), host);
+ const u32 data_size = static_cast<u32>(data.size());
+ ctx.WriteBuffer(data, 0);
+
+ return {data_size, GetAddrInfoError::SUCCESS};
}
void SFDNSRES::GetAddrInfoRequest(HLERequestContext& ctx) {
- auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
+ auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
- IPC::ResponseBuilder rb{ctx, 4};
+ struct OutputParameters {
+ Errno bsd_errno;
+ GetAddrInfoError gai_error;
+ u32 data_size;
+ };
+ static_assert(sizeof(OutputParameters) == 0xc);
+
+ IPC::ResponseBuilder rb{ctx, 5};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ .gai_error = emu_gai_err,
+ .data_size = data_size,
+ });
+}
+
+void SFDNSRES::GetGaiStringErrorRequest(HLERequestContext& ctx) {
+ struct InputParameters {
+ GetAddrInfoError gai_errno;
+ };
+ IPC::RequestParser rp{ctx};
+ auto input = rp.PopRaw<InputParameters>();
+
+ const std::string result = Translate(input.gai_errno);
+ ctx.WriteBuffer(result);
+
+ IPC::ResponseBuilder rb{ctx, 2};
rb.Push(ResultSuccess);
- rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
- rb.Push(result_code); // errno
- rb.Push(data_size); // serialized size
}
void SFDNSRES::GetAddrInfoRequestWithOptions(HLERequestContext& ctx) {
// Additional options are ignored
- auto [data_size, result_code] = GetAddrInfoRequestImpl(ctx);
+ auto [data_size, emu_gai_err] = GetAddrInfoRequestImpl(ctx);
+
+ struct OutputParameters {
+ u32 data_size;
+ GetAddrInfoError gai_error;
+ NetDbError netdb_error;
+ Errno bsd_errno;
+ };
+ static_assert(sizeof(OutputParameters) == 0x10);
+
+ IPC::ResponseBuilder rb{ctx, 6};
+ rb.Push(ResultSuccess);
+ rb.PushRaw(OutputParameters{
+ .data_size = data_size,
+ .gai_error = emu_gai_err,
+ .netdb_error = GetAddrInfoErrorToNetDbError(emu_gai_err),
+ .bsd_errno = GetAddrInfoErrorToErrno(emu_gai_err),
+ });
+}
+
+void SFDNSRES::ResolverSetOptionRequest(HLERequestContext& ctx) {
+ LOG_WARNING(Service, "(STUBBED) called");
+
+ IPC::ResponseBuilder rb{ctx, 3};
- IPC::ResponseBuilder rb{ctx, 5};
rb.Push(ResultSuccess);
- rb.Push(data_size); // serialized size
- rb.Push(result_code); // errno
- rb.Push(static_cast<s32>(AddrInfoErrorToNetDbError(result_code))); // NetDBErrorCode
- rb.Push(0);
+ rb.Push<s32>(0); // bsd errno
}
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sfdnsres.h b/src/core/hle/service/sockets/sfdnsres.h
index 18e3cd60c..282ef9071 100644
--- a/src/core/hle/service/sockets/sfdnsres.h
+++ b/src/core/hle/service/sockets/sfdnsres.h
@@ -17,8 +17,12 @@ public:
~SFDNSRES() override;
private:
+ void GetHostByNameRequest(HLERequestContext& ctx);
+ void GetGaiStringErrorRequest(HLERequestContext& ctx);
+ void GetHostByNameRequestWithOptions(HLERequestContext& ctx);
void GetAddrInfoRequest(HLERequestContext& ctx);
void GetAddrInfoRequestWithOptions(HLERequestContext& ctx);
+ void ResolverSetOptionRequest(HLERequestContext& ctx);
};
} // namespace Service::Sockets
diff --git a/src/core/hle/service/sockets/sockets.h b/src/core/hle/service/sockets/sockets.h
index acd2dae7b..77426c46e 100644
--- a/src/core/hle/service/sockets/sockets.h
+++ b/src/core/hle/service/sockets/sockets.h
@@ -22,13 +22,35 @@ enum class Errno : u32 {
CONNRESET = 104,
NOTCONN = 107,
TIMEDOUT = 110,
+ INPROGRESS = 115,
+};
+
+enum class GetAddrInfoError : s32 {
+ SUCCESS = 0,
+ ADDRFAMILY = 1,
+ AGAIN = 2,
+ BADFLAGS = 3,
+ FAIL = 4,
+ FAMILY = 5,
+ MEMORY = 6,
+ NODATA = 7,
+ NONAME = 8,
+ SERVICE = 9,
+ SOCKTYPE = 10,
+ SYSTEM = 11,
+ BADHINTS = 12,
+ PROTOCOL = 13,
+ OVERFLOW_ = 14, // avoid name collision with Windows macro
+ OTHER = 15,
};
enum class Domain : u32 {
+ Unspecified = 0,
INET = 2,
};
enum class Type : u32 {
+ Unspecified = 0,
STREAM = 1,
DGRAM = 2,
RAW = 3,
@@ -36,12 +58,16 @@ enum class Type : u32 {
};
enum class Protocol : u32 {
- UNSPECIFIED = 0,
+ Unspecified = 0,
ICMP = 1,
TCP = 6,
UDP = 17,
};
+enum class SocketLevel : u32 {
+ SOCKET = 0xffff, // i.e. SOL_SOCKET
+};
+
enum class OptName : u32 {
REUSEADDR = 0x4,
KEEPALIVE = 0x8,
@@ -51,6 +77,8 @@ enum class OptName : u32 {
RCVBUF = 0x1002,
SNDTIMEO = 0x1005,
RCVTIMEO = 0x1006,
+ ERROR_ = 0x1007, // avoid name collision with Windows macro
+ NOSIGPIPE = 0x800, // at least according to libnx
};
enum class ShutdownHow : s32 {
@@ -80,6 +108,9 @@ enum class PollEvents : u16 {
Err = 1 << 3,
Hup = 1 << 4,
Nval = 1 << 5,
+ RdNorm = 1 << 6,
+ RdBand = 1 << 7,
+ WrBand = 1 << 8,
};
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
diff --git a/src/core/hle/service/sockets/sockets_translate.cpp b/src/core/hle/service/sockets/sockets_translate.cpp
index 594e58f90..c1187209f 100644
--- a/src/core/hle/service/sockets/sockets_translate.cpp
+++ b/src/core/hle/service/sockets/sockets_translate.cpp
@@ -29,6 +29,8 @@ Errno Translate(Network::Errno value) {
return Errno::TIMEDOUT;
case Network::Errno::CONNRESET:
return Errno::CONNRESET;
+ case Network::Errno::INPROGRESS:
+ return Errno::INPROGRESS;
default:
UNIMPLEMENTED_MSG("Unimplemented errno={}", value);
return Errno::SUCCESS;
@@ -39,8 +41,88 @@ std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value) {
return {value.first, Translate(value.second)};
}
+GetAddrInfoError Translate(Network::GetAddrInfoError error) {
+ switch (error) {
+ case Network::GetAddrInfoError::SUCCESS:
+ return GetAddrInfoError::SUCCESS;
+ case Network::GetAddrInfoError::ADDRFAMILY:
+ return GetAddrInfoError::ADDRFAMILY;
+ case Network::GetAddrInfoError::AGAIN:
+ return GetAddrInfoError::AGAIN;
+ case Network::GetAddrInfoError::BADFLAGS:
+ return GetAddrInfoError::BADFLAGS;
+ case Network::GetAddrInfoError::FAIL:
+ return GetAddrInfoError::FAIL;
+ case Network::GetAddrInfoError::FAMILY:
+ return GetAddrInfoError::FAMILY;
+ case Network::GetAddrInfoError::MEMORY:
+ return GetAddrInfoError::MEMORY;
+ case Network::GetAddrInfoError::NODATA:
+ return GetAddrInfoError::NODATA;
+ case Network::GetAddrInfoError::NONAME:
+ return GetAddrInfoError::NONAME;
+ case Network::GetAddrInfoError::SERVICE:
+ return GetAddrInfoError::SERVICE;
+ case Network::GetAddrInfoError::SOCKTYPE:
+ return GetAddrInfoError::SOCKTYPE;
+ case Network::GetAddrInfoError::SYSTEM:
+ return GetAddrInfoError::SYSTEM;
+ case Network::GetAddrInfoError::BADHINTS:
+ return GetAddrInfoError::BADHINTS;
+ case Network::GetAddrInfoError::PROTOCOL:
+ return GetAddrInfoError::PROTOCOL;
+ case Network::GetAddrInfoError::OVERFLOW_:
+ return GetAddrInfoError::OVERFLOW_;
+ case Network::GetAddrInfoError::OTHER:
+ return GetAddrInfoError::OTHER;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented GetAddrInfoError={}", error);
+ return GetAddrInfoError::OTHER;
+ }
+}
+
+const char* Translate(GetAddrInfoError error) {
+ // https://android.googlesource.com/platform/bionic/+/085543106/libc/dns/net/getaddrinfo.c#254
+ switch (error) {
+ case GetAddrInfoError::SUCCESS:
+ return "Success";
+ case GetAddrInfoError::ADDRFAMILY:
+ return "Address family for hostname not supported";
+ case GetAddrInfoError::AGAIN:
+ return "Temporary failure in name resolution";
+ case GetAddrInfoError::BADFLAGS:
+ return "Invalid value for ai_flags";
+ case GetAddrInfoError::FAIL:
+ return "Non-recoverable failure in name resolution";
+ case GetAddrInfoError::FAMILY:
+ return "ai_family not supported";
+ case GetAddrInfoError::MEMORY:
+ return "Memory allocation failure";
+ case GetAddrInfoError::NODATA:
+ return "No address associated with hostname";
+ case GetAddrInfoError::NONAME:
+ return "hostname nor servname provided, or not known";
+ case GetAddrInfoError::SERVICE:
+ return "servname not supported for ai_socktype";
+ case GetAddrInfoError::SOCKTYPE:
+ return "ai_socktype not supported";
+ case GetAddrInfoError::SYSTEM:
+ return "System error returned in errno";
+ case GetAddrInfoError::BADHINTS:
+ return "Invalid value for hints";
+ case GetAddrInfoError::PROTOCOL:
+ return "Resolved protocol is unknown";
+ case GetAddrInfoError::OVERFLOW_:
+ return "Argument buffer overflow";
+ default:
+ return "Unknown error";
+ }
+}
+
Network::Domain Translate(Domain domain) {
switch (domain) {
+ case Domain::Unspecified:
+ return Network::Domain::Unspecified;
case Domain::INET:
return Network::Domain::INET;
default:
@@ -51,6 +133,8 @@ Network::Domain Translate(Domain domain) {
Domain Translate(Network::Domain domain) {
switch (domain) {
+ case Network::Domain::Unspecified:
+ return Domain::Unspecified;
case Network::Domain::INET:
return Domain::INET;
default:
@@ -61,39 +145,69 @@ Domain Translate(Network::Domain domain) {
Network::Type Translate(Type type) {
switch (type) {
+ case Type::Unspecified:
+ return Network::Type::Unspecified;
case Type::STREAM:
return Network::Type::STREAM;
case Type::DGRAM:
return Network::Type::DGRAM;
+ case Type::RAW:
+ return Network::Type::RAW;
+ case Type::SEQPACKET:
+ return Network::Type::SEQPACKET;
default:
UNIMPLEMENTED_MSG("Unimplemented type={}", type);
return Network::Type{};
}
}
-Network::Protocol Translate(Type type, Protocol protocol) {
+Type Translate(Network::Type type) {
+ switch (type) {
+ case Network::Type::Unspecified:
+ return Type::Unspecified;
+ case Network::Type::STREAM:
+ return Type::STREAM;
+ case Network::Type::DGRAM:
+ return Type::DGRAM;
+ case Network::Type::RAW:
+ return Type::RAW;
+ case Network::Type::SEQPACKET:
+ return Type::SEQPACKET;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented type={}", type);
+ return Type{};
+ }
+}
+
+Network::Protocol Translate(Protocol protocol) {
switch (protocol) {
- case Protocol::UNSPECIFIED:
- LOG_WARNING(Service, "Unspecified protocol, assuming protocol from type");
- switch (type) {
- case Type::DGRAM:
- return Network::Protocol::UDP;
- case Type::STREAM:
- return Network::Protocol::TCP;
- default:
- return Network::Protocol::TCP;
- }
+ case Protocol::Unspecified:
+ return Network::Protocol::Unspecified;
case Protocol::TCP:
return Network::Protocol::TCP;
case Protocol::UDP:
return Network::Protocol::UDP;
default:
UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
- return Network::Protocol::TCP;
+ return Network::Protocol::Unspecified;
+ }
+}
+
+Protocol Translate(Network::Protocol protocol) {
+ switch (protocol) {
+ case Network::Protocol::Unspecified:
+ return Protocol::Unspecified;
+ case Network::Protocol::TCP:
+ return Protocol::TCP;
+ case Network::Protocol::UDP:
+ return Protocol::UDP;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
+ return Protocol::Unspecified;
}
}
-Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
+Network::PollEvents Translate(PollEvents flags) {
Network::PollEvents result{};
const auto translate = [&result, &flags](PollEvents from, Network::PollEvents to) {
if (True(flags & from)) {
@@ -107,12 +221,15 @@ Network::PollEvents TranslatePollEventsToHost(PollEvents flags) {
translate(PollEvents::Err, Network::PollEvents::Err);
translate(PollEvents::Hup, Network::PollEvents::Hup);
translate(PollEvents::Nval, Network::PollEvents::Nval);
+ translate(PollEvents::RdNorm, Network::PollEvents::RdNorm);
+ translate(PollEvents::RdBand, Network::PollEvents::RdBand);
+ translate(PollEvents::WrBand, Network::PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
return result;
}
-PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
+PollEvents Translate(Network::PollEvents flags) {
PollEvents result{};
const auto translate = [&result, &flags](Network::PollEvents from, PollEvents to) {
if (True(flags & from)) {
@@ -127,13 +244,18 @@ PollEvents TranslatePollEventsToGuest(Network::PollEvents flags) {
translate(Network::PollEvents::Err, PollEvents::Err);
translate(Network::PollEvents::Hup, PollEvents::Hup);
translate(Network::PollEvents::Nval, PollEvents::Nval);
+ translate(Network::PollEvents::RdNorm, PollEvents::RdNorm);
+ translate(Network::PollEvents::RdBand, PollEvents::RdBand);
+ translate(Network::PollEvents::WrBand, PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG((u16)flags != 0, "Unimplemented flags={}", (u16)flags);
return result;
}
Network::SockAddrIn Translate(SockAddrIn value) {
- ASSERT(value.len == 0 || value.len == sizeof(value));
+ // Note: 6 is incorrect, but can be passed by homebrew (because libnx sets
+ // sin_len to 6 when deserializing getaddrinfo results).
+ ASSERT(value.len == 0 || value.len == sizeof(value) || value.len == 6);
return {
.family = Translate(static_cast<Domain>(value.family)),
diff --git a/src/core/hle/service/sockets/sockets_translate.h b/src/core/hle/service/sockets/sockets_translate.h
index c93291d3e..bd6721fd3 100644
--- a/src/core/hle/service/sockets/sockets_translate.h
+++ b/src/core/hle/service/sockets/sockets_translate.h
@@ -17,6 +17,12 @@ Errno Translate(Network::Errno value);
/// Translate abstract return value errno pair to guest return value errno pair
std::pair<s32, Errno> Translate(std::pair<s32, Network::Errno> value);
+/// Translate abstract getaddrinfo error to guest getaddrinfo error
+GetAddrInfoError Translate(Network::GetAddrInfoError value);
+
+/// Translate guest error to string
+const char* Translate(GetAddrInfoError value);
+
/// Translate guest domain to abstract domain
Network::Domain Translate(Domain domain);
@@ -26,14 +32,20 @@ Domain Translate(Network::Domain domain);
/// Translate guest type to abstract type
Network::Type Translate(Type type);
+/// Translate abstract type to guest type
+Type Translate(Network::Type type);
+
/// Translate guest protocol to abstract protocol
-Network::Protocol Translate(Type type, Protocol protocol);
+Network::Protocol Translate(Protocol protocol);
-/// Translate abstract poll event flags to guest poll event flags
-Network::PollEvents TranslatePollEventsToHost(PollEvents flags);
+/// Translate abstract protocol to guest protocol
+Protocol Translate(Network::Protocol protocol);
/// Translate guest poll event flags to abstract poll event flags
-PollEvents TranslatePollEventsToGuest(Network::PollEvents flags);
+Network::PollEvents Translate(PollEvents flags);
+
+/// Translate abstract poll event flags to guest poll event flags
+PollEvents Translate(Network::PollEvents flags);
/// Translate guest socket address structure to abstract socket address structure
Network::SockAddrIn Translate(SockAddrIn value);
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp
index 2b99dd7ac..9c96f9763 100644
--- a/src/core/hle/service/ssl/ssl.cpp
+++ b/src/core/hle/service/ssl/ssl.cpp
@@ -1,10 +1,18 @@
// SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
+#include "common/string_util.h"
+
+#include "core/core.h"
#include "core/hle/service/ipc_helpers.h"
#include "core/hle/service/server_manager.h"
#include "core/hle/service/service.h"
+#include "core/hle/service/sm/sm.h"
+#include "core/hle/service/sockets/bsd.h"
#include "core/hle/service/ssl/ssl.h"
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
namespace Service::SSL {
@@ -20,6 +28,18 @@ enum class ContextOption : u32 {
CrlImportDateCheckEnable = 1,
};
+// This is nn::ssl::Connection::IoMode
+enum class IoMode : u32 {
+ Blocking = 1,
+ NonBlocking = 2,
+};
+
+// This is nn::ssl::sf::OptionType
+enum class OptionType : u32 {
+ DoNotCloseSocket = 0,
+ GetServerCertChain = 1,
+};
+
// This is nn::ssl::sf::SslVersion
struct SslVersion {
union {
@@ -34,35 +54,42 @@ struct SslVersion {
};
};
+struct SslContextSharedData {
+ u32 connection_count = 0;
+};
+
class ISslConnection final : public ServiceFramework<ISslConnection> {
public:
- explicit ISslConnection(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} {
+ explicit ISslConnection(Core::System& system_in, SslVersion ssl_version_in,
+ std::shared_ptr<SslContextSharedData>& shared_data_in,
+ std::unique_ptr<SSLConnectionBackend>&& backend_in)
+ : ServiceFramework{system_in, "ISslConnection"}, ssl_version{ssl_version_in},
+ shared_data{shared_data_in}, backend{std::move(backend_in)} {
// clang-format off
static const FunctionInfo functions[] = {
- {0, nullptr, "SetSocketDescriptor"},
- {1, nullptr, "SetHostName"},
- {2, nullptr, "SetVerifyOption"},
- {3, nullptr, "SetIoMode"},
+ {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"},
+ {1, &ISslConnection::SetHostName, "SetHostName"},
+ {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"},
+ {3, &ISslConnection::SetIoMode, "SetIoMode"},
{4, nullptr, "GetSocketDescriptor"},
{5, nullptr, "GetHostName"},
{6, nullptr, "GetVerifyOption"},
{7, nullptr, "GetIoMode"},
- {8, nullptr, "DoHandshake"},
- {9, nullptr, "DoHandshakeGetServerCert"},
- {10, nullptr, "Read"},
- {11, nullptr, "Write"},
- {12, nullptr, "Pending"},
+ {8, &ISslConnection::DoHandshake, "DoHandshake"},
+ {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"},
+ {10, &ISslConnection::Read, "Read"},
+ {11, &ISslConnection::Write, "Write"},
+ {12, &ISslConnection::Pending, "Pending"},
{13, nullptr, "Peek"},
{14, nullptr, "Poll"},
{15, nullptr, "GetVerifyCertError"},
{16, nullptr, "GetNeededServerCertBufferSize"},
- {17, nullptr, "SetSessionCacheMode"},
+ {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"},
{18, nullptr, "GetSessionCacheMode"},
{19, nullptr, "FlushSessionCache"},
{20, nullptr, "SetRenegotiationMode"},
{21, nullptr, "GetRenegotiationMode"},
- {22, nullptr, "SetOption"},
+ {22, &ISslConnection::SetOption, "SetOption"},
{23, nullptr, "GetOption"},
{24, nullptr, "GetVerifyCertErrors"},
{25, nullptr, "GetCipherInfo"},
@@ -80,21 +107,299 @@ public:
// clang-format on
RegisterHandlers(functions);
+
+ shared_data->connection_count++;
+ }
+
+ ~ISslConnection() {
+ shared_data->connection_count--;
+ if (fd_to_close.has_value()) {
+ const s32 fd = *fd_to_close;
+ if (!do_not_close_socket) {
+ LOG_ERROR(Service_SSL,
+ "do_not_close_socket was changed after setting socket; is this right?");
+ } else {
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ if (bsd) {
+ auto err = bsd->CloseImpl(fd);
+ if (err != Service::Sockets::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to close duplicated socket: {}", err);
+ }
+ }
+ }
+ }
}
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
+ std::unique_ptr<SSLConnectionBackend> backend;
+ std::optional<int> fd_to_close;
+ bool do_not_close_socket = false;
+ bool get_server_cert_chain = false;
+ std::shared_ptr<Network::SocketBase> socket;
+ bool did_set_host_name = false;
+ bool did_handshake = false;
+
+ ResultVal<s32> SetSocketDescriptorImpl(s32 fd) {
+ LOG_DEBUG(Service_SSL, "called, fd={}", fd);
+ ASSERT(!did_handshake);
+ auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u");
+ ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; });
+ s32 ret_fd;
+ // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor
+ if (do_not_close_socket) {
+ auto res = bsd->DuplicateSocketImpl(fd);
+ if (!res.has_value()) {
+ LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ fd = *res;
+ fd_to_close = fd;
+ ret_fd = fd;
+ } else {
+ ret_fd = -1;
+ }
+ std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd);
+ if (!sock.has_value()) {
+ LOG_ERROR(Service_SSL, "invalid socket fd {}", fd);
+ return ResultInvalidSocket;
+ }
+ socket = std::move(*sock);
+ backend->SetSocket(socket);
+ return ret_fd;
+ }
+
+ Result SetHostNameImpl(const std::string& hostname) {
+ LOG_DEBUG(Service_SSL, "called. hostname={}", hostname);
+ ASSERT(!did_handshake);
+ Result res = backend->SetHostName(hostname);
+ if (res == ResultSuccess) {
+ did_set_host_name = true;
+ }
+ return res;
+ }
+
+ Result SetVerifyOptionImpl(u32 option) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option);
+ return ResultSuccess;
+ }
+
+ Result SetIoModeImpl(u32 input_mode) {
+ auto mode = static_cast<IoMode>(input_mode);
+ ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking);
+ ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; });
+
+ const bool non_block = mode == IoMode::NonBlocking;
+ const Network::Errno error = socket->SetNonBlock(non_block);
+ if (error != Network::Errno::SUCCESS) {
+ LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block);
+ }
+ return ResultSuccess;
+ }
+
+ Result SetSessionCacheModeImpl(u32 mode) {
+ ASSERT(!did_handshake);
+ LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode);
+ return ResultSuccess;
+ }
+
+ Result DoHandshakeImpl() {
+ ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; });
+ ASSERT_OR_EXECUTE_MSG(
+ did_set_host_name, { return ResultInternalError; },
+ "Expected SetHostName before DoHandshake");
+ Result res = backend->DoHandshake();
+ did_handshake = res.IsSuccess();
+ return res;
+ }
+
+ std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) {
+ struct Header {
+ u64 magic;
+ u32 count;
+ u32 pad;
+ };
+ struct EntryHeader {
+ u32 size;
+ u32 offset;
+ };
+ if (!get_server_cert_chain) {
+ // Just return the first one, unencoded.
+ ASSERT_OR_EXECUTE_MSG(
+ !certs.empty(), { return {}; }, "Should be at least one server cert");
+ return certs[0];
+ }
+ std::vector<u8> ret;
+ Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0};
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1));
+ size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader);
+ for (auto& cert : certs) {
+ EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)};
+ data_offset += cert.size();
+ ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header),
+ reinterpret_cast<u8*>(&entry_header + 1));
+ }
+ for (auto& cert : certs) {
+ ret.insert(ret.end(), cert.begin(), cert.end());
+ }
+ return ret;
+ }
+
+ ResultVal<std::vector<u8>> ReadImpl(size_t size) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ std::vector<u8> res(size);
+ ResultVal<size_t> actual = backend->Read(res);
+ if (actual.Failed()) {
+ return actual.Code();
+ }
+ res.resize(*actual);
+ return res;
+ }
+
+ ResultVal<size_t> WriteImpl(std::span<const u8> data) {
+ ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; });
+ return backend->Write(data);
+ }
+
+ ResultVal<s32> PendingImpl() {
+ LOG_WARNING(Service_SSL, "(STUBBED) called.");
+ return 0;
+ }
+
+ void SetSocketDescriptor(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const s32 fd = rp.Pop<s32>();
+ const ResultVal<s32> res = SetSocketDescriptorImpl(fd);
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(-1));
+ }
+
+ void SetHostName(HLERequestContext& ctx) {
+ const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer());
+ const Result res = SetHostNameImpl(hostname);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetVerifyOption(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 option = rp.Pop<u32>();
+ const Result res = SetVerifyOptionImpl(option);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetIoMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetIoModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshake(HLERequestContext& ctx) {
+ const Result res = DoHandshakeImpl();
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void DoHandshakeGetServerCert(HLERequestContext& ctx) {
+ struct OutputParameters {
+ u32 certs_size;
+ u32 certs_count;
+ };
+ static_assert(sizeof(OutputParameters) == 0x8);
+
+ const Result res = DoHandshakeImpl();
+ OutputParameters out{};
+ if (res == ResultSuccess) {
+ auto certs = backend->GetServerCerts();
+ if (certs.Succeeded()) {
+ const std::vector<u8> certs_buf = SerializeServerCerts(*certs);
+ ctx.WriteBuffer(certs_buf);
+ out.certs_count = static_cast<u32>(certs->size());
+ out.certs_size = static_cast<u32>(certs_buf.size());
+ }
+ }
+ IPC::ResponseBuilder rb{ctx, 4};
+ rb.Push(res);
+ rb.PushRaw(out);
+ }
+
+ void Read(HLERequestContext& ctx) {
+ const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ if (res.Succeeded()) {
+ rb.Push(static_cast<u32>(res->size()));
+ ctx.WriteBuffer(*res);
+ } else {
+ rb.Push(static_cast<u32>(0));
+ }
+ }
+
+ void Write(HLERequestContext& ctx) {
+ const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer());
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push(static_cast<u32>(res.ValueOr(0)));
+ }
+
+ void Pending(HLERequestContext& ctx) {
+ const ResultVal<s32> res = PendingImpl();
+ IPC::ResponseBuilder rb{ctx, 3};
+ rb.Push(res.Code());
+ rb.Push<s32>(res.ValueOr(0));
+ }
+
+ void SetSessionCacheMode(HLERequestContext& ctx) {
+ IPC::RequestParser rp{ctx};
+ const u32 mode = rp.Pop<u32>();
+ const Result res = SetSessionCacheModeImpl(mode);
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(res);
+ }
+
+ void SetOption(HLERequestContext& ctx) {
+ struct Parameters {
+ OptionType option;
+ s32 value;
+ };
+ static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size");
+
+ IPC::RequestParser rp{ctx};
+ const auto parameters = rp.PopRaw<Parameters>();
+
+ switch (parameters.option) {
+ case OptionType::DoNotCloseSocket:
+ do_not_close_socket = static_cast<bool>(parameters.value);
+ break;
+ case OptionType::GetServerCertChain:
+ get_server_cert_chain = static_cast<bool>(parameters.value);
+ break;
+ default:
+ LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option,
+ parameters.value);
+ }
+
+ IPC::ResponseBuilder rb{ctx, 2};
+ rb.Push(ResultSuccess);
+ }
};
class ISslContext final : public ServiceFramework<ISslContext> {
public:
explicit ISslContext(Core::System& system_, SslVersion version)
- : ServiceFramework{system_, "ISslContext"}, ssl_version{version} {
+ : ServiceFramework{system_, "ISslContext"}, ssl_version{version},
+ shared_data{std::make_shared<SslContextSharedData>()} {
static const FunctionInfo functions[] = {
{0, &ISslContext::SetOption, "SetOption"},
{1, nullptr, "GetOption"},
{2, &ISslContext::CreateConnection, "CreateConnection"},
- {3, nullptr, "GetConnectionCount"},
+ {3, &ISslContext::GetConnectionCount, "GetConnectionCount"},
{4, &ISslContext::ImportServerPki, "ImportServerPki"},
{5, &ISslContext::ImportClientPki, "ImportClientPki"},
{6, nullptr, "RemoveServerPki"},
@@ -111,6 +416,7 @@ public:
private:
SslVersion ssl_version;
+ std::shared_ptr<SslContextSharedData> shared_data;
void SetOption(HLERequestContext& ctx) {
struct Parameters {
@@ -130,11 +436,24 @@ private:
}
void CreateConnection(HLERequestContext& ctx) {
- LOG_WARNING(Service_SSL, "(STUBBED) called");
+ LOG_WARNING(Service_SSL, "called");
+
+ auto backend_res = CreateSSLConnectionBackend();
IPC::ResponseBuilder rb{ctx, 2, 0, 1};
+ rb.Push(backend_res.Code());
+ if (backend_res.Succeeded()) {
+ rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data,
+ std::move(*backend_res));
+ }
+ }
+
+ void GetConnectionCount(HLERequestContext& ctx) {
+ LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count);
+
+ IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess);
- rb.PushIpcInterface<ISslConnection>(system, ssl_version);
+ rb.Push(shared_data->connection_count);
}
void ImportServerPki(HLERequestContext& ctx) {
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h
new file mode 100644
index 000000000..409f4367c
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend.h
@@ -0,0 +1,45 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#pragma once
+
+#include <memory>
+#include <span>
+#include <string>
+#include <vector>
+
+#include "common/common_types.h"
+
+#include "core/hle/result.h"
+
+namespace Network {
+class SocketBase;
+}
+
+namespace Service::SSL {
+
+constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103};
+constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106};
+constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205};
+constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up
+
+// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake,
+// with no way in the latter case to distinguish whether the client should poll
+// for read or write. The one official client I've seen handles this by always
+// polling for read (with a timeout).
+constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204};
+
+class SSLConnectionBackend {
+public:
+ virtual ~SSLConnectionBackend() {}
+ virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0;
+ virtual Result SetHostName(const std::string& hostname) = 0;
+ virtual Result DoHandshake() = 0;
+ virtual ResultVal<size_t> Read(std::span<u8> data) = 0;
+ virtual ResultVal<size_t> Write(std::span<const u8> data) = 0;
+ virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend();
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp
new file mode 100644
index 000000000..2f4f23c42
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_none.cpp
@@ -0,0 +1,16 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include "common/logging/log.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+
+namespace Service::SSL {
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because no SSL backend is available on this platform");
+ return ResultInternalError;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
new file mode 100644
index 000000000..6ca869dbf
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp
@@ -0,0 +1,351 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include <openssl/bio.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include <openssl/x509.h>
+
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+using namespace Common::FS;
+
+namespace Service::SSL {
+
+// Import OpenSSL's `SSL` type into the namespace. This is needed because the
+// namespace is also named `SSL`.
+using ::SSL;
+
+namespace {
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SSL_CTX* ssl_ctx;
+IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
+BIO_METHOD* bio_meth;
+
+Result CheckOpenSSLErrors();
+void OneTimeInit();
+void OneTimeInitLogFile();
+bool OneTimeInitBIO();
+
+} // namespace
+
+class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(Service_SSL,
+ "Can't create SSL connection because OpenSSL one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ ssl = SSL_new(ssl_ctx);
+ if (!ssl) {
+ LOG_ERROR(Service_SSL, "SSL_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ SSL_set_connect_state(ssl);
+
+ bio = BIO_new(bio_meth);
+ if (!bio) {
+ LOG_ERROR(Service_SSL, "BIO_new failed");
+ return CheckOpenSSLErrors();
+ }
+
+ BIO_set_data(bio, this);
+ BIO_set_init(bio, 1);
+ SSL_set_bio(ssl, bio, bio);
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
+ LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
+ LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname);
+ return CheckOpenSSLErrors();
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ SSL_set_verify_result(ssl, X509_V_OK);
+ const int ret = SSL_do_handshake(ssl);
+ const long verify_result = SSL_get_verify_result(ssl);
+ if (verify_result != X509_V_OK) {
+ LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}",
+ X509_verify_cert_error_string(verify_result));
+ return CheckOpenSSLErrors();
+ }
+ if (ret <= 0) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ if (ssl_err == SSL_ERROR_ZERO_RETURN ||
+ (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ }
+ return HandleReturn("SSL_do_handshake", 0, ret).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_read_ex", actual, ret);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual);
+ return HandleReturn("SSL_write_ex", actual, ret);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) {
+ const int ssl_err = SSL_get_error(ssl, ret);
+ CheckOpenSSLErrors();
+ switch (ssl_err) {
+ case SSL_ERROR_NONE:
+ return actual;
+ case SSL_ERROR_ZERO_RETURN:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what);
+ // DoHandshake special-cases this, but for Read and Write:
+ return size_t(0);
+ case SSL_ERROR_WANT_READ:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what);
+ return ResultWouldBlock;
+ case SSL_ERROR_WANT_WRITE:
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what);
+ return ResultWouldBlock;
+ default:
+ if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) {
+ LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what);
+ return size_t(0);
+ }
+ LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err);
+ return ResultInternalError;
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl);
+ if (!chain) {
+ LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr");
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ int count = sk_X509_num(chain);
+ ASSERT(count >= 0);
+ for (int i = 0; i < count; i++) {
+ X509* x509 = sk_X509_value(chain, i);
+ ASSERT_OR_EXECUTE(x509 != nullptr, { continue; });
+ unsigned char* buf = nullptr;
+ int len = i2d_X509(x509, &buf);
+ ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; });
+ ret.emplace_back(buf, buf + len);
+ OPENSSL_free(buf);
+ }
+ return ret;
+ }
+
+ ~SSLConnectionBackendOpenSSL() {
+ // these are null-tolerant:
+ SSL_free(ssl);
+ BIO_free(bio);
+ }
+
+ static void KeyLogCallback(const SSL* ssl, const char* line) {
+ std::string str(line);
+ str.push_back('\n');
+ // Do this in a single WriteString for atomicity if multiple instances
+ // are running on different threads (though that can't currently
+ // happen).
+ if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) {
+ LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE");
+ }
+ LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line);
+ }
+
+ static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to send but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ return 1;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) {
+ auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket");
+ BIO_clear_retry_flags(bio);
+ auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len});
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ *actual_p = actual;
+ if (actual == 0) {
+ self->got_read_eof = true;
+ }
+ return actual ? 1 : 0;
+ case Network::Errno::AGAIN:
+ BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY);
+ return 0;
+ default:
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return -1;
+ }
+ }
+
+ static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) {
+ switch (cmd) {
+ case BIO_CTRL_FLUSH:
+ // Nothing to flush.
+ return 1;
+ case BIO_CTRL_PUSH:
+ case BIO_CTRL_POP:
+#ifdef BIO_CTRL_GET_KTLS_SEND
+ case BIO_CTRL_GET_KTLS_SEND:
+ case BIO_CTRL_GET_KTLS_RECV:
+#endif
+ // We don't support these operations, but don't bother logging them
+ // as they're nothing unusual.
+ return 0;
+ default:
+ LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg);
+ return 0;
+ }
+ }
+
+ SSL* ssl = nullptr;
+ BIO* bio = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendOpenSSL>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+namespace {
+
+Result CheckOpenSSLErrors() {
+ unsigned long rc;
+ const char* file;
+ int line;
+ const char* func;
+ const char* data;
+ int flags;
+#if OPENSSL_VERSION_NUMBER >= 0x30000000L
+ while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags)))
+#else
+ // Can't get function names from OpenSSL on this version, so use mine:
+ func = __func__;
+ while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags)))
+#endif
+ {
+ std::string msg;
+ msg.resize(1024, '\0');
+ ERR_error_string_n(rc, msg.data(), msg.size());
+ msg.resize(strlen(msg.data()), '\0');
+ if (flags & ERR_TXT_STRING) {
+ msg.append(" | ");
+ msg.append(data);
+ }
+ Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error,
+ Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}",
+ msg);
+ }
+ return ResultInternalError;
+}
+
+void OneTimeInit() {
+ ssl_ctx = SSL_CTX_new(TLS_client_method());
+ if (!ssl_ctx) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_new failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
+
+ if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) {
+ LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed");
+ CheckOpenSSLErrors();
+ return;
+ }
+
+ OneTimeInitLogFile();
+
+ if (!OneTimeInitBIO()) {
+ return;
+ }
+
+ one_time_init_success = true;
+}
+
+void OneTimeInitLogFile() {
+ const char* logfile = getenv("SSLKEYLOGFILE");
+ if (logfile) {
+ key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile,
+ FileShareFlag::ShareWriteOnly);
+ if (key_log_file.IsOpen()) {
+ SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback);
+ } else {
+ LOG_CRITICAL(Service_SSL,
+ "SSLKEYLOGFILE was set but file could not be opened; not logging keys!");
+ }
+ }
+}
+
+bool OneTimeInitBIO() {
+ bio_meth =
+ BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL");
+ if (!bio_meth ||
+ !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) ||
+ !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) ||
+ !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) {
+ LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD");
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
new file mode 100644
index 000000000..d8074339a
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp
@@ -0,0 +1,544 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+#include "common/error.h"
+#include "common/fs/file.h"
+#include "common/hex_util.h"
+#include "common/string_util.h"
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+// These includes are inside the namespace to avoid a conflict on MinGW where
+// the headers define an enum containing Network and Service as enumerators
+// (which clash with the correspondingly named namespaces).
+#define SECURITY_WIN32
+#include <schnlsp.h>
+#include <security.h>
+#include <wincrypt.h>
+
+std::once_flag one_time_init_flag;
+bool one_time_init_success = false;
+
+SCHANNEL_CRED schannel_cred{};
+CredHandle cred_handle;
+
+static void OneTimeInit() {
+ schannel_cred.dwVersion = SCHANNEL_CRED_VERSION;
+ schannel_cred.dwFlags =
+ SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
+ SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
+ SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
+ // ^ I'm assuming that nobody would want to connect Yuzu to a
+ // service that requires some OS-provided corporate client
+ // certificate, and presenting one to some arbitrary server
+ // might be a privacy concern? Who knows, though.
+
+ const SECURITY_STATUS ret =
+ AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
+ nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
+ if (ret != SEC_E_OK) {
+ // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
+ LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}",
+ Common::NativeErrorToString(ret));
+ return;
+ }
+
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting "
+ "keys; not logging keys!");
+ // Not fatal.
+ }
+
+ one_time_init_success = true;
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSchannel final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ std::call_once(one_time_init_flag, OneTimeInit);
+
+ if (!one_time_init_success) {
+ LOG_ERROR(
+ Service_SSL,
+ "Can't create SSL connection because Schannel one-time initialization failed");
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override {
+ socket = std::move(socket_in);
+ }
+
+ Result SetHostName(const std::string& hostname_in) override {
+ hostname = hostname_in;
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ while (1) {
+ Result r;
+ switch (handshake_state) {
+ case HandshakeState::Initial:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::ContinueNeeded:
+ case HandshakeState::IncompleteMessage:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess ||
+ (r = FillCiphertextReadBuf()) != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up");
+ return ResultInternalError;
+ }
+ if ((r = CallInitializeSecurityContext()) != ResultSuccess) {
+ return r;
+ }
+ // CallInitializeSecurityContext updated `handshake_state`.
+ continue;
+ case HandshakeState::DoneAfterFlush:
+ if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) {
+ return r;
+ }
+ handshake_state = HandshakeState::Connected;
+ return ResultSuccess;
+ case HandshakeState::Connected:
+ LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook");
+ return ResultInternalError;
+ case HandshakeState::Error:
+ return ResultInternalError;
+ }
+ }
+ }
+
+ Result FillCiphertextReadBuf() {
+ const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096;
+ read_buf_fill_size = 0;
+ // This unnecessarily zeroes the buffer; oh well.
+ const size_t offset = ciphertext_read_buf.size();
+ ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
+ ciphertext_read_buf.resize(offset + fill_size, 0);
+ const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size);
+ const auto [actual, err] = socket->Recv(0, read_span);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= fill_size);
+ ciphertext_read_buf.resize(offset + actual);
+ return ResultSuccess;
+ case Network::Errno::AGAIN:
+ ciphertext_read_buf.resize(offset);
+ return ResultWouldBlock;
+ default:
+ ciphertext_read_buf.resize(offset);
+ LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+
+ // Returns success if the write buffer has been completely emptied.
+ Result FlushCiphertextWriteBuf() {
+ while (!ciphertext_write_buf.empty()) {
+ const auto [actual, err] = socket->Send(ciphertext_write_buf, 0);
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size());
+ ciphertext_write_buf.erase(ciphertext_write_buf.begin(),
+ ciphertext_write_buf.begin() + actual);
+ break;
+ case Network::Errno::AGAIN:
+ return ResultWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err);
+ return ResultInternalError;
+ }
+ }
+ return ResultSuccess;
+ }
+
+ Result CallInitializeSecurityContext() {
+ const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
+ ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
+ ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
+ ISC_REQ_USE_SUPPLIED_CREDS;
+ unsigned long attr;
+ // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
+ std::array<SecBuffer, 2> input_buffers{{
+ // only used if `initial_call_done`
+ {
+ // [0]
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ {
+ // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
+ // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
+ // whole buffer wasn't used)
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ },
+ }};
+ std::array<SecBuffer, 2> output_buffers{{
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_TOKEN,
+ .pvBuffer = nullptr,
+ }, // [0]
+ {
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_ALERT,
+ .pvBuffer = nullptr,
+ }, // [1]
+ }};
+ SecBufferDesc input_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(input_buffers.size()),
+ .pBuffers = input_buffers.data(),
+ };
+ SecBufferDesc output_desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(output_buffers.size()),
+ .pBuffers = output_buffers.data(),
+ };
+ ASSERT_OR_EXECUTE_MSG(
+ input_buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+
+ bool initial_call_done = handshake_state != HandshakeState::Initial;
+ if (initial_call_done) {
+ LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext",
+ ciphertext_read_buf.size());
+ }
+
+ const SECURITY_STATUS ret =
+ InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr,
+ // Caller ensured we have set a hostname:
+ const_cast<char*>(hostname.value().c_str()), req,
+ 0, // Reserved1
+ 0, // TargetDataRep not used with Schannel
+ initial_call_done ? &input_desc : nullptr,
+ 0, // Reserved2
+ initial_call_done ? nullptr : &ctxt, &output_desc, &attr,
+ nullptr); // ptsExpiry
+
+ if (output_buffers[0].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
+ output_buffers[0].cbBuffer);
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end());
+ FreeContextBuffer(output_buffers[0].pvBuffer);
+ }
+
+ if (output_buffers[1].pvBuffer) {
+ const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
+ output_buffers[1].cbBuffer);
+ // The documentation doesn't explain what format this data is in.
+ LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
+ Common::HexToString(span));
+ }
+
+ switch (ret) {
+ case SEC_I_CONTINUE_NEEDED:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED");
+ if (input_buffers[1].BufferType == SECBUFFER_EXTRA) {
+ LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer);
+ ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - input_buffers[1].cbBuffer);
+ } else {
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ handshake_state = HandshakeState::ContinueNeeded;
+ return ResultSuccess;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE");
+ ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING);
+ read_buf_fill_size = input_buffers[1].cbBuffer;
+ handshake_state = HandshakeState::IncompleteMessage;
+ return ResultSuccess;
+ case SEC_E_OK:
+ LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK");
+ ciphertext_read_buf.clear();
+ handshake_state = HandshakeState::DoneAfterFlush;
+ return GrabStreamSizes();
+ default:
+ LOG_ERROR(Service_SSL,
+ "InitializeSecurityContext failed (probably certificate/protocol issue): {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ }
+
+ Result GrabStreamSizes() {
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
+ Common::NativeErrorToString(ret));
+ handshake_state = HandshakeState::Error;
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0 || got_read_eof) {
+ return size_t(0);
+ }
+ while (1) {
+ if (!cleartext_read_buf.empty()) {
+ const size_t read_size = std::min(cleartext_read_buf.size(), data.size());
+ std::memcpy(data.data(), cleartext_read_buf.data(), read_size);
+ cleartext_read_buf.erase(cleartext_read_buf.begin(),
+ cleartext_read_buf.begin() + read_size);
+ return read_size;
+ }
+ if (!ciphertext_read_buf.empty()) {
+ SecBuffer empty{
+ .cbBuffer = 0,
+ .BufferType = SECBUFFER_EMPTY,
+ .pvBuffer = nullptr,
+ };
+ std::array<SecBuffer, 5> buffers{{
+ {
+ .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = ciphertext_read_buf.data(),
+ },
+ empty,
+ empty,
+ empty,
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[0].cbBuffer == ciphertext_read_buf.size(),
+ { return ResultInternalError; }, "read buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+ SECURITY_STATUS ret =
+ DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr);
+ switch (ret) {
+ case SEC_E_OK:
+ ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA,
+ { return ResultInternalError; });
+ ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER,
+ { return ResultInternalError; });
+ cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer),
+ static_cast<u8*>(buffers[1].pvBuffer) +
+ buffers[1].cbBuffer);
+ if (buffers[3].BufferType == SECBUFFER_EXTRA) {
+ ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size());
+ ciphertext_read_buf.erase(ciphertext_read_buf.begin(),
+ ciphertext_read_buf.end() - buffers[3].cbBuffer);
+ } else {
+ ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY);
+ ciphertext_read_buf.clear();
+ }
+ continue;
+ case SEC_E_INCOMPLETE_MESSAGE:
+ break;
+ case SEC_I_CONTEXT_EXPIRED:
+ // Server hung up by sending close_notify.
+ got_read_eof = true;
+ return size_t(0);
+ default:
+ LOG_ERROR(Service_SSL, "DecryptMessage failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ }
+ const Result r = FillCiphertextReadBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ if (ciphertext_read_buf.empty()) {
+ got_read_eof = true;
+ return size_t(0);
+ }
+ }
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ if (handshake_state != HandshakeState::Connected) {
+ LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake");
+ return ResultInternalError;
+ }
+ if (data.size() == 0) {
+ return size_t(0);
+ }
+ data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage));
+ if (!cleartext_write_buf.empty()) {
+ // Already in the middle of a write. It wouldn't make sense to not
+ // finish sending the entire buffer since TLS has
+ // header/MAC/padding/etc.
+ if (data.size() != cleartext_write_buf.size() ||
+ std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) {
+ LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer");
+ return ResultInternalError;
+ }
+ return WriteAlreadyEncryptedData();
+ } else {
+ cleartext_write_buf.assign(data.begin(), data.end());
+ }
+
+ std::vector<u8> header_buf(stream_sizes.cbHeader, 0);
+ std::vector<u8> tmp_data_buf = cleartext_write_buf;
+ std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0);
+
+ std::array<SecBuffer, 3> buffers{{
+ {
+ .cbBuffer = stream_sizes.cbHeader,
+ .BufferType = SECBUFFER_STREAM_HEADER,
+ .pvBuffer = header_buf.data(),
+ },
+ {
+ .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()),
+ .BufferType = SECBUFFER_DATA,
+ .pvBuffer = tmp_data_buf.data(),
+ },
+ {
+ .cbBuffer = stream_sizes.cbTrailer,
+ .BufferType = SECBUFFER_STREAM_TRAILER,
+ .pvBuffer = trailer_buf.data(),
+ },
+ }};
+ ASSERT_OR_EXECUTE_MSG(
+ buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; },
+ "temp buffer too large");
+ SecBufferDesc desc{
+ .ulVersion = SECBUFFER_VERSION,
+ .cBuffers = static_cast<unsigned long>(buffers.size()),
+ .pBuffers = buffers.data(),
+ };
+
+ const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(),
+ header_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(),
+ tmp_data_buf.end());
+ ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(),
+ trailer_buf.end());
+ return WriteAlreadyEncryptedData();
+ }
+
+ ResultVal<size_t> WriteAlreadyEncryptedData() {
+ const Result r = FlushCiphertextWriteBuf();
+ if (r != ResultSuccess) {
+ return r;
+ }
+ // write buf is empty
+ const size_t cleartext_bytes_written = cleartext_write_buf.size();
+ cleartext_write_buf.clear();
+ return cleartext_bytes_written;
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ PCCERT_CONTEXT returned_cert = nullptr;
+ const SECURITY_STATUS ret =
+ QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
+ if (ret != SEC_E_OK) {
+ LOG_ERROR(Service_SSL,
+ "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}",
+ Common::NativeErrorToString(ret));
+ return ResultInternalError;
+ }
+ PCCERT_CONTEXT some_cert = nullptr;
+ std::vector<std::vector<u8>> certs;
+ while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) {
+ certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded),
+ static_cast<u8*>(some_cert->pbCertEncoded) +
+ some_cert->cbCertEncoded);
+ }
+ std::reverse(certs.begin(),
+ certs.end()); // Windows returns certs in reverse order from what we want
+ CertFreeCertificateContext(returned_cert);
+ return certs;
+ }
+
+ ~SSLConnectionBackendSchannel() {
+ if (handshake_state != HandshakeState::Initial) {
+ DeleteSecurityContext(&ctxt);
+ }
+ }
+
+ enum class HandshakeState {
+ // Haven't called anything yet.
+ Initial,
+ // `SEC_I_CONTINUE_NEEDED` was returned by
+ // `InitializeSecurityContext`; must finish sending data (if any) in
+ // the write buffer, then read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ ContinueNeeded,
+ // `SEC_E_INCOMPLETE_MESSAGE` was returned by
+ // `InitializeSecurityContext`; hopefully the write buffer is empty;
+ // must read at least one byte before calling
+ // `InitializeSecurityContext` again.
+ IncompleteMessage,
+ // `SEC_E_OK` was returned by `InitializeSecurityContext`; must
+ // finish sending data in the write buffer before having `DoHandshake`
+ // report success.
+ DoneAfterFlush,
+ // We finished the above and are now connected. At this point, writing
+ // and reading are separate 'state machines' represented by the
+ // nonemptiness of the ciphertext and cleartext read and write buffers.
+ Connected,
+ // Another error was returned and we shouldn't allow initialization
+ // to continue.
+ Error,
+ } handshake_state = HandshakeState::Initial;
+
+ CtxtHandle ctxt;
+ SecPkgContext_StreamSizes stream_sizes;
+
+ std::shared_ptr<Network::SocketBase> socket;
+ std::optional<std::string> hostname;
+
+ std::vector<u8> ciphertext_read_buf;
+ std::vector<u8> ciphertext_write_buf;
+ std::vector<u8> cleartext_read_buf;
+ std::vector<u8> cleartext_write_buf;
+
+ bool got_read_eof = false;
+ size_t read_buf_fill_size = 0;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSchannel>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/hle/service/ssl/ssl_backend_securetransport.cpp b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
new file mode 100644
index 000000000..b3083cbad
--- /dev/null
+++ b/src/core/hle/service/ssl/ssl_backend_securetransport.cpp
@@ -0,0 +1,222 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <mutex>
+
+// SecureTransport has been deprecated in its entirety in favor of
+// Network.framework, but that does not allow layering TLS on top of an
+// arbitrary socket.
+#if defined(__GNUC__) || defined(__clang__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#include <Security/SecureTransport.h>
+#pragma GCC diagnostic pop
+#endif
+
+#include "core/hle/service/ssl/ssl_backend.h"
+#include "core/internal_network/network.h"
+#include "core/internal_network/sockets.h"
+
+namespace {
+
+template <typename T>
+struct CFReleaser {
+ T ptr;
+
+ YUZU_NON_COPYABLE(CFReleaser);
+ constexpr CFReleaser() : ptr(nullptr) {}
+ constexpr CFReleaser(T ptr) : ptr(ptr) {}
+ constexpr operator T() {
+ return ptr;
+ }
+ ~CFReleaser() {
+ if (ptr) {
+ CFRelease(ptr);
+ }
+ }
+};
+
+std::string CFStringToString(CFStringRef cfstr) {
+ CFReleaser<CFDataRef> cfdata(
+ CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0));
+ ASSERT_OR_EXECUTE(cfdata, { return "???"; });
+ return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)),
+ CFDataGetLength(cfdata));
+}
+
+std::string OSStatusToString(OSStatus status) {
+ CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr));
+ if (!cfstr) {
+ return "[unknown error]";
+ }
+ return CFStringToString(cfstr);
+}
+
+} // namespace
+
+namespace Service::SSL {
+
+class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend {
+public:
+ Result Init() {
+ static std::once_flag once_flag;
+ std::call_once(once_flag, []() {
+ if (getenv("SSLKEYLOGFILE")) {
+ LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not "
+ "support exporting keys; not logging keys!");
+ // Not fatal.
+ }
+ });
+
+ context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType);
+ if (!context) {
+ LOG_ERROR(Service_SSL, "SSLCreateContext failed");
+ return ResultInternalError;
+ }
+
+ OSStatus status;
+ if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) ||
+ (status = SSLSetConnection(context, this))) {
+ LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}",
+ OSStatusToString(status));
+ return ResultInternalError;
+ }
+
+ return ResultSuccess;
+ }
+
+ void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override {
+ socket = std::move(in_socket);
+ }
+
+ Result SetHostName(const std::string& hostname) override {
+ OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size());
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ return ResultSuccess;
+ }
+
+ Result DoHandshake() override {
+ OSStatus status = SSLHandshake(context);
+ return HandleReturn("SSLHandshake", 0, status).Code();
+ }
+
+ ResultVal<size_t> Read(std::span<u8> data) override {
+ size_t actual;
+ OSStatus status = SSLRead(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLRead", actual, status);
+ }
+
+ ResultVal<size_t> Write(std::span<const u8> data) override {
+ size_t actual;
+ OSStatus status = SSLWrite(context, data.data(), data.size(), &actual);
+ ;
+ return HandleReturn("SSLWrite", actual, status);
+ }
+
+ ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) {
+ switch (status) {
+ case 0:
+ return actual;
+ case errSSLWouldBlock:
+ return ResultWouldBlock;
+ default: {
+ std::string reason;
+ if (got_read_eof) {
+ reason = "server hung up";
+ } else {
+ reason = OSStatusToString(status);
+ }
+ LOG_ERROR(Service_SSL, "{} failed: {}", what, reason);
+ return ResultInternalError;
+ }
+ }
+ }
+
+ ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
+ CFReleaser<SecTrustRef> trust;
+ OSStatus status = SSLCopyPeerTrust(context, &trust.ptr);
+ if (status) {
+ LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status));
+ return ResultInternalError;
+ }
+ std::vector<std::vector<u8>> ret;
+ for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) {
+ SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i);
+ CFReleaser<CFDataRef> data(SecCertificateCopyData(cert));
+ ASSERT_OR_EXECUTE(data, { return ResultInternalError; });
+ const u8* ptr = CFDataGetBytePtr(data);
+ ret.emplace_back(ptr, ptr + CFDataGetLength(data));
+ }
+ return ret;
+ }
+
+ static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) {
+ return ReadOrWriteCallback(connection, data, dataLength, true);
+ }
+
+ static OSStatus WriteCallback(SSLConnectionRef connection, const void* data,
+ size_t* dataLength) {
+ return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false);
+ }
+
+ static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength,
+ bool is_read) {
+ auto self =
+ static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection));
+ ASSERT_OR_EXECUTE_MSG(
+ self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket",
+ is_read ? "read" : "write");
+
+ // SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
+ // expected to read/write the full requested dataLength or return an
+ // error, so we have to add a loop ourselves.
+ size_t requested_len = *dataLength;
+ size_t offset = 0;
+ while (offset < requested_len) {
+ std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset);
+ auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0);
+ LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset,
+ actual, cur.size(), static_cast<s32>(err));
+ switch (err) {
+ case Network::Errno::SUCCESS:
+ offset += actual;
+ if (actual == 0) {
+ ASSERT(is_read);
+ self->got_read_eof = true;
+ return errSecEndOfData;
+ }
+ break;
+ case Network::Errno::AGAIN:
+ *dataLength = offset;
+ return errSSLWouldBlock;
+ default:
+ LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}",
+ is_read ? "recv" : "send", err);
+ return errSecIO;
+ }
+ }
+ ASSERT(offset == requested_len);
+ return 0;
+ }
+
+private:
+ CFReleaser<SSLContextRef> context = nullptr;
+ bool got_read_eof = false;
+
+ std::shared_ptr<Network::SocketBase> socket;
+};
+
+ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
+ auto conn = std::make_unique<SSLConnectionBackendSecureTransport>();
+ const Result res = conn->Init();
+ if (res.IsFailure()) {
+ return res;
+ }
+ return conn;
+}
+
+} // namespace Service::SSL
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 75ac10a9c..28f89c599 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -27,6 +27,7 @@
#include "common/assert.h"
#include "common/common_types.h"
+#include "common/expected.h"
#include "common/logging/log.h"
#include "common/settings.h"
#include "core/internal_network/network.h"
@@ -97,6 +98,8 @@ bool EnableNonBlock(SOCKET fd, bool enable) {
Errno TranslateNativeError(int e) {
switch (e) {
+ case 0:
+ return Errno::SUCCESS;
case WSAEBADF:
return Errno::BADF;
case WSAEINVAL:
@@ -121,6 +124,8 @@ Errno TranslateNativeError(int e) {
return Errno::MSGSIZE;
case WSAETIMEDOUT:
return Errno::TIMEDOUT;
+ case WSAEINPROGRESS:
+ return Errno::INPROGRESS;
default:
UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
return Errno::OTHER;
@@ -195,6 +200,8 @@ bool EnableNonBlock(int fd, bool enable) {
Errno TranslateNativeError(int e) {
switch (e) {
+ case 0:
+ return Errno::SUCCESS;
case EBADF:
return Errno::BADF;
case EINVAL:
@@ -219,8 +226,10 @@ Errno TranslateNativeError(int e) {
return Errno::MSGSIZE;
case ETIMEDOUT:
return Errno::TIMEDOUT;
+ case EINPROGRESS:
+ return Errno::INPROGRESS;
default:
- UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
+ UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
return Errno::OTHER;
}
}
@@ -234,15 +243,84 @@ Errno GetAndLogLastError() {
int e = errno;
#endif
const Errno err = TranslateNativeError(e);
- if (err == Errno::AGAIN || err == Errno::TIMEDOUT) {
+ if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
+ // These happen during normal operation, so only log them at debug level.
+ LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
return err;
}
LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
return err;
}
-int TranslateDomain(Domain domain) {
+GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
+ switch (gai_err) {
+ case 0:
+ return GetAddrInfoError::SUCCESS;
+#ifdef EAI_ADDRFAMILY
+ case EAI_ADDRFAMILY:
+ return GetAddrInfoError::ADDRFAMILY;
+#endif
+ case EAI_AGAIN:
+ return GetAddrInfoError::AGAIN;
+ case EAI_BADFLAGS:
+ return GetAddrInfoError::BADFLAGS;
+ case EAI_FAIL:
+ return GetAddrInfoError::FAIL;
+ case EAI_FAMILY:
+ return GetAddrInfoError::FAMILY;
+ case EAI_MEMORY:
+ return GetAddrInfoError::MEMORY;
+ case EAI_NONAME:
+ return GetAddrInfoError::NONAME;
+ case EAI_SERVICE:
+ return GetAddrInfoError::SERVICE;
+ case EAI_SOCKTYPE:
+ return GetAddrInfoError::SOCKTYPE;
+ // These codes may not be defined on all systems:
+#ifdef EAI_SYSTEM
+ case EAI_SYSTEM:
+ return GetAddrInfoError::SYSTEM;
+#endif
+#ifdef EAI_BADHINTS
+ case EAI_BADHINTS:
+ return GetAddrInfoError::BADHINTS;
+#endif
+#ifdef EAI_PROTOCOL
+ case EAI_PROTOCOL:
+ return GetAddrInfoError::PROTOCOL;
+#endif
+#ifdef EAI_OVERFLOW
+ case EAI_OVERFLOW:
+ return GetAddrInfoError::OVERFLOW_;
+#endif
+ default:
+#ifdef EAI_NODATA
+ // This can't be a case statement because it would create a duplicate
+ // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
+ if (gai_err == EAI_NODATA) {
+ return GetAddrInfoError::NODATA;
+ }
+#endif
+ return GetAddrInfoError::OTHER;
+ }
+}
+
+Domain TranslateDomainFromNative(int domain) {
+ switch (domain) {
+ case 0:
+ return Domain::Unspecified;
+ case AF_INET:
+ return Domain::INET;
+ default:
+ UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
+ return Domain::INET;
+ }
+}
+
+int TranslateDomainToNative(Domain domain) {
switch (domain) {
+ case Domain::Unspecified:
+ return 0;
case Domain::INET:
return AF_INET;
default:
@@ -251,20 +329,58 @@ int TranslateDomain(Domain domain) {
}
}
-int TranslateType(Type type) {
+Type TranslateTypeFromNative(int type) {
+ switch (type) {
+ case 0:
+ return Type::Unspecified;
+ case SOCK_STREAM:
+ return Type::STREAM;
+ case SOCK_DGRAM:
+ return Type::DGRAM;
+ case SOCK_RAW:
+ return Type::RAW;
+ case SOCK_SEQPACKET:
+ return Type::SEQPACKET;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented type={}", type);
+ return Type::STREAM;
+ }
+}
+
+int TranslateTypeToNative(Type type) {
switch (type) {
+ case Type::Unspecified:
+ return 0;
case Type::STREAM:
return SOCK_STREAM;
case Type::DGRAM:
return SOCK_DGRAM;
+ case Type::RAW:
+ return SOCK_RAW;
default:
UNIMPLEMENTED_MSG("Unimplemented type={}", type);
return 0;
}
}
-int TranslateProtocol(Protocol protocol) {
+Protocol TranslateProtocolFromNative(int protocol) {
+ switch (protocol) {
+ case 0:
+ return Protocol::Unspecified;
+ case IPPROTO_TCP:
+ return Protocol::TCP;
+ case IPPROTO_UDP:
+ return Protocol::UDP;
+ default:
+ UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
+ return Protocol::Unspecified;
+ }
+}
+
+int TranslateProtocolToNative(Protocol protocol) {
switch (protocol) {
+ case Protocol::Unspecified:
+ return 0;
case Protocol::TCP:
return IPPROTO_TCP;
case Protocol::UDP:
@@ -275,21 +391,10 @@ int TranslateProtocol(Protocol protocol) {
}
}
-SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
- sockaddr_in input;
- std::memcpy(&input, &input_, sizeof(input));
-
+SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
SockAddrIn result;
- switch (input.sin_family) {
- case AF_INET:
- result.family = Domain::INET;
- break;
- default:
- UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
- result.family = Domain::INET;
- break;
- }
+ result.family = TranslateDomainFromNative(input.sin_family);
result.portno = ntohs(input.sin_port);
@@ -301,22 +406,33 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
short TranslatePollEvents(PollEvents events) {
short result = 0;
- if (True(events & PollEvents::In)) {
- events &= ~PollEvents::In;
- result |= POLLIN;
- }
- if (True(events & PollEvents::Pri)) {
- events &= ~PollEvents::Pri;
+ const auto translate = [&result, &events](PollEvents guest, short host) {
+ if (True(events & guest)) {
+ events &= ~guest;
+ result |= host;
+ }
+ };
+
+ translate(PollEvents::In, POLLIN);
+ translate(PollEvents::Pri, POLLPRI);
+ translate(PollEvents::Out, POLLOUT);
+ translate(PollEvents::Err, POLLERR);
+ translate(PollEvents::Hup, POLLHUP);
+ translate(PollEvents::Nval, POLLNVAL);
+ translate(PollEvents::RdNorm, POLLRDNORM);
+ translate(PollEvents::RdBand, POLLRDBAND);
+ translate(PollEvents::WrBand, POLLWRBAND);
+
#ifdef _WIN32
- LOG_WARNING(Service, "Winsock doesn't support POLLPRI");
-#else
- result |= POLLPRI;
+ short allowed_events = POLLRDBAND | POLLRDNORM | POLLWRNORM;
+ // Unlike poll on other OSes, WSAPoll will complain if any other flags are set on input.
+ if (result & ~allowed_events) {
+ LOG_DEBUG(Network,
+ "Removing WSAPoll input events 0x{:x} because Windows doesn't support them",
+ result & ~allowed_events);
+ }
+ result &= allowed_events;
#endif
- }
- if (True(events & PollEvents::Out)) {
- events &= ~PollEvents::Out;
- result |= POLLOUT;
- }
UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
@@ -337,6 +453,10 @@ PollEvents TranslatePollRevents(short revents) {
translate(POLLOUT, PollEvents::Out);
translate(POLLERR, PollEvents::Err);
translate(POLLHUP, PollEvents::Hup);
+ translate(POLLNVAL, PollEvents::Nval);
+ translate(POLLRDNORM, PollEvents::RdNorm);
+ translate(POLLRDBAND, PollEvents::RdBand);
+ translate(POLLWRBAND, PollEvents::WrBand);
UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
@@ -360,12 +480,51 @@ std::optional<IPv4Address> GetHostIPv4Address() {
return {};
}
- std::array<char, 16> ip_addr = {};
- ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
- nullptr);
return TranslateIPv4(network_interface->ip_address);
}
+std::string IPv4AddressToString(IPv4Address ip_addr) {
+ std::array<char, INET_ADDRSTRLEN> buf = {};
+ ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
+ return std::string(buf.data());
+}
+
+u32 IPv4AddressToInteger(IPv4Address ip_addr) {
+ return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
+ static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
+}
+
+Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
+ const std::string& host, const std::optional<std::string>& service) {
+ addrinfo hints{};
+ hints.ai_family = AF_INET; // Switch only supports IPv4.
+ addrinfo* addrinfo;
+ s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
+ &hints, &addrinfo);
+ if (gai_err != 0) {
+ return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
+ }
+ std::vector<AddrInfo> ret;
+ for (auto* current = addrinfo; current; current = current->ai_next) {
+ // We should only get AF_INET results due to the hints value.
+ ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
+ addrinfo->ai_addrlen == sizeof(sockaddr_in),
+ continue;);
+
+ AddrInfo& out = ret.emplace_back();
+ out.family = TranslateDomainFromNative(current->ai_family);
+ out.socket_type = TranslateTypeFromNative(current->ai_socktype);
+ out.protocol = TranslateProtocolFromNative(current->ai_protocol);
+ out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
+ current->ai_addrlen);
+ if (current->ai_canonname != nullptr) {
+ out.canon_name = current->ai_canonname;
+ }
+ }
+ freeaddrinfo(addrinfo);
+ return ret;
+}
+
std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
const size_t num = pollfds.size();
@@ -411,9 +570,21 @@ Socket::Socket(Socket&& rhs) noexcept {
}
template <typename T>
-Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
+std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_so, int option) {
+ T value{};
+ socklen_t len = sizeof(value);
+ const int result = getsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
+ if (result != SOCKET_ERROR) {
+ ASSERT(len == sizeof(value));
+ return {value, Errno::SUCCESS};
+ }
+ return {value, GetAndLogLastError()};
+}
+
+template <typename T>
+Errno Socket::SetSockOpt(SOCKET fd_so, int option, T value) {
const int result =
- setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
+ setsockopt(fd_so, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
if (result != SOCKET_ERROR) {
return Errno::SUCCESS;
}
@@ -421,7 +592,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
}
Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
- fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol));
+ fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
+ TranslateProtocolToNative(protocol));
if (fd != INVALID_SOCKET) {
return Errno::SUCCESS;
}
@@ -430,19 +602,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
}
std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- const SOCKET new_socket = accept(fd, &addr, &addrlen);
+ const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
if (new_socket == INVALID_SOCKET) {
return {AcceptResult{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
-
AcceptResult result{
.socket = std::make_unique<Socket>(new_socket),
- .sockaddr_in = TranslateToSockAddrIn(addr),
+ .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
};
return {std::move(result), Errno::SUCCESS};
@@ -458,25 +628,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
}
std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) {
+ if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
return {SockAddrIn{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
- return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
+ return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
}
std::pair<SockAddrIn, Errno> Socket::GetSockName() {
- sockaddr addr;
+ sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
- if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) {
+ if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
return {SockAddrIn{}, GetAndLogLastError()};
}
- ASSERT(addrlen == sizeof(sockaddr_in));
- return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
+ return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
}
Errno Socket::Bind(SockAddrIn addr) {
@@ -519,7 +687,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
return GetAndLogLastError();
}
-std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
+std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -532,21 +700,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
return {-1, GetAndLogLastError()};
}
-std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
+std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
- sockaddr addr_in{};
+ sockaddr_in addr_in{};
socklen_t addrlen = sizeof(addr_in);
socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
- sockaddr* const p_addr_in = addr ? &addr_in : nullptr;
+ sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
if (result != SOCKET_ERROR) {
if (addr) {
- ASSERT(addrlen == sizeof(addr_in));
- *addr = TranslateToSockAddrIn(addr_in);
+ *addr = TranslateToSockAddrIn(addr_in, addrlen);
}
return {static_cast<s32>(result), Errno::SUCCESS};
}
@@ -597,6 +764,11 @@ Errno Socket::Close() {
return Errno::SUCCESS;
}
+std::pair<Errno, Errno> Socket::GetPendingError() {
+ auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
+ return {TranslateNativeError(pending_err), getsockopt_err};
+}
+
Errno Socket::SetLinger(bool enable, u32 linger) {
return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
}
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index 1e09a007a..badcb8369 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -5,6 +5,7 @@
#include <array>
#include <optional>
+#include <vector>
#include "common/common_funcs.h"
#include "common/common_types.h"
@@ -16,6 +17,11 @@
#include <netinet/in.h>
#endif
+namespace Common {
+template <typename T, typename E>
+class Expected;
+}
+
namespace Network {
class SocketBase;
@@ -36,6 +42,26 @@ enum class Errno {
NETUNREACH,
TIMEDOUT,
MSGSIZE,
+ INPROGRESS,
+ OTHER,
+};
+
+enum class GetAddrInfoError {
+ SUCCESS,
+ ADDRFAMILY,
+ AGAIN,
+ BADFLAGS,
+ FAIL,
+ FAMILY,
+ MEMORY,
+ NODATA,
+ NONAME,
+ SERVICE,
+ SOCKTYPE,
+ SYSTEM,
+ BADHINTS,
+ PROTOCOL,
+ OVERFLOW_,
OTHER,
};
@@ -49,6 +75,9 @@ enum class PollEvents : u16 {
Err = 1 << 3,
Hup = 1 << 4,
Nval = 1 << 5,
+ RdNorm = 1 << 6,
+ RdBand = 1 << 7,
+ WrBand = 1 << 8,
};
DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
@@ -82,4 +111,11 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
std::optional<IPv4Address> GetHostIPv4Address();
+std::string IPv4AddressToString(IPv4Address ip_addr);
+u32 IPv4AddressToInteger(IPv4Address ip_addr);
+
+// named to avoid name collision with Windows macro
+Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddressInfo(
+ const std::string& host, const std::optional<std::string>& service);
+
} // namespace Network
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp
index 7a77171c2..ce0dee970 100644
--- a/src/core/internal_network/socket_proxy.cpp
+++ b/src/core/internal_network/socket_proxy.cpp
@@ -10,6 +10,7 @@
#include "core/internal_network/network.h"
#include "core/internal_network/network_interface.h"
#include "core/internal_network/socket_proxy.h"
+#include "network/network.h"
#if YUZU_UNIX
#include <sys/socket.h>
@@ -98,7 +99,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
return Errno::SUCCESS;
}
-std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
+std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
LOG_WARNING(Network, "(STUBBED) called");
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -106,7 +107,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
return {static_cast<s32>(0), Errno::SUCCESS};
}
-std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) {
+std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
ASSERT(flags == 0);
ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -140,8 +141,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
}
}
-std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message,
- SockAddrIn* addr, std::size_t max_length) {
+std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
+ std::size_t max_length) {
ProxyPacket& packet = received_packets.front();
if (addr) {
addr->family = Domain::INET;
@@ -153,10 +154,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
std::size_t read_bytes;
if (packet.data.size() > max_length) {
read_bytes = max_length;
- message.clear();
- std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
- std::back_inserter(message));
- message.resize(max_length);
+ memcpy(message.data(), packet.data.data(), max_length);
if (protocol == Protocol::UDP) {
if (!peek) {
@@ -171,9 +169,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
}
} else {
read_bytes = packet.data.size();
- message.clear();
- std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
- message.resize(max_length);
+ memcpy(message.data(), packet.data.data(), read_bytes);
if (!peek) {
received_packets.pop();
}
@@ -293,6 +289,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
return Errno::SUCCESS;
}
+std::pair<Errno, Errno> ProxySocket::GetPendingError() {
+ LOG_DEBUG(Network, "(STUBBED) called");
+ return {Errno::SUCCESS, Errno::SUCCESS};
+}
+
bool ProxySocket::IsOpened() const {
return fd != INVALID_SOCKET;
}
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h
index 6e991fa38..70500cf4a 100644
--- a/src/core/internal_network/socket_proxy.h
+++ b/src/core/internal_network/socket_proxy.h
@@ -10,10 +10,12 @@
#include "common/common_funcs.h"
#include "core/internal_network/sockets.h"
-#include "network/network.h"
+#include "network/room_member.h"
namespace Network {
+class RoomNetwork;
+
class ProxySocket : public SocketBase {
public:
explicit ProxySocket(RoomNetwork& room_network_) noexcept;
@@ -39,11 +41,11 @@ public:
Errno Shutdown(ShutdownHow how) override;
- std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
+ std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
- std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
+ std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
- std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr,
+ std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
std::size_t max_length);
std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -74,6 +76,8 @@ public:
template <typename T>
Errno SetSockOpt(SOCKET fd, int option, T value);
+ std::pair<Errno, Errno> GetPendingError() override;
+
bool IsOpened() const override;
private:
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h
index 11e479e50..4ba51f62c 100644
--- a/src/core/internal_network/sockets.h
+++ b/src/core/internal_network/sockets.h
@@ -15,12 +15,13 @@
#include "common/common_types.h"
#include "core/internal_network/network.h"
-#include "network/network.h"
// TODO: C++20 Replace std::vector usages with std::span
namespace Network {
+struct ProxyPacket;
+
class SocketBase {
public:
#ifdef YUZU_UNIX
@@ -59,10 +60,9 @@ public:
virtual Errno Shutdown(ShutdownHow how) = 0;
- virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0;
+ virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
- virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message,
- SockAddrIn* addr) = 0;
+ virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
@@ -87,6 +87,8 @@ public:
virtual Errno SetNonBlock(bool enable) = 0;
+ virtual std::pair<Errno, Errno> GetPendingError() = 0;
+
virtual bool IsOpened() const = 0;
virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
@@ -126,9 +128,9 @@ public:
Errno Shutdown(ShutdownHow how) override;
- std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override;
+ std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
- std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override;
+ std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -156,6 +158,11 @@ public:
template <typename T>
Errno SetSockOpt(SOCKET fd, int option, T value);
+ std::pair<Errno, Errno> GetPendingError() override;
+
+ template <typename T>
+ std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
+
bool IsOpened() const override;
void HandleProxyPacket(const ProxyPacket& packet) override;
diff --git a/src/core/loader/deconstructed_rom_directory.cpp b/src/core/loader/deconstructed_rom_directory.cpp
index 3be9b71cf..e04ad19db 100644
--- a/src/core/loader/deconstructed_rom_directory.cpp
+++ b/src/core/loader/deconstructed_rom_directory.cpp
@@ -153,7 +153,7 @@ AppLoader_DeconstructedRomDirectory::LoadResult AppLoader_DeconstructedRomDirect
// Load NSO modules
modules.clear();
- const VAddr base_address{GetInteger(process.PageTable().GetCodeRegionStart())};
+ const VAddr base_address{GetInteger(process.GetPageTable().GetCodeRegionStart())};
VAddr next_load_addr{base_address};
const FileSys::PatchManager pm{metadata.GetTitleID(), system.GetFileSystemController(),
system.GetContentProvider()};
diff --git a/src/core/loader/kip.cpp b/src/core/loader/kip.cpp
index 709e2564f..ffe976b94 100644
--- a/src/core/loader/kip.cpp
+++ b/src/core/loader/kip.cpp
@@ -96,7 +96,7 @@ AppLoader::LoadResult AppLoader_KIP::Load(Kernel::KProcess& process,
}
codeset.memory = std::move(program_image);
- const VAddr base_address = GetInteger(process.PageTable().GetCodeRegionStart());
+ const VAddr base_address = GetInteger(process.GetPageTable().GetCodeRegionStart());
process.LoadModule(std::move(codeset), base_address);
LOG_DEBUG(Loader, "loaded module {} @ 0x{:X}", kip->GetName(), base_address);
diff --git a/src/core/loader/nro.cpp b/src/core/loader/nro.cpp
index 7be6cf5f3..506808b5d 100644
--- a/src/core/loader/nro.cpp
+++ b/src/core/loader/nro.cpp
@@ -203,7 +203,7 @@ static bool LoadNroImpl(Kernel::KProcess& process, const std::vector<u8>& data)
// Load codeset for current process
codeset.memory = std::move(program_image);
- process.LoadModule(std::move(codeset), process.PageTable().GetCodeRegionStart());
+ process.LoadModule(std::move(codeset), process.GetPageTable().GetCodeRegionStart());
return true;
}
diff --git a/src/core/loader/nso.cpp b/src/core/loader/nso.cpp
index 79639f5e4..74cc9579f 100644
--- a/src/core/loader/nso.cpp
+++ b/src/core/loader/nso.cpp
@@ -167,7 +167,7 @@ AppLoader_NSO::LoadResult AppLoader_NSO::Load(Kernel::KProcess& process, Core::S
modules.clear();
// Load module
- const VAddr base_address = GetInteger(process.PageTable().GetCodeRegionStart());
+ const VAddr base_address = GetInteger(process.GetPageTable().GetCodeRegionStart());
if (!LoadModule(process, system, *file, base_address, true, true)) {
return {ResultStatus::ErrorLoadingNSO, {}};
}
diff --git a/src/core/memory.cpp b/src/core/memory.cpp
index 7538c1d23..513bc4edb 100644
--- a/src/core/memory.cpp
+++ b/src/core/memory.cpp
@@ -31,10 +31,10 @@ struct Memory::Impl {
explicit Impl(Core::System& system_) : system{system_} {}
void SetCurrentPageTable(Kernel::KProcess& process, u32 core_id) {
- current_page_table = &process.PageTable().PageTableImpl();
+ current_page_table = &process.GetPageTable().PageTableImpl();
current_page_table->fastmem_arena = system.DeviceMemory().buffer.VirtualBasePointer();
- const std::size_t address_space_width = process.PageTable().GetAddressSpaceWidth();
+ const std::size_t address_space_width = process.GetPageTable().GetAddressSpaceWidth();
system.ArmInterface(core_id).PageTableChanged(*current_page_table, address_space_width);
}
@@ -183,13 +183,13 @@ struct Memory::Impl {
return string;
}
- void WalkBlock(const Kernel::KProcess& process, const Common::ProcessAddress addr,
- const std::size_t size, auto on_unmapped, auto on_memory, auto on_rasterizer,
- auto increment) {
- const auto& page_table = process.PageTable().PageTableImpl();
+ bool WalkBlock(const Common::ProcessAddress addr, const std::size_t size, auto on_unmapped,
+ auto on_memory, auto on_rasterizer, auto increment) {
+ const auto& page_table = system.ApplicationProcess()->GetPageTable().PageTableImpl();
std::size_t remaining_size = size;
std::size_t page_index = addr >> YUZU_PAGEBITS;
std::size_t page_offset = addr & YUZU_PAGEMASK;
+ bool user_accessible = true;
while (remaining_size) {
const std::size_t copy_amount =
@@ -200,6 +200,7 @@ struct Memory::Impl {
const auto [pointer, type] = page_table.pointers[page_index].PointerType();
switch (type) {
case Common::PageType::Unmapped: {
+ user_accessible = false;
on_unmapped(copy_amount, current_vaddr);
break;
}
@@ -228,13 +229,15 @@ struct Memory::Impl {
increment(copy_amount);
remaining_size -= copy_amount;
}
+
+ return user_accessible;
}
template <bool UNSAFE>
- void ReadBlockImpl(const Kernel::KProcess& process, const Common::ProcessAddress src_addr,
- void* dest_buffer, const std::size_t size) {
- WalkBlock(
- process, src_addr, size,
+ bool ReadBlockImpl(const Common::ProcessAddress src_addr, void* dest_buffer,
+ const std::size_t size) {
+ return WalkBlock(
+ src_addr, size,
[src_addr, size, &dest_buffer](const std::size_t copy_amount,
const Common::ProcessAddress current_vaddr) {
LOG_ERROR(HW_Memory,
@@ -257,14 +260,14 @@ struct Memory::Impl {
});
}
- void ReadBlock(const Common::ProcessAddress src_addr, void* dest_buffer,
+ bool ReadBlock(const Common::ProcessAddress src_addr, void* dest_buffer,
const std::size_t size) {
- ReadBlockImpl<false>(*system.ApplicationProcess(), src_addr, dest_buffer, size);
+ return ReadBlockImpl<false>(src_addr, dest_buffer, size);
}
- void ReadBlockUnsafe(const Common::ProcessAddress src_addr, void* dest_buffer,
+ bool ReadBlockUnsafe(const Common::ProcessAddress src_addr, void* dest_buffer,
const std::size_t size) {
- ReadBlockImpl<true>(*system.ApplicationProcess(), src_addr, dest_buffer, size);
+ return ReadBlockImpl<true>(src_addr, dest_buffer, size);
}
const u8* GetSpan(const VAddr src_addr, const std::size_t size) const {
@@ -284,10 +287,10 @@ struct Memory::Impl {
}
template <bool UNSAFE>
- void WriteBlockImpl(const Kernel::KProcess& process, const Common::ProcessAddress dest_addr,
- const void* src_buffer, const std::size_t size) {
- WalkBlock(
- process, dest_addr, size,
+ bool WriteBlockImpl(const Common::ProcessAddress dest_addr, const void* src_buffer,
+ const std::size_t size) {
+ return WalkBlock(
+ dest_addr, size,
[dest_addr, size](const std::size_t copy_amount,
const Common::ProcessAddress current_vaddr) {
LOG_ERROR(HW_Memory,
@@ -309,20 +312,19 @@ struct Memory::Impl {
});
}
- void WriteBlock(const Common::ProcessAddress dest_addr, const void* src_buffer,
+ bool WriteBlock(const Common::ProcessAddress dest_addr, const void* src_buffer,
const std::size_t size) {
- WriteBlockImpl<false>(*system.ApplicationProcess(), dest_addr, src_buffer, size);
+ return WriteBlockImpl<false>(dest_addr, src_buffer, size);
}
- void WriteBlockUnsafe(const Common::ProcessAddress dest_addr, const void* src_buffer,
+ bool WriteBlockUnsafe(const Common::ProcessAddress dest_addr, const void* src_buffer,
const std::size_t size) {
- WriteBlockImpl<true>(*system.ApplicationProcess(), dest_addr, src_buffer, size);
+ return WriteBlockImpl<true>(dest_addr, src_buffer, size);
}
- void ZeroBlock(const Kernel::KProcess& process, const Common::ProcessAddress dest_addr,
- const std::size_t size) {
- WalkBlock(
- process, dest_addr, size,
+ bool ZeroBlock(const Common::ProcessAddress dest_addr, const std::size_t size) {
+ return WalkBlock(
+ dest_addr, size,
[dest_addr, size](const std::size_t copy_amount,
const Common::ProcessAddress current_vaddr) {
LOG_ERROR(HW_Memory,
@@ -340,23 +342,23 @@ struct Memory::Impl {
[](const std::size_t copy_amount) {});
}
- void CopyBlock(const Kernel::KProcess& process, Common::ProcessAddress dest_addr,
- Common::ProcessAddress src_addr, const std::size_t size) {
- WalkBlock(
- process, dest_addr, size,
+ bool CopyBlock(Common::ProcessAddress dest_addr, Common::ProcessAddress src_addr,
+ const std::size_t size) {
+ return WalkBlock(
+ dest_addr, size,
[&](const std::size_t copy_amount, const Common::ProcessAddress current_vaddr) {
LOG_ERROR(HW_Memory,
"Unmapped CopyBlock @ 0x{:016X} (start address = 0x{:016X}, size = {})",
GetInteger(current_vaddr), GetInteger(src_addr), size);
- ZeroBlock(process, dest_addr, copy_amount);
+ ZeroBlock(dest_addr, copy_amount);
},
[&](const std::size_t copy_amount, const u8* const src_ptr) {
- WriteBlockImpl<false>(process, dest_addr, src_ptr, copy_amount);
+ WriteBlockImpl<false>(dest_addr, src_ptr, copy_amount);
},
[&](const Common::ProcessAddress current_vaddr, const std::size_t copy_amount,
u8* const host_ptr) {
HandleRasterizerDownload(GetInteger(current_vaddr), copy_amount);
- WriteBlockImpl<false>(process, dest_addr, host_ptr, copy_amount);
+ WriteBlockImpl<false>(dest_addr, host_ptr, copy_amount);
},
[&](const std::size_t copy_amount) {
dest_addr += copy_amount;
@@ -365,13 +367,13 @@ struct Memory::Impl {
}
template <typename Callback>
- Result PerformCacheOperation(const Kernel::KProcess& process, Common::ProcessAddress dest_addr,
- std::size_t size, Callback&& cb) {
+ Result PerformCacheOperation(Common::ProcessAddress dest_addr, std::size_t size,
+ Callback&& cb) {
class InvalidMemoryException : public std::exception {};
try {
WalkBlock(
- process, dest_addr, size,
+ dest_addr, size,
[&](const std::size_t block_size, const Common::ProcessAddress current_vaddr) {
LOG_ERROR(HW_Memory, "Unmapped cache maintenance @ {:#018X}",
GetInteger(current_vaddr));
@@ -388,37 +390,34 @@ struct Memory::Impl {
return ResultSuccess;
}
- Result InvalidateDataCache(const Kernel::KProcess& process, Common::ProcessAddress dest_addr,
- std::size_t size) {
+ Result InvalidateDataCache(Common::ProcessAddress dest_addr, std::size_t size) {
auto on_rasterizer = [&](const Common::ProcessAddress current_vaddr,
const std::size_t block_size) {
// dc ivac: Invalidate to point of coherency
// GPU flush -> CPU invalidate
HandleRasterizerDownload(GetInteger(current_vaddr), block_size);
};
- return PerformCacheOperation(process, dest_addr, size, on_rasterizer);
+ return PerformCacheOperation(dest_addr, size, on_rasterizer);
}
- Result StoreDataCache(const Kernel::KProcess& process, Common::ProcessAddress dest_addr,
- std::size_t size) {
+ Result StoreDataCache(Common::ProcessAddress dest_addr, std::size_t size) {
auto on_rasterizer = [&](const Common::ProcessAddress current_vaddr,
const std::size_t block_size) {
// dc cvac: Store to point of coherency
// CPU flush -> GPU invalidate
system.GPU().InvalidateRegion(GetInteger(current_vaddr), block_size);
};
- return PerformCacheOperation(process, dest_addr, size, on_rasterizer);
+ return PerformCacheOperation(dest_addr, size, on_rasterizer);
}
- Result FlushDataCache(const Kernel::KProcess& process, Common::ProcessAddress dest_addr,
- std::size_t size) {
+ Result FlushDataCache(Common::ProcessAddress dest_addr, std::size_t size) {
auto on_rasterizer = [&](const Common::ProcessAddress current_vaddr,
const std::size_t block_size) {
// dc civac: Store to point of coherency, and invalidate from cache
// CPU flush -> GPU invalidate
system.GPU().InvalidateRegion(GetInteger(current_vaddr), block_size);
};
- return PerformCacheOperation(process, dest_addr, size, on_rasterizer);
+ return PerformCacheOperation(dest_addr, size, on_rasterizer);
}
void MarkRegionDebug(u64 vaddr, u64 size, bool debug) {
@@ -812,7 +811,7 @@ void Memory::UnmapRegion(Common::PageTable& page_table, Common::ProcessAddress b
bool Memory::IsValidVirtualAddress(const Common::ProcessAddress vaddr) const {
const Kernel::KProcess& process = *system.ApplicationProcess();
- const auto& page_table = process.PageTable().PageTableImpl();
+ const auto& page_table = process.GetPageTable().PageTableImpl();
const size_t page = vaddr >> YUZU_PAGEBITS;
if (page >= page_table.pointers.size()) {
return false;
@@ -903,14 +902,14 @@ std::string Memory::ReadCString(Common::ProcessAddress vaddr, std::size_t max_le
return impl->ReadCString(vaddr, max_length);
}
-void Memory::ReadBlock(const Common::ProcessAddress src_addr, void* dest_buffer,
+bool Memory::ReadBlock(const Common::ProcessAddress src_addr, void* dest_buffer,
const std::size_t size) {
- impl->ReadBlock(src_addr, dest_buffer, size);
+ return impl->ReadBlock(src_addr, dest_buffer, size);
}
-void Memory::ReadBlockUnsafe(const Common::ProcessAddress src_addr, void* dest_buffer,
+bool Memory::ReadBlockUnsafe(const Common::ProcessAddress src_addr, void* dest_buffer,
const std::size_t size) {
- impl->ReadBlockUnsafe(src_addr, dest_buffer, size);
+ return impl->ReadBlockUnsafe(src_addr, dest_buffer, size);
}
const u8* Memory::GetSpan(const VAddr src_addr, const std::size_t size) const {
@@ -921,23 +920,23 @@ u8* Memory::GetSpan(const VAddr src_addr, const std::size_t size) {
return impl->GetSpan(src_addr, size);
}
-void Memory::WriteBlock(const Common::ProcessAddress dest_addr, const void* src_buffer,
+bool Memory::WriteBlock(const Common::ProcessAddress dest_addr, const void* src_buffer,
const std::size_t size) {
- impl->WriteBlock(dest_addr, src_buffer, size);
+ return impl->WriteBlock(dest_addr, src_buffer, size);
}
-void Memory::WriteBlockUnsafe(const Common::ProcessAddress dest_addr, const void* src_buffer,
+bool Memory::WriteBlockUnsafe(const Common::ProcessAddress dest_addr, const void* src_buffer,
const std::size_t size) {
- impl->WriteBlockUnsafe(dest_addr, src_buffer, size);
+ return impl->WriteBlockUnsafe(dest_addr, src_buffer, size);
}
-void Memory::CopyBlock(Common::ProcessAddress dest_addr, Common::ProcessAddress src_addr,
+bool Memory::CopyBlock(Common::ProcessAddress dest_addr, Common::ProcessAddress src_addr,
const std::size_t size) {
- impl->CopyBlock(*system.ApplicationProcess(), dest_addr, src_addr, size);
+ return impl->CopyBlock(dest_addr, src_addr, size);
}
-void Memory::ZeroBlock(Common::ProcessAddress dest_addr, const std::size_t size) {
- impl->ZeroBlock(*system.ApplicationProcess(), dest_addr, size);
+bool Memory::ZeroBlock(Common::ProcessAddress dest_addr, const std::size_t size) {
+ return impl->ZeroBlock(dest_addr, size);
}
void Memory::SetGPUDirtyManagers(std::span<Core::GPUDirtyMemoryManager> managers) {
@@ -945,15 +944,15 @@ void Memory::SetGPUDirtyManagers(std::span<Core::GPUDirtyMemoryManager> managers
}
Result Memory::InvalidateDataCache(Common::ProcessAddress dest_addr, const std::size_t size) {
- return impl->InvalidateDataCache(*system.ApplicationProcess(), dest_addr, size);
+ return impl->InvalidateDataCache(dest_addr, size);
}
Result Memory::StoreDataCache(Common::ProcessAddress dest_addr, const std::size_t size) {
- return impl->StoreDataCache(*system.ApplicationProcess(), dest_addr, size);
+ return impl->StoreDataCache(dest_addr, size);
}
Result Memory::FlushDataCache(Common::ProcessAddress dest_addr, const std::size_t size) {
- return impl->FlushDataCache(*system.ApplicationProcess(), dest_addr, size);
+ return impl->FlushDataCache(dest_addr, size);
}
void Memory::RasterizerMarkRegionCached(Common::ProcessAddress vaddr, u64 size, bool cached) {
diff --git a/src/core/memory.h b/src/core/memory.h
index ea33c769c..2eb61ffd3 100644
--- a/src/core/memory.h
+++ b/src/core/memory.h
@@ -24,7 +24,6 @@ class GPUDirtyMemoryManager;
} // namespace Core
namespace Kernel {
-class PhysicalMemory;
class KProcess;
} // namespace Kernel
@@ -330,7 +329,7 @@ public:
* @post The range [dest_buffer, size) contains the read bytes from the
* current process' address space.
*/
- void ReadBlock(Common::ProcessAddress src_addr, void* dest_buffer, std::size_t size);
+ bool ReadBlock(Common::ProcessAddress src_addr, void* dest_buffer, std::size_t size);
/**
* Reads a contiguous block of bytes from the current process' address space.
@@ -349,7 +348,7 @@ public:
* @post The range [dest_buffer, size) contains the read bytes from the
* current process' address space.
*/
- void ReadBlockUnsafe(Common::ProcessAddress src_addr, void* dest_buffer, std::size_t size);
+ bool ReadBlockUnsafe(Common::ProcessAddress src_addr, void* dest_buffer, std::size_t size);
const u8* GetSpan(const VAddr src_addr, const std::size_t size) const;
u8* GetSpan(const VAddr src_addr, const std::size_t size);
@@ -373,7 +372,7 @@ public:
* and will mark that region as invalidated to caches that the active
* graphics backend may be maintaining over the course of execution.
*/
- void WriteBlock(Common::ProcessAddress dest_addr, const void* src_buffer, std::size_t size);
+ bool WriteBlock(Common::ProcessAddress dest_addr, const void* src_buffer, std::size_t size);
/**
* Writes a range of bytes into the current process' address space at the specified
@@ -391,7 +390,7 @@ public:
* will be ignored and an error will be logged.
*
*/
- void WriteBlockUnsafe(Common::ProcessAddress dest_addr, const void* src_buffer,
+ bool WriteBlockUnsafe(Common::ProcessAddress dest_addr, const void* src_buffer,
std::size_t size);
/**
@@ -405,7 +404,7 @@ public:
* @post The range [dest_addr, size) within the process' address space contains the
* same data within the range [src_addr, size).
*/
- void CopyBlock(Common::ProcessAddress dest_addr, Common::ProcessAddress src_addr,
+ bool CopyBlock(Common::ProcessAddress dest_addr, Common::ProcessAddress src_addr,
std::size_t size);
/**
@@ -418,7 +417,7 @@ public:
* @post The range [dest_addr, size) within the process' address space contains the
* value 0.
*/
- void ZeroBlock(Common::ProcessAddress dest_addr, std::size_t size);
+ bool ZeroBlock(Common::ProcessAddress dest_addr, std::size_t size);
/**
* Invalidates a range of bytes within the current process' address space at the specified
diff --git a/src/core/memory/cheat_engine.cpp b/src/core/memory/cheat_engine.cpp
index 8742dd164..7b52f61a7 100644
--- a/src/core/memory/cheat_engine.cpp
+++ b/src/core/memory/cheat_engine.cpp
@@ -199,7 +199,7 @@ void CheatEngine::Initialize() {
metadata.process_id = system.ApplicationProcess()->GetProcessId();
metadata.title_id = system.GetApplicationProcessProgramID();
- const auto& page_table = system.ApplicationProcess()->PageTable();
+ const auto& page_table = system.ApplicationProcess()->GetPageTable();
metadata.heap_extents = {
.base = GetInteger(page_table.GetHeapRegionStart()),
.size = page_table.GetHeapRegionSize(),
diff --git a/src/core/reporter.cpp b/src/core/reporter.cpp
index 6c3dc7369..b5b3e7eda 100644
--- a/src/core/reporter.cpp
+++ b/src/core/reporter.cpp
@@ -117,8 +117,8 @@ json GetProcessorStateDataAuto(Core::System& system) {
arm.SaveContext(context);
return GetProcessorStateData(process->Is64BitProcess() ? "AArch64" : "AArch32",
- GetInteger(process->PageTable().GetCodeRegionStart()), context.sp,
- context.pc, context.pstate, context.cpu_registers);
+ GetInteger(process->GetPageTable().GetCodeRegionStart()),
+ context.sp, context.pc, context.pstate, context.cpu_registers);
}
json GetBacktraceData(Core::System& system) {
diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index d0b145860..07cabca43 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -14,12 +14,12 @@
//
#include <deque>
+#include <map>
#include <span>
+#include <unordered_map>
#include <variant>
#include <vector>
-#include <boost/container/flat_map.hpp>
-
#include "shader_recompiler/frontend/ir/basic_block.h"
#include "shader_recompiler/frontend/ir/opcodes.h"
#include "shader_recompiler/frontend/ir/pred.h"
@@ -52,7 +52,7 @@ struct IndirectBranchVariable {
using Variant = std::variant<IR::Reg, IR::Pred, ZeroFlagTag, SignFlagTag, CarryFlagTag,
OverflowFlagTag, GotoVariable, IndirectBranchVariable>;
-using ValueMap = boost::container::flat_map<IR::Block*, IR::Value>;
+using ValueMap = std::unordered_map<IR::Block*, IR::Value>;
struct DefTable {
const IR::Value& Def(IR::Block* block, IR::Reg variable) {
@@ -112,7 +112,7 @@ struct DefTable {
}
std::array<ValueMap, IR::NUM_USER_PREDS> preds;
- boost::container::flat_map<u32, ValueMap> goto_vars;
+ std::unordered_map<u32, ValueMap> goto_vars;
ValueMap indirect_branch_var;
ValueMap zero_flag;
ValueMap sign_flag;
@@ -295,8 +295,7 @@ private:
return same;
}
- boost::container::flat_map<IR::Block*, boost::container::flat_map<Variant, IR::Inst*>>
- incomplete_phis;
+ std::unordered_map<IR::Block*, std::map<Variant, IR::Inst*>> incomplete_phis;
DefTable current_def;
};
diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt
index 3b2fe01da..7f79111e0 100644
--- a/src/video_core/CMakeLists.txt
+++ b/src/video_core/CMakeLists.txt
@@ -274,6 +274,7 @@ add_library(video_core STATIC
vulkan_common/vulkan_wrapper.h
vulkan_common/nsight_aftermath_tracker.cpp
vulkan_common/nsight_aftermath_tracker.h
+ vulkan_common/vma.cpp
)
create_target_directory_groups(video_core)
@@ -291,7 +292,7 @@ target_link_options(video_core PRIVATE ${FFmpeg_LDFLAGS})
add_dependencies(video_core host_shaders)
target_include_directories(video_core PRIVATE ${HOST_SHADERS_INCLUDE})
-target_link_libraries(video_core PRIVATE sirit Vulkan::Headers vma)
+target_link_libraries(video_core PRIVATE sirit Vulkan::Headers GPUOpen::VulkanMemoryAllocator)
if (ENABLE_NSIGHT_AFTERMATH)
if (NOT DEFINED ENV{NSIGHT_AFTERMATH_SDK})
@@ -324,6 +325,9 @@ else()
# xbyak
set_source_files_properties(macro/macro_jit_x64.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-shadow")
+
+ # VMA
+ set_source_files_properties(vulkan_common/vma.cpp PROPERTIES COMPILE_OPTIONS "-Wno-conversion;-Wno-unused-variable;-Wno-unused-parameter;-Wno-missing-field-initializers")
endif()
if (ARCHITECTURE_x86_64)
diff --git a/src/video_core/buffer_cache/buffer_cache.h b/src/video_core/buffer_cache/buffer_cache.h
index 6ed4b78f2..f0f450edb 100644
--- a/src/video_core/buffer_cache/buffer_cache.h
+++ b/src/video_core/buffer_cache/buffer_cache.h
@@ -442,6 +442,11 @@ void BufferCache<P>::UnbindComputeStorageBuffers() {
template <class P>
void BufferCache<P>::BindComputeStorageBuffer(size_t ssbo_index, u32 cbuf_index, u32 cbuf_offset,
bool is_written) {
+ if (ssbo_index >= channel_state->compute_storage_buffers.size()) [[unlikely]] {
+ LOG_ERROR(HW_GPU, "Storage buffer index {} exceeds maximum storage buffer count",
+ ssbo_index);
+ return;
+ }
channel_state->enabled_compute_storage_buffers |= 1U << ssbo_index;
channel_state->written_compute_storage_buffers |= (is_written ? 1U : 0U) << ssbo_index;
@@ -464,6 +469,11 @@ void BufferCache<P>::UnbindComputeTextureBuffers() {
template <class P>
void BufferCache<P>::BindComputeTextureBuffer(size_t tbo_index, GPUVAddr gpu_addr, u32 size,
PixelFormat format, bool is_written, bool is_image) {
+ if (tbo_index >= channel_state->compute_texture_buffers.size()) [[unlikely]] {
+ LOG_ERROR(HW_GPU, "Texture buffer index {} exceeds maximum texture buffer count",
+ tbo_index);
+ return;
+ }
channel_state->enabled_compute_texture_buffers |= 1U << tbo_index;
channel_state->written_compute_texture_buffers |= (is_written ? 1U : 0U) << tbo_index;
if constexpr (SEPARATE_IMAGE_BUFFERS_BINDINGS) {
diff --git a/src/video_core/buffer_cache/buffer_cache_base.h b/src/video_core/buffer_cache/buffer_cache_base.h
index 460fc7551..0b7135d49 100644
--- a/src/video_core/buffer_cache/buffer_cache_base.h
+++ b/src/video_core/buffer_cache/buffer_cache_base.h
@@ -67,7 +67,7 @@ constexpr u32 NUM_TRANSFORM_FEEDBACK_BUFFERS = 4;
constexpr u32 NUM_GRAPHICS_UNIFORM_BUFFERS = 18;
constexpr u32 NUM_COMPUTE_UNIFORM_BUFFERS = 8;
constexpr u32 NUM_STORAGE_BUFFERS = 16;
-constexpr u32 NUM_TEXTURE_BUFFERS = 16;
+constexpr u32 NUM_TEXTURE_BUFFERS = 32;
constexpr u32 NUM_STAGES = 5;
using UniformBufferSizes = std::array<std::array<u32, NUM_GRAPHICS_UNIFORM_BUFFERS>, NUM_STAGES>;
diff --git a/src/video_core/renderer_base.cpp b/src/video_core/renderer_base.cpp
index 2d3f58201..4002fa72b 100644
--- a/src/video_core/renderer_base.cpp
+++ b/src/video_core/renderer_base.cpp
@@ -38,8 +38,8 @@ void RendererBase::RequestScreenshot(void* data, std::function<void(bool)> callb
LOG_ERROR(Render, "A screenshot is already requested or in progress, ignoring the request");
return;
}
- auto async_callback{[callback = std::move(callback)](bool invert_y) {
- std::thread t{callback, invert_y};
+ auto async_callback{[callback_ = std::move(callback)](bool invert_y) {
+ std::thread t{callback_, invert_y};
t.detach();
}};
renderer_settings.screenshot_bits = data;
diff --git a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
index 23a48c6fe..71f720c63 100644
--- a/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
+++ b/src/video_core/renderer_opengl/gl_graphics_pipeline.cpp
@@ -231,24 +231,25 @@ GraphicsPipeline::GraphicsPipeline(const Device& device, TextureCache& texture_c
}
const bool in_parallel = thread_worker != nullptr;
const auto backend = device.GetShaderBackend();
- auto func{[this, sources = std::move(sources), sources_spirv = std::move(sources_spirv),
+ auto func{[this, sources_ = std::move(sources), sources_spirv_ = std::move(sources_spirv),
shader_notify, backend, in_parallel,
force_context_flush](ShaderContext::Context*) mutable {
for (size_t stage = 0; stage < 5; ++stage) {
switch (backend) {
case Settings::ShaderBackend::GLSL:
- if (!sources[stage].empty()) {
- source_programs[stage] = CreateProgram(sources[stage], Stage(stage));
+ if (!sources_[stage].empty()) {
+ source_programs[stage] = CreateProgram(sources_[stage], Stage(stage));
}
break;
case Settings::ShaderBackend::GLASM:
- if (!sources[stage].empty()) {
- assembly_programs[stage] = CompileProgram(sources[stage], AssemblyStage(stage));
+ if (!sources_[stage].empty()) {
+ assembly_programs[stage] =
+ CompileProgram(sources_[stage], AssemblyStage(stage));
}
break;
case Settings::ShaderBackend::SPIRV:
- if (!sources_spirv[stage].empty()) {
- source_programs[stage] = CreateProgram(sources_spirv[stage], Stage(stage));
+ if (!sources_spirv_[stage].empty()) {
+ source_programs[stage] = CreateProgram(sources_spirv_[stage], Stage(stage));
}
break;
}
diff --git a/src/video_core/renderer_opengl/gl_shader_cache.cpp b/src/video_core/renderer_opengl/gl_shader_cache.cpp
index 0329ed820..7e1d7f92e 100644
--- a/src/video_core/renderer_opengl/gl_shader_cache.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_cache.cpp
@@ -288,9 +288,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading,
const auto load_compute{[&](std::ifstream& file, FileEnvironment env) {
ComputePipelineKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- queue_work([this, key, env = std::move(env), &state, &callback](Context* ctx) mutable {
+ queue_work([this, key, env_ = std::move(env), &state, &callback](Context* ctx) mutable {
ctx->pools.ReleaseContents();
- auto pipeline{CreateComputePipeline(ctx->pools, key, env, true)};
+ auto pipeline{CreateComputePipeline(ctx->pools, key, env_, true)};
std::scoped_lock lock{state.mutex};
if (pipeline) {
compute_cache.emplace(key, std::move(pipeline));
@@ -305,9 +305,9 @@ void ShaderCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading,
const auto load_graphics{[&](std::ifstream& file, std::vector<FileEnvironment> envs) {
GraphicsPipelineKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- queue_work([this, key, envs = std::move(envs), &state, &callback](Context* ctx) mutable {
+ queue_work([this, key, envs_ = std::move(envs), &state, &callback](Context* ctx) mutable {
boost::container::static_vector<Shader::Environment*, 5> env_ptrs;
- for (auto& env : envs) {
+ for (auto& env : envs_) {
env_ptrs.push_back(&env);
}
ctx->pools.ReleaseContents();
diff --git a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
index 51df18ec3..f8cd2a5d8 100644
--- a/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_buffer_cache.cpp
@@ -206,8 +206,8 @@ public:
const size_t sub_first_offset = static_cast<size_t>(first % 4) * GetQuadsNum(num_indices);
const size_t offset =
(sub_first_offset + GetQuadsNum(first)) * 6ULL * BytesPerIndex(index_type);
- scheduler.Record([buffer = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindIndexBuffer(buffer, offset, index_type_);
+ scheduler.Record([buffer_ = *buffer, index_type_, offset](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindIndexBuffer(buffer_, offset, index_type_);
});
}
@@ -528,17 +528,18 @@ void BufferCacheRuntime::BindVertexBuffers(VideoCommon::HostBindings<Buffer>& bi
buffer_handles.push_back(handle);
}
if (device.IsExtExtendedDynamicStateSupported()) {
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindVertexBuffers2EXT(
- bindings.min_index, bindings.max_index - bindings.min_index, buffer_handles.data(),
- bindings.offsets.data(), bindings.sizes.data(), bindings.strides.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindVertexBuffers2EXT(bindings_.min_index,
+ bindings_.max_index - bindings_.min_index,
+ buffer_handles_.data(), bindings_.offsets.data(),
+ bindings_.sizes.data(), bindings_.strides.data());
});
} else {
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindVertexBuffers(bindings.min_index, bindings.max_index - bindings.min_index,
- buffer_handles.data(), bindings.offsets.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindVertexBuffers(bindings_.min_index, bindings_.max_index - bindings_.min_index,
+ buffer_handles_.data(), bindings_.offsets.data());
});
}
}
@@ -573,11 +574,11 @@ void BufferCacheRuntime::BindTransformFeedbackBuffers(VideoCommon::HostBindings<
for (u32 index = 0; index < bindings.buffers.size(); ++index) {
buffer_handles.push_back(bindings.buffers[index]->Handle());
}
- scheduler.Record([bindings = std::move(bindings),
- buffer_handles = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
- cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles.size()),
- buffer_handles.data(), bindings.offsets.data(),
- bindings.sizes.data());
+ scheduler.Record([bindings_ = std::move(bindings),
+ buffer_handles_ = std::move(buffer_handles)](vk::CommandBuffer cmdbuf) {
+ cmdbuf.BindTransformFeedbackBuffersEXT(0, static_cast<u32>(buffer_handles_.size()),
+ buffer_handles_.data(), bindings_.offsets.data(),
+ bindings_.sizes.data());
});
}
diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
index d600c4e61..4f84d8497 100644
--- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
@@ -469,9 +469,9 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading
ComputePipelineCacheKey key;
file.read(reinterpret_cast<char*>(&key), sizeof(key));
- workers.QueueWork([this, key, env = std::move(env), &state, &callback]() mutable {
+ workers.QueueWork([this, key, env_ = std::move(env), &state, &callback]() mutable {
ShaderPools pools;
- auto pipeline{CreateComputePipeline(pools, key, env, state.statistics.get(), false)};
+ auto pipeline{CreateComputePipeline(pools, key, env_, state.statistics.get(), false)};
std::scoped_lock lock{state.mutex};
if (pipeline) {
compute_cache.emplace(key, std::move(pipeline));
@@ -500,10 +500,10 @@ void PipelineCache::LoadDiskResources(u64 title_id, std::stop_token stop_loading
(key.state.dynamic_vertex_input != 0) != dynamic_features.has_dynamic_vertex_input) {
return;
}
- workers.QueueWork([this, key, envs = std::move(envs), &state, &callback]() mutable {
+ workers.QueueWork([this, key, envs_ = std::move(envs), &state, &callback]() mutable {
ShaderPools pools;
boost::container::static_vector<Shader::Environment*, 5> env_ptrs;
- for (auto& env : envs) {
+ for (auto& env : envs_) {
env_ptrs.push_back(&env);
}
auto pipeline{CreateGraphicsPipeline(pools, key, MakeSpan(env_ptrs),
@@ -702,8 +702,8 @@ std::unique_ptr<ComputePipeline> PipelineCache::CreateComputePipeline(
if (!pipeline || pipeline_cache_filename.empty()) {
return pipeline;
}
- serialization_thread.QueueWork([this, key, env = std::move(env)] {
- SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env},
+ serialization_thread.QueueWork([this, key, env_ = std::move(env)] {
+ SerializePipeline(key, std::array<const GenericEnvironment*, 1>{&env_},
pipeline_cache_filename, CACHE_VERSION);
});
return pipeline;
diff --git a/src/video_core/renderer_vulkan/vk_query_cache.cpp b/src/video_core/renderer_vulkan/vk_query_cache.cpp
index d67490449..29e0b797b 100644
--- a/src/video_core/renderer_vulkan/vk_query_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_query_cache.cpp
@@ -98,10 +98,10 @@ HostCounter::HostCounter(QueryCache& cache_, std::shared_ptr<HostCounter> depend
: HostCounterBase{std::move(dependency_)}, cache{cache_}, type{type_},
query{cache_.AllocateQuery(type_)}, tick{cache_.GetScheduler().CurrentTick()} {
const vk::Device* logical = &cache.GetDevice().GetLogical();
- cache.GetScheduler().Record([logical, query = query](vk::CommandBuffer cmdbuf) {
+ cache.GetScheduler().Record([logical, query_ = query](vk::CommandBuffer cmdbuf) {
const bool use_precise = Settings::IsGPULevelHigh();
- logical->ResetQueryPool(query.first, query.second, 1);
- cmdbuf.BeginQuery(query.first, query.second,
+ logical->ResetQueryPool(query_.first, query_.second, 1);
+ cmdbuf.BeginQuery(query_.first, query_.second,
use_precise ? VK_QUERY_CONTROL_PRECISE_BIT : 0);
});
}
@@ -111,8 +111,9 @@ HostCounter::~HostCounter() {
}
void HostCounter::EndQuery() {
- cache.GetScheduler().Record(
- [query = query](vk::CommandBuffer cmdbuf) { cmdbuf.EndQuery(query.first, query.second); });
+ cache.GetScheduler().Record([query_ = query](vk::CommandBuffer cmdbuf) {
+ cmdbuf.EndQuery(query_.first, query_.second);
+ });
}
u64 HostCounter::BlockingQuery(bool async) const {
diff --git a/src/video_core/renderer_vulkan/vk_texture_cache.cpp b/src/video_core/renderer_vulkan/vk_texture_cache.cpp
index 3aac3cfab..bf6ad6c79 100644
--- a/src/video_core/renderer_vulkan/vk_texture_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_texture_cache.cpp
@@ -1412,7 +1412,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
}
scheduler->RequestOutsideRenderPassOperationContext();
scheduler->Record([buffers = std::move(buffers_vector), image = *original_image,
- aspect_mask = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) {
+ aspect_mask_ = aspect_mask, vk_copies](vk::CommandBuffer cmdbuf) {
const VkImageMemoryBarrier read_barrier{
.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER,
.pNext = nullptr,
@@ -1424,7 +1424,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
.image = image,
.subresourceRange{
- .aspectMask = aspect_mask,
+ .aspectMask = aspect_mask_,
.baseMipLevel = 0,
.levelCount = VK_REMAINING_MIP_LEVELS,
.baseArrayLayer = 0,
@@ -1456,7 +1456,7 @@ void Image::DownloadMemory(std::span<VkBuffer> buffers_span, std::span<VkDeviceS
.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
.image = image,
.subresourceRange{
- .aspectMask = aspect_mask,
+ .aspectMask = aspect_mask_,
.baseMipLevel = 0,
.levelCount = VK_REMAINING_MIP_LEVELS,
.baseArrayLayer = 0,
diff --git a/src/video_core/vulkan_common/vma.cpp b/src/video_core/vulkan_common/vma.cpp
new file mode 100644
index 000000000..1fe2cf52b
--- /dev/null
+++ b/src/video_core/vulkan_common/vma.cpp
@@ -0,0 +1,8 @@
+// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#define VMA_IMPLEMENTATION
+#define VMA_STATIC_VULKAN_FUNCTIONS 0
+#define VMA_DYNAMIC_VULKAN_FUNCTIONS 1
+
+#include <vk_mem_alloc.h> \ No newline at end of file
diff --git a/src/web_service/announce_room_json.cpp b/src/web_service/announce_room_json.cpp
index 4c3195efd..f1020a5b8 100644
--- a/src/web_service/announce_room_json.cpp
+++ b/src/web_service/announce_room_json.cpp
@@ -135,11 +135,11 @@ void RoomJson::Delete() {
LOG_ERROR(WebService, "Room must be registered to be deleted");
return;
}
- Common::DetachedTasks::AddTask(
- [host{this->host}, username{this->username}, token{this->token}, room_id{this->room_id}]() {
- // create a new client here because the this->client might be destroyed.
- Client{host, username, token}.DeleteJson(fmt::format("/lobby/{}", room_id), "", false);
- });
+ Common::DetachedTasks::AddTask([host_{this->host}, username_{this->username},
+ token_{this->token}, room_id_{this->room_id}]() {
+ // create a new client here because the this->client might be destroyed.
+ Client{host_, username_, token_}.DeleteJson(fmt::format("/lobby/{}", room_id_), "", false);
+ });
}
} // namespace WebService
diff --git a/src/yuzu/game_list_worker.cpp b/src/yuzu/game_list_worker.cpp
index 63326968b..9404365b4 100644
--- a/src/yuzu/game_list_worker.cpp
+++ b/src/yuzu/game_list_worker.cpp
@@ -235,7 +235,7 @@ GameListWorker::~GameListWorker() = default;
void GameListWorker::AddTitlesToGameList(GameListDir* parent_dir) {
using namespace FileSys;
- const auto& cache = dynamic_cast<ContentProviderUnion&>(system.GetContentProvider());
+ const auto& cache = system.GetContentProviderUnion();
auto installed_games = cache.ListEntriesFilterOrigin(std::nullopt, TitleType::Application,
ContentRecordType::Program);
@@ -265,7 +265,11 @@ void GameListWorker::AddTitlesToGameList(GameListDir* parent_dir) {
std::vector<u8> icon;
std::string name;
u64 program_id = 0;
- loader->ReadProgramId(program_id);
+ const auto result = loader->ReadProgramId(program_id);
+
+ if (result != Loader::ResultStatus::Success) {
+ continue;
+ }
const PatchManager patch{program_id, system.GetFileSystemController(),
system.GetContentProvider()};