mirror of
				https://git.suyu.dev/suyu/suyu
				synced 2025-11-04 08:59:03 -06:00 
			
		
		
		
	Kernel: Object ShouldWait and Acquire calls now take a thread as a parameter.
This will be useful when implementing mutex priority inheritance.
This commit is contained in:
		@@ -30,12 +30,12 @@ SharedPtr<Event> Event::Create(ResetType reset_type, std::string name) {
 | 
			
		||||
    return evt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Event::ShouldWait() {
 | 
			
		||||
bool Event::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return !signaled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Event::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void Event::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
 | 
			
		||||
    // Release the event if it's not sticky...
 | 
			
		||||
    if (reset_type != ResetType::Sticky)
 | 
			
		||||
 
 | 
			
		||||
@@ -35,8 +35,8 @@ public:
 | 
			
		||||
    bool signaled;    ///< Whether the event has already been signaled
 | 
			
		||||
    std::string name; ///< Name of event (optional)
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
    void Signal();
 | 
			
		||||
    void Clear();
 | 
			
		||||
 
 | 
			
		||||
@@ -39,11 +39,6 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
 | 
			
		||||
               thread->status == THREADSTATUS_DEAD;
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    // TODO(Subv): This call should be performed inside the loop below to check if an object can be
 | 
			
		||||
    // acquired by a particular thread. This is useful for things like recursive locking of Mutexes.
 | 
			
		||||
    if (ShouldWait())
 | 
			
		||||
        return nullptr;
 | 
			
		||||
 | 
			
		||||
    Thread* candidate = nullptr;
 | 
			
		||||
    s32 candidate_priority = THREADPRIO_LOWEST + 1;
 | 
			
		||||
 | 
			
		||||
@@ -51,9 +46,12 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
 | 
			
		||||
        if (thread->current_priority >= candidate_priority)
 | 
			
		||||
            continue;
 | 
			
		||||
 | 
			
		||||
        if (ShouldWait(thread.get()))
 | 
			
		||||
            continue;
 | 
			
		||||
 | 
			
		||||
        bool ready_to_run =
 | 
			
		||||
            std::none_of(thread->wait_objects.begin(), thread->wait_objects.end(),
 | 
			
		||||
                         [](const SharedPtr<WaitObject>& object) { return object->ShouldWait(); });
 | 
			
		||||
                         [&thread](const SharedPtr<WaitObject>& object) { return object->ShouldWait(thread.get()); });
 | 
			
		||||
        if (ready_to_run) {
 | 
			
		||||
            candidate = thread.get();
 | 
			
		||||
            candidate_priority = thread->current_priority;
 | 
			
		||||
@@ -66,7 +64,7 @@ SharedPtr<Thread> WaitObject::GetHighestPriorityReadyThread() {
 | 
			
		||||
void WaitObject::WakeupAllWaitingThreads() {
 | 
			
		||||
    while (auto thread = GetHighestPriorityReadyThread()) {
 | 
			
		||||
        if (!thread->IsSleepingOnWaitAll()) {
 | 
			
		||||
            Acquire();
 | 
			
		||||
            Acquire(thread.get());
 | 
			
		||||
            // Set the output index of the WaitSynchronizationN call to the index of this object.
 | 
			
		||||
            if (thread->wait_set_output) {
 | 
			
		||||
                thread->SetWaitSynchronizationOutput(thread->GetWaitObjectIndex(this));
 | 
			
		||||
@@ -74,7 +72,7 @@ void WaitObject::WakeupAllWaitingThreads() {
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            for (auto& object : thread->wait_objects) {
 | 
			
		||||
                object->Acquire();
 | 
			
		||||
                object->Acquire(thread.get());
 | 
			
		||||
                object->RemoveWaitingThread(thread.get());
 | 
			
		||||
            }
 | 
			
		||||
            // Note: This case doesn't update the output index of WaitSynchronizationN.
 | 
			
		||||
 
 | 
			
		||||
@@ -132,13 +132,14 @@ using SharedPtr = boost::intrusive_ptr<T>;
 | 
			
		||||
class WaitObject : public Object {
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * Check if the current thread should wait until the object is available
 | 
			
		||||
     * Check if the specified thread should wait until the object is available
 | 
			
		||||
     * @param thread The thread about which we're deciding.
 | 
			
		||||
     * @return True if the current thread should wait due to this object being unavailable
 | 
			
		||||
     */
 | 
			
		||||
    virtual bool ShouldWait() = 0;
 | 
			
		||||
    virtual bool ShouldWait(Thread* thread) const = 0;
 | 
			
		||||
 | 
			
		||||
    /// Acquire/lock the object if it is available
 | 
			
		||||
    virtual void Acquire() = 0;
 | 
			
		||||
    /// Acquire/lock the object for the specified thread if it is available
 | 
			
		||||
    virtual void Acquire(Thread* thread) = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Add a thread to wait on this object
 | 
			
		||||
 
 | 
			
		||||
@@ -40,31 +40,19 @@ SharedPtr<Mutex> Mutex::Create(bool initial_locked, std::string name) {
 | 
			
		||||
    mutex->name = std::move(name);
 | 
			
		||||
    mutex->holding_thread = nullptr;
 | 
			
		||||
 | 
			
		||||
    // Acquire mutex with current thread if initialized as locked...
 | 
			
		||||
    // Acquire mutex with current thread if initialized as locked
 | 
			
		||||
    if (initial_locked)
 | 
			
		||||
        mutex->Acquire();
 | 
			
		||||
        mutex->Acquire(GetCurrentThread());
 | 
			
		||||
 | 
			
		||||
    return mutex;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Mutex::ShouldWait() {
 | 
			
		||||
    auto thread = GetCurrentThread();
 | 
			
		||||
    bool wait = lock_count > 0 && holding_thread != thread;
 | 
			
		||||
 | 
			
		||||
    // If the holding thread of the mutex is lower priority than this thread, that thread should
 | 
			
		||||
    // temporarily inherit this thread's priority
 | 
			
		||||
    if (wait && thread->current_priority < holding_thread->current_priority)
 | 
			
		||||
        holding_thread->BoostPriority(thread->current_priority);
 | 
			
		||||
 | 
			
		||||
    return wait;
 | 
			
		||||
bool Mutex::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return lock_count > 0 && thread != holding_thread;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Mutex::Acquire() {
 | 
			
		||||
    Acquire(GetCurrentThread());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Mutex::Acquire(SharedPtr<Thread> thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void Mutex::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
 | 
			
		||||
    // Actually "acquire" the mutex only if we don't already have it...
 | 
			
		||||
    if (lock_count == 0) {
 | 
			
		||||
 
 | 
			
		||||
@@ -38,8 +38,9 @@ public:
 | 
			
		||||
    std::string name;                 ///< Name of mutex (optional)
 | 
			
		||||
    SharedPtr<Thread> holding_thread; ///< Thread that has acquired the mutex
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Acquires the specified mutex for the specified thread
 | 
			
		||||
 
 | 
			
		||||
@@ -30,12 +30,12 @@ ResultVal<SharedPtr<Semaphore>> Semaphore::Create(s32 initial_count, s32 max_cou
 | 
			
		||||
    return MakeResult<SharedPtr<Semaphore>>(std::move(semaphore));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Semaphore::ShouldWait() {
 | 
			
		||||
bool Semaphore::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return available_count <= 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Semaphore::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void Semaphore::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
    --available_count;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -39,8 +39,8 @@ public:
 | 
			
		||||
    s32 available_count; ///< Number of free slots left in the semaphore
 | 
			
		||||
    std::string name;    ///< Name of semaphore (optional)
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Releases a certain number of slots from a semaphore.
 | 
			
		||||
 
 | 
			
		||||
@@ -14,13 +14,13 @@ namespace Kernel {
 | 
			
		||||
ServerPort::ServerPort() {}
 | 
			
		||||
ServerPort::~ServerPort() {}
 | 
			
		||||
 | 
			
		||||
bool ServerPort::ShouldWait() {
 | 
			
		||||
bool ServerPort::ShouldWait(Thread* thread) const {
 | 
			
		||||
    // If there are no pending sessions, we wait until a new one is added.
 | 
			
		||||
    return pending_sessions.size() == 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ServerPort::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void ServerPort::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<SharedPtr<ServerPort>, SharedPtr<ClientPort>> ServerPort::CreatePortPair(
 | 
			
		||||
 
 | 
			
		||||
@@ -53,8 +53,8 @@ public:
 | 
			
		||||
    /// ServerSessions created from this port inherit a reference to this handler.
 | 
			
		||||
    std::shared_ptr<Service::SessionRequestHandler> hle_handler;
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    ServerPort();
 | 
			
		||||
 
 | 
			
		||||
@@ -29,12 +29,12 @@ ResultVal<SharedPtr<ServerSession>> ServerSession::Create(
 | 
			
		||||
    return MakeResult<SharedPtr<ServerSession>>(std::move(server_session));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ServerSession::ShouldWait() {
 | 
			
		||||
bool ServerSession::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return !signaled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ServerSession::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void ServerSession::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
    signaled = false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -57,9 +57,9 @@ public:
 | 
			
		||||
     */
 | 
			
		||||
    ResultCode HandleSyncRequest();
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
    std::string name; ///< The name of this session (optional)
 | 
			
		||||
    bool signaled;    ///< Whether there's new data available to this ServerSession
 | 
			
		||||
 
 | 
			
		||||
@@ -27,12 +27,12 @@ namespace Kernel {
 | 
			
		||||
/// Event type for the thread wake up event
 | 
			
		||||
static int ThreadWakeupEventType;
 | 
			
		||||
 | 
			
		||||
bool Thread::ShouldWait() {
 | 
			
		||||
bool Thread::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return status != THREADSTATUS_DEAD;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Thread::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void Thread::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(yuriks): This can be removed if Thread objects are explicitly pooled in the future, allowing
 | 
			
		||||
 
 | 
			
		||||
@@ -72,8 +72,8 @@ public:
 | 
			
		||||
        return HANDLE_TYPE;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Gets the thread's current priority
 | 
			
		||||
 
 | 
			
		||||
@@ -39,12 +39,12 @@ SharedPtr<Timer> Timer::Create(ResetType reset_type, std::string name) {
 | 
			
		||||
    return timer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Timer::ShouldWait() {
 | 
			
		||||
bool Timer::ShouldWait(Thread* thread) const {
 | 
			
		||||
    return !signaled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Timer::Acquire() {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(), "object unavailable!");
 | 
			
		||||
void Timer::Acquire(Thread* thread) {
 | 
			
		||||
    ASSERT_MSG(!ShouldWait(thread), "object unavailable!");
 | 
			
		||||
 | 
			
		||||
    if (reset_type == ResetType::OneShot)
 | 
			
		||||
        signaled = false;
 | 
			
		||||
 
 | 
			
		||||
@@ -39,8 +39,8 @@ public:
 | 
			
		||||
    u64 initial_delay;  ///< The delay until the timer fires for the first time
 | 
			
		||||
    u64 interval_delay; ///< The delay until the timer fires after the first time
 | 
			
		||||
 | 
			
		||||
    bool ShouldWait() override;
 | 
			
		||||
    void Acquire() override;
 | 
			
		||||
    bool ShouldWait(Thread* thread) const override;
 | 
			
		||||
    void Acquire(Thread* thread) override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Starts the timer, with the specified initial delay and interval.
 | 
			
		||||
 
 | 
			
		||||
@@ -272,7 +272,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
 | 
			
		||||
    LOG_TRACE(Kernel_SVC, "called handle=0x%08X(%s:%s), nanoseconds=%lld", handle,
 | 
			
		||||
              object->GetTypeName().c_str(), object->GetName().c_str(), nano_seconds);
 | 
			
		||||
 | 
			
		||||
    if (object->ShouldWait()) {
 | 
			
		||||
    if (object->ShouldWait(thread)) {
 | 
			
		||||
 | 
			
		||||
        if (nano_seconds == 0)
 | 
			
		||||
            return ERR_SYNC_TIMEOUT;
 | 
			
		||||
@@ -294,7 +294,7 @@ static ResultCode WaitSynchronization1(Kernel::Handle handle, s64 nano_seconds)
 | 
			
		||||
        return ERR_SYNC_TIMEOUT;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    object->Acquire();
 | 
			
		||||
    object->Acquire(thread);
 | 
			
		||||
 | 
			
		||||
    return RESULT_SUCCESS;
 | 
			
		||||
}
 | 
			
		||||
@@ -336,11 +336,11 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
 | 
			
		||||
    if (wait_all) {
 | 
			
		||||
        bool all_available =
 | 
			
		||||
            std::all_of(objects.begin(), objects.end(),
 | 
			
		||||
                        [](const ObjectPtr& object) { return !object->ShouldWait(); });
 | 
			
		||||
                        [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
 | 
			
		||||
        if (all_available) {
 | 
			
		||||
            // We can acquire all objects right now, do so.
 | 
			
		||||
            for (auto& object : objects)
 | 
			
		||||
                object->Acquire();
 | 
			
		||||
                object->Acquire(thread);
 | 
			
		||||
            // Note: In this case, the `out` parameter is not set,
 | 
			
		||||
            // and retains whatever value it had before.
 | 
			
		||||
            return RESULT_SUCCESS;
 | 
			
		||||
@@ -380,12 +380,12 @@ static ResultCode WaitSynchronizationN(s32* out, Kernel::Handle* handles, s32 ha
 | 
			
		||||
    } else {
 | 
			
		||||
        // Find the first object that is acquirable in the provided list of objects
 | 
			
		||||
        auto itr = std::find_if(objects.begin(), objects.end(),
 | 
			
		||||
                                [](const ObjectPtr& object) { return !object->ShouldWait(); });
 | 
			
		||||
                                [thread](const ObjectPtr& object) { return !object->ShouldWait(thread); });
 | 
			
		||||
 | 
			
		||||
        if (itr != objects.end()) {
 | 
			
		||||
            // We found a ready object, acquire it and set the result value
 | 
			
		||||
            Kernel::WaitObject* object = itr->get();
 | 
			
		||||
            object->Acquire();
 | 
			
		||||
            object->Acquire(thread);
 | 
			
		||||
            *out = std::distance(objects.begin(), itr);
 | 
			
		||||
            return RESULT_SUCCESS;
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user