summaryrefslogblamecommitdiffstats
path: root/src/core/hle/kernel/address_arbiter.cpp
blob: 048acd30e90d2e991be4779c7181a3f17ad051ad (plain) (tree)
1
2
3
4
5
6
7
8
9



                                            


                    
                          
                                
                                       
                      
                                            
                                   

                                         
                                      
                                   
                                         
                            

                        
                  
 
                                                    

                                                                                             

                                                                                          
                                              
                          
                                                                     
     

                                  
                                            
                                                                               
                                         
                                                      



                                             
                                                                        

                                            














                                                                                          
                                        

                                                                



                                              

                                                                                       
                                        

                                   
                                               
                                                 


                                         



                                                               
                                                                       
 
                                                       



                                                                              
 
                                                     

 

                                                                                                  
                                        

                                   
                                               
                                                 



                                          

                                                                
 

                                                               
                      
        
                                                                       


                                     
         






                                                                       
                






                                                                                 
         
                                                                              
 



                                              













                                                                                         

                                                                                          
                                   

                                                                          
 



                                                                                  




                                                     





                                                     




                                                                     
 

                                                                           




                                                                   
                                                                                             













                                                                                                 



                                                         
     
 


                                                       

     





                                                        

     
                                                

 
                                                                                         
                                   

                                                                          
 



                                                                                  




                                                     





                                                     
















                                                                           
     
 


                                                       

     






                                                        
 
                                                

 


                                                                            









                                                                                             
     
 
 


                                                                            



                                                                                             


                                     
 
 




                                                                                
     


                                                                         

                     
// Copyright 2018 yuzu emulator team
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.

#include <algorithm>
#include <vector>

#include "common/assert.h"
#include "common/common_types.h"
#include "core/arm/exclusive_monitor.h"
#include "core/core.h"
#include "core/hle/kernel/address_arbiter.h"
#include "core/hle/kernel/errors.h"
#include "core/hle/kernel/handle_table.h"
#include "core/hle/kernel/kernel.h"
#include "core/hle/kernel/scheduler.h"
#include "core/hle/kernel/thread.h"
#include "core/hle/kernel/time_manager.h"
#include "core/hle/result.h"
#include "core/memory.h"

namespace Kernel {

// Wake up num_to_wake (or all) threads in a vector.
void AddressArbiter::WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads,
                                 s32 num_to_wake) {
    // Only process up to 'target' threads, unless 'target' is <= 0, in which case process
    // them all.
    std::size_t last = waiting_threads.size();
    if (num_to_wake > 0) {
        last = std::min(last, static_cast<std::size_t>(num_to_wake));
    }

    // Signal the waiting threads.
    for (std::size_t i = 0; i < last; i++) {
        waiting_threads[i]->SetSynchronizationResults(nullptr, RESULT_SUCCESS);
        RemoveThread(waiting_threads[i]);
        waiting_threads[i]->WaitForArbitration(false);
        waiting_threads[i]->ResumeFromWait();
    }
}

AddressArbiter::AddressArbiter(Core::System& system) : system{system} {}
AddressArbiter::~AddressArbiter() = default;

ResultCode AddressArbiter::SignalToAddress(VAddr address, SignalType type, s32 value,
                                           s32 num_to_wake) {
    switch (type) {
    case SignalType::Signal:
        return SignalToAddressOnly(address, num_to_wake);
    case SignalType::IncrementAndSignalIfEqual:
        return IncrementAndSignalToAddressIfEqual(address, value, num_to_wake);
    case SignalType::ModifyByWaitingCountAndSignalIfEqual:
        return ModifyByWaitingCountAndSignalToAddressIfEqual(address, value, num_to_wake);
    default:
        return ERR_INVALID_ENUM_VALUE;
    }
}

