From 23ec6b3d8f33e2e0d76e646049f1fc1d36e70cf3 Mon Sep 17 00:00:00 2001
From: Yuri Kunde Schlesner <yuriks@yuriks.net>
Date: Mon, 5 Jun 2017 23:31:59 -0700
Subject: [PATCH] Service: Make service registration part of the sm
 implementation

Also enhances the GetServiceHandle implementation to be more accurate.
---
 src/core/CMakeLists.txt          |  2 ++
 src/core/hle/service/service.cpp | 15 ++++-----
 src/core/hle/service/service.h   |  2 --
 src/core/hle/service/sm/sm.cpp   | 58 ++++++++++++++++++++++++++++++++
 src/core/hle/service/sm/sm.h     | 49 +++++++++++++++++++++++++++
 src/core/hle/service/sm/srv.cpp  | 53 +++++++++++++++++++----------
 6 files changed, 151 insertions(+), 28 deletions(-)
 create mode 100644 src/core/hle/service/sm/sm.cpp
 create mode 100644 src/core/hle/service/sm/sm.h

diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt
index 0e2aacde7..6e602b0c5 100644
--- a/src/core/CMakeLists.txt
+++ b/src/core/CMakeLists.txt
@@ -156,6 +156,7 @@ set(SRCS
             hle/service/qtm/qtm_sp.cpp
             hle/service/qtm/qtm_u.cpp
             hle/service/service.cpp
+            hle/service/sm/sm.cpp
             hle/service/sm/srv.cpp
             hle/service/soc_u.cpp
             hle/service/ssl_c.cpp
@@ -352,6 +353,7 @@ set(HEADERS
             hle/service/qtm/qtm_sp.h
             hle/service/qtm/qtm_u.h
             hle/service/service.h
+            hle/service/sm/sm.h
             hle/service/sm/srv.h
             hle/service/soc_u.h
             hle/service/ssl_c.h
diff --git a/src/core/hle/service/service.cpp b/src/core/hle/service/service.cpp
index 3a821871f..1b64ee77d 100644
--- a/src/core/hle/service/service.cpp
+++ b/src/core/hle/service/service.cpp
@@ -38,6 +38,7 @@
 #include "core/hle/service/ptm/ptm.h"
 #include "core/hle/service/qtm/qtm.h"
 #include "core/hle/service/service.h"
+#include "core/hle/service/sm/sm.h"
 #include "core/hle/service/sm/srv.h"
 #include "core/hle/service/soc_u.h"
 #include "core/hle/service/ssl_c.h"
@@ -46,7 +47,6 @@
 namespace Service {
 
 std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> g_kernel_named_ports;
-std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> g_srv_services;
 
 /**
  * Creates a function string for logging, complete with the name (or header code, depending
@@ -115,17 +115,16 @@ static void AddNamedPort(Interface* interface_) {
 }
 
 void AddService(Interface* interface_) {
-    Kernel::SharedPtr<Kernel::ServerPort> server_port;
-    Kernel::SharedPtr<Kernel::ClientPort> client_port;
-    std::tie(server_port, client_port) =
-        Kernel::ServerPort::CreatePortPair(interface_->GetMaxSessions(), interface_->GetPortName());
-
+    auto server_port =
+        SM::g_service_manager
+            ->RegisterService(interface_->GetPortName(), interface_->GetMaxSessions())
+            .MoveFrom();
     server_port->SetHleHandler(std::shared_ptr<Interface>(interface_));
-    g_srv_services.emplace(interface_->GetPortName(), std::move(client_port));
 }
 
 /// Initialize ServiceManager
 void Init() {
+    SM::g_service_manager = std::make_unique<SM::ServiceManager>();
     AddNamedPort(new SM::SRV);
     AddNamedPort(new ERR::ERR_F);
 
@@ -187,7 +186,7 @@ void Shutdown() {
     AC::Shutdown();
     FS::ArchiveShutdown();
 
-    g_srv_services.clear();
+    SM::g_service_manager = nullptr;
     g_kernel_named_ports.clear();
     LOG_DEBUG(Service, "shutdown OK");
 }
diff --git a/src/core/hle/service/service.h b/src/core/hle/service/service.h
index a5fe843f6..7010b116b 100644
--- a/src/core/hle/service/service.h
+++ b/src/core/hle/service/service.h
@@ -107,8 +107,6 @@ void Shutdown();
 
 /// Map of named ports managed by the kernel, which can be retrieved using the ConnectToPort SVC.
 extern std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> g_kernel_named_ports;
-/// Map of services registered with the "srv:" service, retrieved using GetServiceHandle.
-extern std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> g_srv_services;
 
 /// Adds a service to the services table
 void AddService(Interface* interface_);
diff --git a/src/core/hle/service/sm/sm.cpp b/src/core/hle/service/sm/sm.cpp
new file mode 100644
index 000000000..40df0f0dd
--- /dev/null
+++ b/src/core/hle/service/sm/sm.cpp
@@ -0,0 +1,58 @@
+// Copyright 2017 Citra Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#include <tuple>
+#include "core/hle/kernel/client_session.h"
+#include "core/hle/kernel/server_port.h"
+#include "core/hle/result.h"
+#include "core/hle/service/sm/sm.h"
+
+namespace Service {
+namespace SM {
+
+static ResultCode ValidateServiceName(const std::string& name) {
+    if (name.size() <= 0 || name.size() > 8) {
+        return ERR_INVALID_NAME_SIZE;
+    }
+    if (name.find('\0') != std::string::npos) {
+        return ERR_NAME_CONTAINS_NUL;
+    }
+    return RESULT_SUCCESS;
+}
+
+ResultVal<Kernel::SharedPtr<Kernel::ServerPort>> ServiceManager::RegisterService(
+    std::string name, unsigned int max_sessions) {
+
+    CASCADE_CODE(ValidateServiceName(name));
+    Kernel::SharedPtr<Kernel::ServerPort> server_port;
+    Kernel::SharedPtr<Kernel::ClientPort> client_port;
+    std::tie(server_port, client_port) = Kernel::ServerPort::CreatePortPair(max_sessions, name);
+
+    registered_services.emplace(name, std::move(client_port));
+    return MakeResult<Kernel::SharedPtr<Kernel::ServerPort>>(std::move(server_port));
+}
+
+ResultVal<Kernel::SharedPtr<Kernel::ClientPort>> ServiceManager::GetServicePort(
+    const std::string& name) {
+
+    CASCADE_CODE(ValidateServiceName(name));
+    auto it = registered_services.find(name);
+    if (it == registered_services.end()) {
+        return ERR_SERVICE_NOT_REGISTERED;
+    }
+
+    return MakeResult<Kernel::SharedPtr<Kernel::ClientPort>>(it->second);
+}
+
+ResultVal<Kernel::SharedPtr<Kernel::ClientSession>> ServiceManager::ConnectToService(
+    const std::string& name) {
+
+    CASCADE_RESULT(auto client_port, GetServicePort(name));
+    return client_port->Connect();
+}
+
+std::unique_ptr<ServiceManager> g_service_manager;
+
+} // namespace SM
+} // namespace Service
diff --git a/src/core/hle/service/sm/sm.h b/src/core/hle/service/sm/sm.h
new file mode 100644
index 000000000..5fac5455c
--- /dev/null
+++ b/src/core/hle/service/sm/sm.h
@@ -0,0 +1,49 @@
+// Copyright 2017 Citra Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#pragma once
+
+#include <string>
+#include <unordered_map>
+#include "core/hle/kernel/kernel.h"
+#include "core/hle/result.h"
+#include "core/hle/service/service.h"
+
+namespace Kernel {
+class ClientPort;
+class ClientSession;
+class ServerPort;
+class SessionRequestHandler;
+} // namespace Kernel
+
+namespace Service {
+namespace SM {
+
+constexpr ResultCode ERR_SERVICE_NOT_REGISTERED(1, ErrorModule::SRV, ErrorSummary::WouldBlock,
+                                                ErrorLevel::Temporary); // 0xD0406401
+constexpr ResultCode ERR_MAX_CONNECTIONS_REACHED(2, ErrorModule::SRV, ErrorSummary::WouldBlock,
+                                                 ErrorLevel::Temporary); // 0xD0406402
+constexpr ResultCode ERR_INVALID_NAME_SIZE(5, ErrorModule::SRV, ErrorSummary::WrongArgument,
+                                           ErrorLevel::Permanent); // 0xD9006405
+constexpr ResultCode ERR_ACCESS_DENIED(6, ErrorModule::SRV, ErrorSummary::InvalidArgument,
+                                       ErrorLevel::Permanent); // 0xD8E06406
+constexpr ResultCode ERR_NAME_CONTAINS_NUL(7, ErrorModule::SRV, ErrorSummary::WrongArgument,
+                                           ErrorLevel::Permanent); // 0xD9006407
+
+class ServiceManager {
+public:
+    ResultVal<Kernel::SharedPtr<Kernel::ServerPort>> RegisterService(std::string name,
+                                                                     unsigned int max_sessions);
+    ResultVal<Kernel::SharedPtr<Kernel::ClientPort>> GetServicePort(const std::string& name);
+    ResultVal<Kernel::SharedPtr<Kernel::ClientSession>> ConnectToService(const std::string& name);
+
+private:
+    /// Map of services registered with the "srv:" service, retrieved using GetServiceHandle.
+    std::unordered_map<std::string, Kernel::SharedPtr<Kernel::ClientPort>> registered_services;
+};
+
+extern std::unique_ptr<ServiceManager> g_service_manager;
+
+} // namespace SM
+} // namespace Service
diff --git a/src/core/hle/service/sm/srv.cpp b/src/core/hle/service/sm/srv.cpp
index d6946c734..34166289c 100644
--- a/src/core/hle/service/sm/srv.cpp
+++ b/src/core/hle/service/sm/srv.cpp
@@ -9,6 +9,7 @@
 #include "core/hle/kernel/client_session.h"
 #include "core/hle/kernel/semaphore.h"
 #include "core/hle/kernel/server_session.h"
+#include "core/hle/service/sm/sm.h"
 #include "core/hle/service/sm/srv.h"
 
 namespace Service {
@@ -78,25 +79,41 @@ static void GetServiceHandle(Interface* self) {
     ResultCode res = RESULT_SUCCESS;
     u32* cmd_buff = Kernel::GetCommandBuffer();
 
-    std::string port_name = std::string((const char*)&cmd_buff[1], 0, Service::kMaxPortSize);
-    auto it = Service::g_srv_services.find(port_name);
-
-    if (it != Service::g_srv_services.end()) {
-        auto client_port = it->second;
-
-        auto client_session = client_port->Connect();
-        res = client_session.Code();
-
-        if (client_session.Succeeded()) {
-            // Return the client session
-            cmd_buff[3] = Kernel::g_handle_table.Create(*client_session).MoveFrom();
-        }
-        LOG_TRACE(Service_SRV, "called port=%s, handle=0x%08X", port_name.c_str(), cmd_buff[3]);
-    } else {
-        LOG_ERROR(Service_SRV, "(UNIMPLEMENTED) called port=%s", port_name.c_str());
-        res = UnimplementedFunction(ErrorModule::SRV);
+    size_t name_len = cmd_buff[3];
+    if (name_len > Service::kMaxPortSize) {
+        cmd_buff[1] = ERR_INVALID_NAME_SIZE.raw;
+        LOG_ERROR(Service_SRV, "called name_len=0x%X, failed with code=0x%08X", name_len,
+                  cmd_buff[1]);
+        return;
+    }
+    std::string name(reinterpret_cast<const char*>(&cmd_buff[1]), name_len);
+    bool return_port_on_failure = (cmd_buff[4] & 1) == 0;
+
+    // TODO(yuriks): Permission checks go here
+
+    auto client_port = g_service_manager->GetServicePort(name);
+    if (client_port.Failed()) {
+        cmd_buff[1] = client_port.Code().raw;
+        LOG_ERROR(Service_SRV, "called service=%s, failed with code=0x%08X", name.c_str(),
+                  cmd_buff[1]);
+        return;
+    }
+
+    auto session = client_port.Unwrap()->Connect();
+    cmd_buff[1] = session.Code().raw;
+    if (session.Succeeded()) {
+        cmd_buff[3] = Kernel::g_handle_table.Create(session.MoveFrom()).MoveFrom();
+        LOG_DEBUG(Service_SRV, "called service=%s, session handle=0x%08X", name.c_str(),
+                  cmd_buff[3]);
+    } else if (session.Code() == Kernel::ERR_MAX_CONNECTIONS_REACHED && return_port_on_failure) {
+        cmd_buff[1] = ERR_MAX_CONNECTIONS_REACHED.raw;
+        cmd_buff[3] = Kernel::g_handle_table.Create(client_port.MoveFrom()).MoveFrom();
+        LOG_WARNING(Service_SRV, "called service=%s, *port* handle=0x%08X", name.c_str(),
+                    cmd_buff[3]);
+    } else {
+        LOG_ERROR(Service_SRV, "called service=%s, failed with code=0x%08X", name.c_str(),
+                  cmd_buff[1]);
     }
-    cmd_buff[1] = res.raw;
 }
 
 /**