ResultCode AddressArbiter::SignalToAddressOnly(VAddr address, s32 num_to_wake) {
    SchedulerLock lock(system.Kernel());
    const std::vector<std::shared_ptr<Thread>> waiting_threads =
        GetThreadsWaitingOnAddress(address);
    WakeThreads(waiting_threads, num_to_wake);
    return RESULT_SUCCESS;
}

ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value,
                                                              s32 num_to_wake) {
    SchedulerLock lock(system.Kernel());
    auto& memory = system.Memory();

    // Ensure that we can write to the address.
    if (!memory.IsValidVirtualAddress(address)) {
        return ERR_INVALID_ADDRESS_STATE;
    }

    const std::size_t current_core = system.CurrentCoreIndex();
    auto& monitor = system.Monitor();
    u32 current_value;
    do {
        current_value = monitor.ExclusiveRead32(current_core, address);

        if (current_value != static_cast<u32>(value)) {
            return ERR_INVALID_STATE;
        }
        current_value++;
    } while (!monitor.ExclusiveWrite32(current_core, address, current_value));

    return SignalToAddressOnly(address, num_to_wake);
}

ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
                                                                         s32 num_to_wake) {
    SchedulerLock lock(system.Kernel());
    auto& memory = system.Memory();

    // Ensure that we can write to the address.
    if (!memory.IsValidVirtualAddress(address)) {
        return ERR_INVALID_ADDRESS_STATE;
    }

    // Get threads waiting on the address.
    const std::vector<std::shared_ptr<Thread>> waiting_threads =
        GetThreadsWaitingOnAddress(address);

    const std::size_t current_core = system.CurrentCoreIndex();
    auto& monitor = system.Monitor();
    s32 updated_value;
    do {
        updated_value = monitor.ExclusiveRead32(current_core, address);

        if (updated_value != value) {
            return ERR_INVALID_STATE;
        }
        // Determine the modified value depending on the waiting count.
        if (num_to_wake <= 0) {
            if (waiting_threads.empty()) {
                updated_value = value + 1;
            } else {
                updated_value = value - 1;
            }
        } else {
            if (waiting_threads.empty()) {
                updated_value = value + 1;
            } else if (waiting_threads.size() <= static_cast<u32>(num_to_wake)) {
                updated_value = value - 1;
            } else {
                updated_value = value;
            }
        }
    } while (!monitor.ExclusiveWrite32(current_core, address, updated_value));

    WakeThreads(waiting_threads, num_to_wake);
    return RESULT_SUCCESS;
}

ResultCode AddressArbiter::WaitForAddress(VAddr address, ArbitrationType type, s32 value,
                                          s64 timeout_ns) {
    switch (type) {
    case ArbitrationType::WaitIfLessThan:
        return WaitForAddressIfLessThan(address, value, timeout_ns, false);
    case ArbitrationType::DecrementAndWaitIfLessThan:
        return WaitForAddressIfLessThan(address, value, timeout_ns, true);
    case ArbitrationType::WaitIfEqual:
        return WaitForAddressIfEqual(address, value, timeout_ns);
    default:
        return ERR_INVALID_ENUM_VALUE;
    }
}

ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout,
                                                    bool should_decrement) {
    auto& memory = system.Memory();
    auto& kernel = system.Kernel();
    Thread* current_thread = system.CurrentScheduler().GetCurrentThread();

    Handle event_handle = InvalidHandle;
    {
        SchedulerLockAndSleep lock(kernel, event_handle, current_thread, timeout);

        if (current_thread->IsPendingTermination()) {
            lock.CancelSleep();
            return ERR_THREAD_TERMINATING;
        }

        // Ensure that we can read the address.
        if (!memory.IsValidVirtualAddress(address)) {
            lock.CancelSleep();
            return ERR_INVALID_ADDRESS_STATE;
        }

        s32 current_value = static_cast<s32>(memory.Read32(address));
        if (current_value >= value) {
            lock.CancelSleep();
            return ERR_INVALID_STATE;
        }

        current_thread->SetSynchronizationResults(nullptr, RESULT_TIMEOUT);

        s32 decrement_value;

        const std::size_t current_core = system.CurrentCoreIndex();
        auto& monitor = system.Monitor();
        do {
            current_value = static_cast<s32>(monitor.ExclusiveRead32(current_core, address));
            if (should_decrement) {
                decrement_value = current_value - 1;
            } else {
                decrement_value = current_value;
            }
        } while (
            !monitor.ExclusiveWrite32(current_core, address, static_cast<u32>(decrement_value)));

        // Short-circuit without rescheduling, if timeout is zero.
        if (timeout == 0) {
            lock.CancelSleep();
            return RESULT_TIMEOUT;
        }

        current_thread->SetArbiterWaitAddress(address);
        InsertThread(SharedFrom(current_thread));
        current_thread->SetStatus(ThreadStatus::WaitArb);
        current_thread->WaitForArbitration(true);
    }

    if (event_handle != InvalidHandle) {
        auto& time_manager = kernel.TimeManager();
        time_manager.UnscheduleTimeEvent(event_handle);
    }

    {
        SchedulerLock lock(kernel);
        if (current_thread->IsWaitingForArbitration()) {
            RemoveThread(SharedFrom(current_thread));
            current_thread->WaitForArbitration(false);
        }
    }

    return current_thread->GetSignalingResult();
}

ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
    auto& memory = system.Memory();
    auto& kernel = system.Kernel();
    Thread* current_thread = system.CurrentScheduler().GetCurrentThread();

    Handle event_handle = InvalidHandle;
    {
        SchedulerLockAndSleep lock(kernel, event_handle, current_thread, timeout);

        if (current_thread->IsPendingTermination()) {
            lock.CancelSleep();
            return ERR_THREAD_TERMINATING;
        }

        // Ensure that we can read the address.
        if (!memory.IsValidVirtualAddress(address)) {
            lock.CancelSleep();
            return ERR_INVALID_ADDRESS_STATE;
        }

        s32 current_value = static_cast<s32>(memory.Read32(address));
        if (current_value != value) {
            lock.CancelSleep();
            return ERR_INVALID_STATE;
        }

        // Short-circuit without rescheduling, if timeout is zero.
        if (timeout == 0) {
            lock.CancelSleep();
            return RESULT_TIMEOUT;
        }

        current_thread->SetSynchronizationResults(nullptr, RESULT_TIMEOUT);
        current_thread->SetArbiterWaitAddress(address);
        InsertThread(SharedFrom(current_thread));
        current_thread->SetStatus(ThreadStatus::WaitArb);
        current_thread->WaitForArbitration(true);
    }

    if (event_handle != InvalidHandle) {
        auto& time_manager = kernel.TimeManager();
        time_manager.UnscheduleTimeEvent(event_handle);
    }

    {
        SchedulerLock lock(kernel);
        if (current_thread->IsWaitingForArbitration()) {
            RemoveThread(SharedFrom(current_thread));
            current_thread->WaitForArbitration(false);
        }
    }

    return current_thread->GetSignalingResult();
}

void AddressArbiter::InsertThread(std::shared_ptr<Thread> thread) {
    const VAddr arb_addr = thread->GetArbiterWaitAddress();
    std::list<std::shared_ptr<Thread>>& thread_list = arb_threads[arb_addr];

    const auto iter =
        std::find_if(thread_list.cbegin(), thread_list.cend(), [&thread](const auto& entry) {
            return entry->GetPriority() >= thread->GetPriority();
        });

    if (iter == thread_list.cend()) {
        thread_list.push_back(std::move(thread));
    } else {
        thread_list.insert(iter, std::move(thread));
    }
}

void AddressArbiter::RemoveThread(std::shared_ptr<Thread> thread) {
    const VAddr arb_addr = thread->GetArbiterWaitAddress();
    std::list<std::shared_ptr<Thread>>& thread_list = arb_threads[arb_addr];

    const auto iter = std::find_if(thread_list.cbegin(), thread_list.cend(),
                                   [&thread](const auto& entry) { return thread == entry; });

    if (iter != thread_list.cend()) {
        thread_list.erase(iter);
    }
}

std::vector<std::shared_ptr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(
    VAddr address) const {
    const auto iter = arb_threads.find(address);
    if (iter == arb_threads.cend()) {
        return {};
    }

    const std::list<std::shared_ptr<Thread>>& thread_list = iter->second;
    return {thread_list.cbegin(), thread_list.cend()};
}
} // namespace Kernel