diff --git a/.pipelines/build-stage.yml b/.pipelines/build-stage.yml index b213352b6..ff55eb489 100644 --- a/.pipelines/build-stage.yml +++ b/.pipelines/build-stage.yml @@ -27,8 +27,8 @@ parameters: - name: targets type: object default: - - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin" - pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll" + - target: "wsl;libwsl;wslg;wslservice;wslhost;wslrelay;wslpluginhost;wslinstaller;wslinstall;initramfs;wslserviceproxystub;wslsettings;wslinstallerproxystub;testplugin" + pattern: "wsl.exe,libwsl.dll,wslg.exe,wslservice.exe,wslhost.exe,wslrelay.exe,wslpluginhost.exe,wslinstaller.exe,wslinstall.dll,wslserviceproxystub.dll,wslsettings/wslsettings.dll,wslsettings/wslsettings.exe,wslinstallerproxystub.dll,WSLDVCPlugin.dll,testplugin.dll,wsldeps.dll" - target: "msixgluepackage" pattern: "gluepackage.msix" - target: "msipackage" diff --git a/CMakeLists.txt b/CMakeLists.txt index 53abd10b2..feda1fcd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -486,6 +486,7 @@ add_subdirectory(src/windows/wsl) add_subdirectory(src/windows/wslg) add_subdirectory(src/windows/wslhost) add_subdirectory(src/windows/wslrelay) +add_subdirectory(src/windows/wslpluginhost) add_subdirectory(src/windows/wslinstall) if (WSL_BUILD_WSL_SETTINGS) diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt index a7c7c8c04..11c946418 100644 --- a/msipackage/CMakeLists.txt +++ b/msipackage/CMakeLists.txt @@ -12,7 +12,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX ${BIN}/package.wix) set(CAB_CACHE ${BIN}/cab) -set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll) +set(WINDOWS_BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslpluginhost.exe;wslservice.exe;wslserviceproxystub.dll;wslinstall.dll) if (WSL_BUILD_WSL_SETTINGS) list(APPEND WINDOWS_BINARIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") endif() @@ -52,7 +52,7 @@ add_custom_command( add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) -add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage) +add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslpluginhost wslserviceproxystub init initramfs wslinstall msixgluepackage) if (WSL_BUILD_WSL_SETTINGS) add_dependencies(msipackage wslsettings libwsl) diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index a0ec1b700..2207f4444 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -29,6 +29,7 @@ + @@ -159,6 +160,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/windows/common/precomp.h b/src/windows/common/precomp.h index 08a8140c1..b8c5c7d01 100644 --- a/src/windows/common/precomp.h +++ b/src/windows/common/precomp.h @@ -83,6 +83,7 @@ Module Name: #include #include #include +#include #include #include #include diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index cfe5b8649..c48d49e75 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -50,7 +50,7 @@ set(HEADERS WslCoreVm.h) add_executable(wslservice ${SOURCES} ${HEADERS}) -add_dependencies(wslservice wslserviceidl wslservicemc) +add_dependencies(wslservice wslserviceidl wslservicemc wslpluginhostidl) add_compile_definitions(__WRL_CLASSIC_COM__) add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__) add_compile_definitions(USE_COM_CONTEXT_DEF=1) diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp index 2ad6469c9..f4610fc2e 100644 --- a/src/windows/service/exe/LxssUserSession.cpp +++ b/src/windows/service/exe/LxssUserSession.cpp @@ -2604,13 +2604,18 @@ std::shared_ptr LxssUserSessionImpl::_CreateInstance(_In_op registration.Write(Property::OsVersion, distributionInfo->Version); } - // This needs to be done before plugins are notifed because they might try to run a command inside the distribution. - m_runningInstances[registration.Id()] = instance; + // This needs to be done before plugins are notified because they might try to run a command inside the distribution. + { + std::unique_lock callbackLock(m_callbackLock); + m_runningInstances[registration.Id()] = instance; + } if (version == LXSS_WSL_VERSION_2) { - auto cleanupOnFailure = - wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { m_runningInstances.erase(registration.Id()); }); + auto cleanupOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + std::unique_lock callbackLock(m_callbackLock); + m_runningInstances.erase(registration.Id()); + }); m_pluginManager.OnDistributionStarted(&m_session, instance->DistributionInformation()); cleanupOnFailure.release(); } @@ -3577,17 +3582,26 @@ bool LxssUserSessionImpl::_TerminateInstanceInternal(_In_ LPCGUID DistroGuid, _I m_pluginManager.OnDistributionStopping(&m_session, wslcoreInstance->DistributionInformation()); } - instance->second->Stop(); + m_lifetimeManager.RemoveCallback(clientKey); - const auto clientId = instance->second->GetClientId(); + // Stop the instance and remove it from m_runningInstances atomically + // under m_callbackLock. This prevents plugin callbacks (which hold + // m_callbackLock shared) from finding a stopped-but-still-listed + // instance between Stop() and erase. + ULONG clientId; { - auto lock = m_terminatedInstanceLock.lock_exclusive(); - m_terminatedInstances.push_back(std::move(instance->second)); - } + std::unique_lock callbackLock(m_callbackLock); - m_lifetimeManager.RemoveCallback(clientKey); + instance->second->Stop(); + clientId = instance->second->GetClientId(); + + { + auto lock = m_terminatedInstanceLock.lock_exclusive(); + m_terminatedInstances.push_back(std::move(instance->second)); + } - m_runningInstances.erase(instance); + m_runningInstances.erase(instance); + } // If the instance that was terminated was a WSL2 instance, // check if the VM is now idle. @@ -3615,7 +3629,10 @@ void LxssUserSessionImpl::_UpdateInit(_In_ const LXSS_DISTRO_CONFIGURATION& Conf HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name) { - std::lock_guard lock(m_instanceLock); + // Shared lock prevents _VmTerminate from destroying the VM while we use it. + // Do NOT acquire m_instanceLock — callbacks arrive on a different COM thread + // from the notification thread that holds m_instanceLock. + std::shared_lock lock(m_callbackLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); m_utilityVm->MountRootNamespaceFolder(HostPath, GuestPath, ReadOnly, Name); @@ -3624,7 +3641,9 @@ HRESULT LxssUserSessionImpl::MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* Socket) { - std::lock_guard lock(m_instanceLock); + // Shared lock prevents _VmTerminate from destroying the VM or instances + // while we use them. See MountRootNamespaceFolder for rationale. + std::shared_lock lock(m_callbackLock); RETURN_HR_IF(E_NOT_VALID_STATE, !m_utilityVm); if (Distro == nullptr) @@ -3633,9 +3652,16 @@ HRESULT LxssUserSessionImpl::CreateLinuxProcess(_In_opt_ const GUID* Distro, _In } else { - const auto distro = _RunningInstance(Distro); - THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, !distro); - + // Look up the running instance directly instead of calling _RunningInstance, + // which accesses m_lockedDistributions (guarded only by m_instanceLock). + // m_runningInstances is safe to read under m_callbackLock (shared). + // The _EnsureNotLocked check is unnecessary here: _ConversionBegin removes + // a distribution from m_runningInstances before adding it to m_lockedDistributions, + // so a locked distribution will never be found in this lookup. + const auto instance = m_runningInstances.find(*Distro); + THROW_HR_IF(WSL_E_VM_MODE_INVALID_STATE, instance == m_runningInstances.end()); + + const auto distro = instance->second; const auto wsl2Distro = dynamic_cast(distro.get()); THROW_HR_IF(WSL_E_WSL2_NEEDED, !wsl2Distro); @@ -3871,7 +3897,12 @@ void LxssUserSessionImpl::_VmTerminate() m_telemetryThread.join(); } - m_utilityVm.reset(); + // Acquire exclusive callback lock to wait for any in-flight plugin callbacks + // (MountRootNamespaceFolder, CreateLinuxProcess) to complete before destroying the VM. + { + std::unique_lock callbackLock(m_callbackLock); + m_utilityVm.reset(); + } m_vmId.store(GUID_NULL); // Reset the user's token since its lifetime is tied to the VM. diff --git a/src/windows/service/exe/LxssUserSession.h b/src/windows/service/exe/LxssUserSession.h index 6e2d41687..c938f223f 100644 --- a/src/windows/service/exe/LxssUserSession.h +++ b/src/windows/service/exe/LxssUserSession.h @@ -310,6 +310,10 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce7e") LxssUserSession /// class LxssUserSessionImpl { + // Plugin callbacks arrive on a different COM RPC thread and use m_callbackLock + // (shared) instead of m_instanceLock to access m_utilityVm and m_runningInstances. + friend class wsl::windows::service::PluginHostCallbackImpl; + public: LxssUserSessionImpl(_In_ PSID userSid, _In_ DWORD sessionId, _Inout_ wsl::windows::service::PluginManager& pluginManager); virtual ~LxssUserSessionImpl(); @@ -363,11 +367,6 @@ class LxssUserSessionImpl /// void ClearDiskStateInRegistry(_In_opt_ LPCWSTR Disk); - /// - /// Start a process in the root namespace or in a user distribution. - /// - HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); - /// /// Enumerates registered distributions, optionally including ones that are /// currently being registered, unregistered, or converted. @@ -443,8 +442,6 @@ class LxssUserSessionImpl HRESULT MoveDistribution(_In_ LPCGUID DistroGuid, _In_ LPCWSTR Location); - HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); - /// /// Registers a distribution. /// @@ -533,6 +530,18 @@ class LxssUserSessionImpl static CreateLxProcessContext s_GetCreateProcessContext(_In_ const GUID& DistroGuid, _In_ bool SystemDistro); private: + /// + /// Plugin callback methods — called from PluginHostCallbackImpl on a COM RPC + /// thread during plugin notifications. These acquire m_callbackLock (shared) + /// instead of m_instanceLock, preventing _VmTerminate from destroying the VM + /// while a callback is in-flight. Access is restricted via friend declaration. + /// + _Requires_lock_not_held_(m_instanceLock) + HRESULT MountRootNamespaceFolder(_In_ LPCWSTR HostPath, _In_ LPCWSTR GuestPath, _In_ bool ReadOnly, _In_ LPCWSTR Name); + + _Requires_lock_not_held_(m_instanceLock) + HRESULT CreateLinuxProcess(_In_opt_ const GUID* Distro, _In_ LPCSTR Path, _In_ LPCSTR* Arguments, _Out_ SOCKET* socket); + /// /// Adds a distro to the list of converting distros. /// @@ -794,7 +803,9 @@ class LxssUserSessionImpl std::recursive_timed_mutex m_instanceLock; /// - /// Contains the currently running utility VM's. + /// Contains the currently running instances. + /// Reads guarded by m_instanceLock OR m_callbackLock (shared). + /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). /// _Guarded_by_(m_instanceLock) std::map, wsl::windows::common::helpers::GuidLess> m_runningInstances; @@ -811,9 +822,24 @@ class LxssUserSessionImpl /// /// The running utility vm for WSL2 distributions. - /// + /// Reads guarded by m_instanceLock OR m_callbackLock (shared). + /// Mutations require BOTH m_instanceLock AND m_callbackLock (exclusive). + /// _Guarded_by_(m_instanceLock) std::unique_ptr m_utilityVm; + /// + /// Reader-writer lock protecting m_utilityVm and m_runningInstances for + /// plugin callbacks. Callbacks take a shared (read) lock; _VmTerminate and + /// instance mutations take an exclusive (write) lock. + /// + /// Mutations of m_runningInstances/m_utilityVm require BOTH m_instanceLock + /// AND m_callbackLock (exclusive). Reads are safe under either lock alone. + /// + /// Lock ordering: m_instanceLock → m_callbackLock (never reverse). + /// Callbacks must NEVER acquire m_instanceLock (deadlock with notification thread). + /// + std::shared_mutex m_callbackLock; + std::atomic m_vmId{GUID_NULL}; /// diff --git a/src/windows/service/exe/PluginManager.cpp b/src/windows/service/exe/PluginManager.cpp index e4d23f226..d529ef036 100644 --- a/src/windows/service/exe/PluginManager.cpp +++ b/src/windows/service/exe/PluginManager.cpp @@ -9,6 +9,8 @@ Module Name: Abstract: This file contains the PluginManager helper class implementation. + Plugins are loaded in isolated wslpluginhost.exe processes via COM, + so a crashing plugin cannot take down the WSL service. --*/ @@ -16,90 +18,134 @@ Module Name: #include "install.h" #include "PluginManager.h" #include "WslPluginApi.h" +#include "WslPluginHost.h" #include "LxssUserSessionFactory.h" using wsl::windows::common::Context; using wsl::windows::common::ExecutionContext; +using wsl::windows::service::PluginHostCallbackImpl; using wsl::windows::service::PluginManager; constexpr auto c_pluginPath = L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Lxss\\Plugins"; -constexpr WSLVersion Version = {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision}; +// --- IWslPluginHostCallback implementation (service-side) --- +// These methods handle API calls from the plugin host process. -thread_local std::optional g_pluginErrorMessage; - -extern "C" { -HRESULT MountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name) +STDMETHODIMP PluginHostCallbackImpl::MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name) try { - const auto session = FindSessionByCookie(Session); + WSL_LOG( + "PluginCallbackMountFolderBegin", + TraceLoggingValue(WindowsPath, "WindowsPath"), + TraceLoggingValue(SessionId, "SessionId")); + const auto session = FindSessionByCookie(SessionId); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); auto result = session->MountRootNamespaceFolder(WindowsPath, LinuxPath, ReadOnly, Name); - WSL_LOG( - "PluginMountFolderCall", - TraceLoggingValue(WindowsPath, "WindowsPath"), - TraceLoggingValue(LinuxPath, "LinuxPath"), - TraceLoggingValue(ReadOnly, "ReadOnly"), - TraceLoggingValue(Name, "Name"), - TraceLoggingValue(result, "Result")); + WSL_LOG("PluginCallbackMountFolderEnd", TraceLoggingValue(WindowsPath, "WindowsPath"), TraceLoggingValue(result, "Result")); return result; } CATCH_RETURN(); -HRESULT ExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +STDMETHODIMP PluginHostCallbackImpl::ExecuteBinary( + _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket) try { + RETURN_HR_IF(E_POINTER, Socket == nullptr); + *Socket = nullptr; - const auto session = FindSessionByCookie(Session); + WSL_LOG("PluginCallbackExecuteBinaryBegin", TraceLoggingValue(Path, "Path"), TraceLoggingValue(SessionId, "SessionId")); + const auto session = FindSessionByCookie(SessionId); + WSL_LOG( + "PluginCallbackExecuteBinaryFoundSession", + TraceLoggingValue(Path, "Path"), + TraceLoggingValue(session != nullptr, "Found")); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - auto result = session->CreateLinuxProcess(nullptr, Path, Arguments, Socket); - - WSL_LOG("PluginExecuteBinaryCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - return result; -} -CATCH_RETURN(); - -HRESULT PluginError(LPCWSTR UserMessage) -try -{ - const auto* context = ExecutionContext::Current(); - THROW_HR_IF(E_INVALIDARG, UserMessage == nullptr); - THROW_HR_IF_MSG( - E_ILLEGAL_METHOD_CALL, context == nullptr || WI_IsFlagClear(context->CurrentContext(), Context::Plugin), "Message: %ls", UserMessage); + // Build NULL-terminated argument array expected by CreateLinuxProcess. + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); - // Logs when a WSL plugin hits an error and what that error message is - WSL_LOG_TELEMETRY("PluginError", PDT_ProductAndServicePerformance, TraceLoggingValue(UserMessage, "Message")); + WSL_LOG("PluginCallbackExecuteBinaryCallingCreateProcess", TraceLoggingValue(Path, "Path")); + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(nullptr, Path, args.data(), &sock); - THROW_HR_IF(E_ILLEGAL_STATE_CHANGE, g_pluginErrorMessage.has_value()); + WSL_LOG("PluginCallbackExecuteBinaryEnd", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); - g_pluginErrorMessage.emplace(UserMessage); + if (SUCCEEDED(result)) + { + // Return socket as HANDLE — COM's system_handle marshaling will + // duplicate it into the host process automatically. + *Socket = reinterpret_cast(sock.release()); + } - return S_OK; + return result; } CATCH_RETURN(); -HRESULT ExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +STDMETHODIMP PluginHostCallbackImpl::ExecuteBinaryInDistribution( + _In_ DWORD SessionId, + _In_ const GUID* DistributionId, + _In_ LPCSTR Path, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _Out_ HANDLE* Socket) try { - THROW_HR_IF(E_INVALIDARG, Distro == nullptr); + RETURN_HR_IF(E_POINTER, Socket == nullptr); + *Socket = nullptr; + RETURN_HR_IF(E_INVALIDARG, DistributionId == nullptr); - const auto session = FindSessionByCookie(Session); + const auto session = FindSessionByCookie(SessionId); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - auto result = session->CreateLinuxProcess(Distro, Path, Arguments, Socket); + std::vector args; + if (Arguments != nullptr) + { + args.assign(Arguments, Arguments + ArgumentCount); + } + args.push_back(nullptr); + + wil::unique_socket sock; + auto result = session->CreateLinuxProcess(DistributionId, Path, args.data(), &sock); WSL_LOG("PluginExecuteBinaryInDistributionCall", TraceLoggingValue(Path, "Path"), TraceLoggingValue(result, "Result")); + if (SUCCEEDED(result)) + { + *Socket = reinterpret_cast(sock.release()); + } + return result; } CATCH_RETURN(); + +STDMETHODIMP PluginHostCallbackImpl::PluginError(_In_ LPCWSTR UserMessage) +try +{ + // PluginError is now handled locally in the host process. + // The host captures the message and returns it alongside the hook HRESULT. + // This callback exists only for completeness — it should not be called + // directly over COM since the host handles it locally. + RETURN_HR_IF(E_INVALIDARG, UserMessage == nullptr); + WSL_LOG_TELEMETRY("PluginError", PDT_ProductAndServicePerformance, TraceLoggingValue(UserMessage, "Message")); + return S_OK; } +CATCH_RETURN(); -static constexpr WSLPluginAPIV1 ApiV1 = {Version, &MountFolder, &ExecuteBinary, &PluginError, &ExecuteBinaryInDistribution}; +// --- PluginManager implementation --- + +PluginManager::~PluginManager() +{ + // Release all COM proxies, which will cause the host processes to exit. + m_plugins.clear(); +} void PluginManager::LoadPlugins() { @@ -125,188 +171,374 @@ void PluginManager::LoadPlugins() continue; } - auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e.first.c_str(), path.c_str()); }); + // Record the plugin for deferred activation. The actual COM host process + // is created in EnsureInitialized(), which runs after the service's COM + // initialization is complete (CoInitializeSecurity must happen first). + OutOfProcPlugin plugin{}; + plugin.name = e.first; + plugin.path = path; + m_plugins.emplace_back(std::move(plugin)); - // Logs when a WSL plugin is loaded, used for evaluating plugin populations WSL_LOG_TELEMETRY( "PluginLoad", PDT_ProductAndServiceUsage, TraceLoggingValue(e.first.c_str(), "Name"), TraceLoggingValue(path.c_str(), "Path"), - TraceLoggingValue(loadResult, "Result")); + TraceLoggingValue(S_OK, "Result")); + } +} + +void PluginManager::EnsureInitialized() +{ + std::call_once(m_initOnce, [this]() { + m_callback = Microsoft::WRL::Make(); + THROW_IF_NULL_ALLOC(m_callback); - if (FAILED(loadResult)) + for (auto& e : m_plugins) { - // If this plugin reported an error, record it to display it to the user - m_pluginError.emplace(PluginError{e.first, loadResult}); + auto loadResult = wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&]() { LoadPlugin(e); }); + + WSL_LOG_TELEMETRY( + "PluginHostActivation", + PDT_ProductAndServiceUsage, + TraceLoggingValue(e.name.c_str(), "Name"), + TraceLoggingValue(e.path.c_str(), "Path"), + TraceLoggingValue(loadResult, "Result")); + + if (FAILED(loadResult)) + { + // Only treat plugin-reported errors (from entry point) as fatal. + // COM infrastructure errors (activation, connectivity) are non-fatal + // — the plugin is simply unavailable. + if (IsHostCrash(loadResult) || loadResult == CO_E_SERVER_EXEC_FAILURE) + { + LOG_HR_MSG(loadResult, "Plugin host activation failed for: '%ls', skipping", e.name.c_str()); + } + else + { + m_pluginError.emplace(PluginError{e.name, loadResult}); + } + } } - } + }); } -void PluginManager::LoadPlugin(LPCWSTR Name, LPCWSTR ModulePath) +void PluginManager::LoadPlugin(OutOfProcPlugin& plugin) { - // Validate the plugin signature before loading it. - // The handle to the module is kept open after validating the signature so the file can't be written to - // after the signature check. - wil::unique_hfile pluginHandle; - if constexpr (wsl::shared::OfficialBuild) + // Activate the plugin host via COM. The LocalServer32 registration causes COM + // to spawn wslpluginhost.exe automatically. + Microsoft::WRL::ComPtr host; + HRESULT activationHr = CoCreateInstance(CLSID_WslPluginHost, nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&host)); + WSL_LOG( + "PluginHostActivation", + TraceLoggingValue(plugin.name.c_str(), "Plugin"), + TraceLoggingValue(plugin.path.c_str(), "Path"), + TraceLoggingValue(activationHr, "CoCreateInstanceResult")); + THROW_IF_FAILED_MSG(activationHr, "Failed to create plugin host for: '%ls'", plugin.path.c_str()); + + THROW_IF_FAILED_MSG( + host->Initialize(m_callback.Get(), plugin.path.c_str(), plugin.name.c_str()), + "Plugin host failed to initialize: '%ls'", + plugin.path.c_str()); + + // Add the plugin host process to our job object so it is automatically + // terminated if wslservice exits or crashes. + EnsureJobObjectCreated(); + wil::unique_handle process; + if (SUCCEEDED(host->GetProcessHandle(&process))) { - pluginHandle = wsl::windows::common::install::ValidateFileSignature(ModulePath); - WI_ASSERT(pluginHandle.is_valid()); + LOG_IF_WIN32_BOOL_FALSE(AssignProcessToJobObject(m_jobObject.get(), process.get())); } - LoadedPlugin plugin{}; - plugin.name = Name; - - plugin.module.reset(LoadLibrary(ModulePath)); - THROW_LAST_ERROR_IF_NULL(plugin.module); - - const WSLPluginAPI_EntryPointV1 entryPoint = - reinterpret_cast(GetProcAddress(plugin.module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1))); + plugin.host = std::move(host); +} - THROW_LAST_ERROR_IF_NULL(entryPoint); - THROW_IF_FAILED_MSG(entryPoint(&ApiV1, &plugin.hooks), "Error returned by plugin: '%ls'", ModulePath); +void PluginManager::EnsureJobObjectCreated() +{ + std::call_once(m_jobObjectOnce, [this]() { + m_jobObject.reset(CreateJobObjectW(nullptr, nullptr)); + THROW_LAST_ERROR_IF(!m_jobObject); + + JOBOBJECT_EXTENDED_LIMIT_INFORMATION jobInfo{}; + jobInfo.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + THROW_IF_WIN32_BOOL_FALSE(SetInformationJobObject(m_jobObject.get(), JobObjectExtendedLimitInformation, &jobInfo, sizeof(jobInfo))); + }); +} - m_plugins.emplace_back(std::move(plugin)); +std::vector PluginManager::SerializeSid(PSID Sid) +{ + const DWORD sidLength = GetLengthSid(Sid); + std::vector buffer(sidLength); + CopySid(sidLength, buffer.data(), Sid); + return buffer; } void PluginManager::OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnVMStarted != nullptr) + if (!e.host) { - WSL_LOG( - "PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); - - ThrowIfPluginError(e.hooks.OnVMStarted(Session, Settings), Session->SessionId, e.name.c_str()); + continue; } + WSL_LOG("PluginOnVmStartedCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + + wil::unique_cotaskmem_string errorMessage; + WSL_LOG("PluginOnVmStartedBeginRpc", TraceLoggingValue(e.name.c_str(), "Plugin")); + HRESULT hr = e.host->OnVMStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + static_cast(Settings->CustomConfigurationFlags), + &errorMessage); + WSL_LOG("PluginOnVmStartedEndRpc", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(hr, "Result")); + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } } -void PluginManager::OnVmStopping(const WSLSessionInformation* Session) const +void PluginManager::OnVmStopping(const WSLSessionInformation* Session) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnVMStopping != nullptr) + if (!e.host) { - WSL_LOG( - "PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + continue; + } + WSL_LOG("PluginOnVmStoppingCall", TraceLoggingValue(e.name.c_str(), "Plugin"), TraceLoggingValue(Session->UserSid, "Sid")); + + const auto result = + e.host->OnVMStopping(Session->SessionId, Session->UserToken, static_cast(sidData.size()), sidData.data()); - const auto result = e.hooks.OnVMStopping(Session); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + if (IsHostCrash(result)) + { + LOG_HR_MSG(result, "Plugin host crashed, skipping OnVmStopping for: '%ls'", e.name.c_str()); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } void PluginManager::OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnDistributionStarted != nullptr) + if (!e.host) { - WSL_LOG( - "PluginOnDistroStartedCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - ThrowIfPluginError(e.hooks.OnDistributionStarted(Session, Distribution), Session->SessionId, e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnDistroStartedCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + wil::unique_cotaskmem_string errorMessage; + HRESULT hr = e.host->OnDistributionStarted( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version, + &errorMessage); + + ThrowIfPluginError(hr, errorMessage.get(), Session->SessionId, e.name.c_str()); } } -void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) const +void PluginManager::OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnDistributionStopping != nullptr) + if (!e.host) { - WSL_LOG( - "PluginOnDistroStoppingCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionStopping(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; + } + WSL_LOG( + "PluginOnDistroStoppingCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + const auto result = e.host->OnDistributionStopping( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PidNamespace, + Distribution->PackageFamilyName, + Distribution->InitPid, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) + { + LOG_HR_MSG(result, "Plugin host crashed, skipping OnDistributionStopping for: '%ls'", e.name.c_str()); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const +void PluginManager::OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnDistributionRegistered != nullptr) + if (!e.host) + { + continue; + } + WSL_LOG( + "PluginOnDistributionRegisteredCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + const auto result = e.host->OnDistributionRegistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) { - WSL_LOG( - "PluginOnDistributionRegisteredCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionRegistered(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + LOG_HR_MSG(result, "Plugin host crashed, skipping OnDistributionRegistered for: '%ls'", e.name.c_str()); + continue; } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) const +void PluginManager::OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* Distribution) { ExecutionContext context(Context::Plugin); + EnsureInitialized(); + + auto sidData = SerializeSid(Session->UserSid); for (const auto& e : m_plugins) { - if (e.hooks.OnDistributionUnregistered != nullptr) + if (!e.host) { - WSL_LOG( - "PluginOnDistributionUnregisteredCall", - TraceLoggingValue(e.name.c_str(), "Plugin"), - TraceLoggingValue(Session->UserSid, "Sid"), - TraceLoggingValue(Distribution->Id, "DistributionId")); - - const auto result = e.hooks.OnDistributionUnregistered(Session, Distribution); - LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); + continue; } + WSL_LOG( + "PluginOnDistributionUnregisteredCall", + TraceLoggingValue(e.name.c_str(), "Plugin"), + TraceLoggingValue(Session->UserSid, "Sid"), + TraceLoggingValue(Distribution->Id, "DistributionId")); + + const auto result = e.host->OnDistributionUnregistered( + Session->SessionId, + Session->UserToken, + static_cast(sidData.size()), + sidData.data(), + &Distribution->Id, + Distribution->Name, + Distribution->PackageFamilyName, + Distribution->Flavor, + Distribution->Version); + + if (IsHostCrash(result)) + { + LOG_HR_MSG(result, "Plugin host crashed, skipping OnDistributionUnregistered for: '%ls'", e.name.c_str()); + continue; + } + + LOG_IF_FAILED_MSG(result, "Error thrown from plugin: '%ls'", e.name.c_str()); } } -void PluginManager::ThrowIfPluginError(HRESULT Result, WSLSessionId Session, LPCWSTR Plugin) +void PluginManager::ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId Session, LPCWSTR Plugin) { - const auto message = std::move(g_pluginErrorMessage); - g_pluginErrorMessage.reset(); // std::move() doesn't clear the previous std::optional + // If the host process crashed, don't propagate as a fatal plugin error — + // log it and let the caller decide. The plugin is already dead. + if (IsHostCrash(Result)) + { + LOG_HR_MSG(Result, "Plugin host process crashed for plugin: '%ls'", Plugin); + return; + } if (FAILED(Result)) { - if (message.has_value()) + if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0') { - THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, message->c_str())); + THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginErrorWithMessage(Plugin, ErrorMessage)); } else { THROW_HR_WITH_USER_ERROR(Result, wsl::shared::Localization::MessageFatalPluginError(Plugin)); } } - else if (message.has_value()) + else if (ErrorMessage != nullptr && ErrorMessage[0] != L'\0') { THROW_HR_MSG(E_ILLEGAL_STATE_CHANGE, "Plugin '%ls' emitted an error message but returned success", Plugin); } } -void PluginManager::ThrowIfFatalPluginError() const +bool PluginManager::IsHostCrash(HRESULT hr) +{ + switch (hr) + { + case RPC_E_DISCONNECTED: + case RPC_E_SERVER_DIED: + case RPC_E_SERVER_DIED_DNE: + case CO_E_OBJNOTCONNECTED: + case RPC_S_SERVER_UNAVAILABLE: + case HRESULT_FROM_WIN32(RPC_S_SERVER_UNAVAILABLE): + case RPC_E_CALL_REJECTED: + return true; + default: + return false; + } +} + +void PluginManager::ThrowIfFatalPluginError() { ExecutionContext context(Context::Plugin); + EnsureInitialized(); if (!m_pluginError.has_value()) { diff --git a/src/windows/service/exe/PluginManager.h b/src/windows/service/exe/PluginManager.h index 746a11fa4..aa66fa3cd 100644 --- a/src/windows/service/exe/PluginManager.h +++ b/src/windows/service/exe/PluginManager.h @@ -9,17 +9,53 @@ Module Name: Abstract: This file contains the PluginManager class definition. + Plugins are loaded out-of-process in wslpluginhost.exe via COM + to isolate the service from plugin crashes. --*/ #pragma once #include +#include #include #include #include "WslPluginApi.h" +#include "WslPluginHost.h" namespace wsl::windows::service { + +// +// IWslPluginHostCallback implementation — lives in the service process and +// handles API calls coming from the plugin host (MountFolder, ExecuteBinary, etc.) +// +class PluginHostCallbackImpl + : public Microsoft::WRL::RuntimeClass, IWslPluginHostCallback> +{ +public: + STDMETHODIMP MountFolder(_In_ DWORD SessionId, _In_ LPCWSTR WindowsPath, _In_ LPCWSTR LinuxPath, _In_ BOOL ReadOnly, _In_ LPCWSTR Name) override; + + STDMETHODIMP ExecuteBinary( + _In_ DWORD SessionId, _In_ LPCSTR Path, _In_ DWORD ArgumentCount, _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, _Out_ HANDLE* Socket) override; + + STDMETHODIMP ExecuteBinaryInDistribution( + _In_ DWORD SessionId, + _In_ const GUID* DistributionId, + _In_ LPCSTR Path, + _In_ DWORD ArgumentCount, + _In_reads_opt_(ArgumentCount) LPCSTR* Arguments, + _Out_ HANDLE* Socket) override; + + STDMETHODIMP PluginError(_In_ LPCWSTR UserMessage) override; +}; + +/// +/// Manages out-of-process plugin hosts (wslpluginhost.exe) via COM activation. +/// Each plugin DLL is loaded in a separate process to isolate the service from +/// plugin crashes. Communication uses IWslPluginHost (service → host) for lifecycle +/// notifications and IWslPluginHostCallback (host → service) for plugin API calls. +/// A job object ensures all hosts are terminated if the service exits unexpectedly. +/// class PluginManager { public: @@ -30,6 +66,7 @@ class PluginManager }; PluginManager() = default; + ~PluginManager(); PluginManager(const PluginManager&) = delete; PluginManager& operator=(const PluginManager&) = delete; @@ -38,26 +75,37 @@ class PluginManager void LoadPlugins(); void OnVmStarted(const WSLSessionInformation* Session, const WSLVmCreationSettings* Settings); - void OnVmStopping(const WSLSessionInformation* Session) const; + void OnVmStopping(const WSLSessionInformation* Session); void OnDistributionStarted(const WSLSessionInformation* Session, const WSLDistributionInformation* distro); - void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro) const; - void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; - void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro) const; - void ThrowIfFatalPluginError() const; + void OnDistributionStopping(const WSLSessionInformation* Session, const WSLDistributionInformation* distro); + void OnDistributionRegistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro); + void OnDistributionUnregistered(const WSLSessionInformation* Session, const WslOfflineDistributionInformation* distro); + void ThrowIfFatalPluginError(); private: - void LoadPlugin(LPCWSTR Name, LPCWSTR Path); - static void ThrowIfPluginError(HRESULT Result, WSLSessionId session, LPCWSTR Plugin); - - struct LoadedPlugin + struct OutOfProcPlugin { - wil::unique_hmodule module; + Microsoft::WRL::ComPtr host; std::wstring name; - WSLPluginHooksV1 hooks{}; + std::wstring path; }; - std::vector m_plugins; + void LoadPlugin(OutOfProcPlugin& plugin); + void EnsureInitialized(); + void EnsureJobObjectCreated(); + static void ThrowIfPluginError(HRESULT Result, LPWSTR ErrorMessage, WSLSessionId session, LPCWSTR Plugin); + static std::vector SerializeSid(PSID Sid); + static bool IsHostCrash(HRESULT hr); + + std::once_flag m_initOnce; + std::vector m_plugins; + Microsoft::WRL::ComPtr m_callback; std::optional m_pluginError; + + // Job object that automatically terminates all plugin host processes + // when wslservice exits or crashes (JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE). + std::once_flag m_jobObjectOnce; + wil::unique_handle m_jobObject; }; } // namespace wsl::windows::service \ No newline at end of file diff --git a/src/windows/service/inc/CMakeLists.txt b/src/windows/service/inc/CMakeLists.txt index f742b7d71..c9ad923bb 100644 --- a/src/windows/service/inc/CMakeLists.txt +++ b/src/windows/service/inc/CMakeLists.txt @@ -1,2 +1,4 @@ add_idl(wslserviceidl "wslservice.idl" "windowsdefs.idl") -set_target_properties(wslserviceidl PROPERTIES FOLDER windows) \ No newline at end of file +add_idl(wslpluginhostidl "WslPluginHost.idl" "") +set_target_properties(wslserviceidl PROPERTIES FOLDER windows) +set_target_properties(wslpluginhostidl PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/service/inc/WslPluginHost.idl b/src/windows/service/inc/WslPluginHost.idl new file mode 100644 index 000000000..9a72261a1 --- /dev/null +++ b/src/windows/service/inc/WslPluginHost.idl @@ -0,0 +1,162 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslPluginHost.idl + +Abstract: + + This file contains the COM interface definitions for out-of-process + plugin hosting. IWslPluginHost is implemented by the plugin host process + and called by the service. IWslPluginHostCallback is implemented by the + service and called by the plugin host when a plugin invokes API functions. + +--*/ + +import "unknwn.idl"; +import "wtypes.idl"; + +cpp_quote("const GUID CLSID_WslPluginHost = {0x7a1d2c3e, 0x4b5f, 0x6a7d, {0x8e, 0x9f, 0x0a, 0x1b, 0x2c, 0x3d, 0x4e, 0x5f}};") +cpp_quote("#ifdef __cplusplus") +cpp_quote("class DECLSPEC_UUID(\"7a1d2c3e-4b5f-6a7d-8e9f-0a1b2c3d4e5f\") WslPluginHost;") +cpp_quote("#endif") + +// +// IWslPluginHostCallback - implemented by the service, called by the plugin host +// when a plugin invokes WSLPluginAPIV1 functions (MountFolder, ExecuteBinary, etc.) +// + +[ + uuid(A2B3C4D5-E6F7-4890-AB12-CD34EF56A789), + pointer_default(unique), + object +] +interface IWslPluginHostCallback : IUnknown +{ + HRESULT MountFolder( + [in] DWORD SessionId, + [in, string] LPCWSTR WindowsPath, + [in, string] LPCWSTR LinuxPath, + [in] BOOL ReadOnly, + [in, string] LPCWSTR Name); + + HRESULT ExecuteBinary( + [in] DWORD SessionId, + [in, string] LPCSTR Path, + [in] DWORD ArgumentCount, + [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, + [out, system_handle(sh_socket)] HANDLE* Socket); + + HRESULT ExecuteBinaryInDistribution( + [in] DWORD SessionId, + [in] const GUID* DistributionId, + [in, string] LPCSTR Path, + [in] DWORD ArgumentCount, + [in, unique, size_is(ArgumentCount), string] LPCSTR* Arguments, + [out, system_handle(sh_socket)] HANDLE* Socket); + + HRESULT PluginError( + [in, string] LPCWSTR UserMessage); +}; + +// +// IWslPluginHost - implemented by the plugin host process, called by the service +// to deliver lifecycle notifications to the plugin. +// + +[ + uuid(B3C4D5E6-F7A8-4901-BC23-DE45FA67B890), + pointer_default(unique), + object +] +interface IWslPluginHost : IUnknown +{ + // + // Initialize the plugin host: load the plugin DLL and call its entry point. + // The Callback interface is used by the plugin to call back into the service. + // + + HRESULT Initialize( + [in] IWslPluginHostCallback* Callback, + [in, string] LPCWSTR PluginPath, + [in, string] LPCWSTR PluginName); + + // + // Returns a handle to this COM server process. Used by the service to add + // the plugin host to a job object for automatic cleanup on service exit. + // + + HRESULT GetProcessHandle( + [out, system_handle(sh_process)] HANDLE* ProcessHandle); + + // + // Lifecycle hook dispatchers - mirror WSLPluginHooksV1. + // UserToken is duplicated into the host process by the service before calling. + // UserSid is serialized as a byte array. + // + + HRESULT OnVMStarted( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] DWORD CustomConfigurationFlags, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnVMStopping( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData); + + HRESULT OnDistributionStarted( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in] ULONGLONG PidNamespace, + [in, unique, string] LPCWSTR PackageFamilyName, + [in] DWORD InitPid, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version, + [out, string] LPWSTR* ErrorMessage); + + HRESULT OnDistributionStopping( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in] ULONGLONG PidNamespace, + [in, unique, string] LPCWSTR PackageFamilyName, + [in] DWORD InitPid, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); + + HRESULT OnDistributionRegistered( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in, unique, string] LPCWSTR PackageFamilyName, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); + + HRESULT OnDistributionUnregistered( + [in] DWORD SessionId, + [in, system_handle(sh_token)] HANDLE UserToken, + [in] DWORD SidSize, + [in, size_is(SidSize)] BYTE* SidData, + [in] const GUID* DistributionId, + [in, string] LPCWSTR DistributionName, + [in, unique, string] LPCWSTR PackageFamilyName, + [in, unique, string] LPCWSTR Flavor, + [in, unique, string] LPCWSTR Version); +}; diff --git a/src/windows/service/stub/CMakeLists.txt b/src/windows/service/stub/CMakeLists.txt index 3c754cc59..11558637f 100644 --- a/src/windows/service/stub/CMakeLists.txt +++ b/src/windows/service/stub/CMakeLists.txt @@ -1,6 +1,8 @@ set(SOURCES ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslservice_i_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslservice_p_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_i_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/WslPluginHost_p_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c ${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.def ${CMAKE_CURRENT_LIST_DIR}/WslServiceProxyStub.rc) @@ -8,6 +10,6 @@ set(SOURCES set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE) add_library(wslserviceproxystub SHARED ${SOURCES}) -add_dependencies(wslserviceproxystub wslserviceidl) +add_dependencies(wslserviceproxystub wslserviceidl wslpluginhostidl) target_link_libraries(wslserviceproxystub ${COMMON_LINK_LIBRARIES}) set_target_properties(wslserviceproxystub PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslinstall/DllMain.cpp b/src/windows/wslinstall/DllMain.cpp index c9ab7d800..3bbf725ae 100644 --- a/src/windows/wslinstall/DllMain.cpp +++ b/src/windows/wslinstall/DllMain.cpp @@ -827,7 +827,7 @@ void RegisterLspCategoriesImpl(DWORD flags) const auto installRoot = wsl::windows::common::wslutil::GetMsiPackagePath(); THROW_HR_IF(E_INVALIDARG, !installRoot.has_value()); - for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe", L"wslservice.exe"}) + for (const auto& e : {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe", L"wslservice.exe"}) { auto executable = installRoot.value() + e; INT error{}; diff --git a/src/windows/wslpluginhost/CMakeLists.txt b/src/windows/wslpluginhost/CMakeLists.txt new file mode 100644 index 000000000..ce8dc8f77 --- /dev/null +++ b/src/windows/wslpluginhost/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(exe) diff --git a/src/windows/wslpluginhost/exe/CMakeLists.txt b/src/windows/wslpluginhost/exe/CMakeLists.txt new file mode 100644 index 000000000..fc16ec3bc --- /dev/null +++ b/src/windows/wslpluginhost/exe/CMakeLists.txt @@ -0,0 +1,25 @@ +set(SOURCES + main.cpp + main.rc + PluginHost.cpp) + +set(HEADERS + PluginHost.h + resource.h) + +add_executable(wslpluginhost WIN32 ${SOURCES} ${HEADERS}) +add_dependencies(wslpluginhost + wslpluginhostidl + common) + +target_include_directories(wslpluginhost PRIVATE + ${CMAKE_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) + +target_link_libraries(wslpluginhost + ${COMMON_LINK_LIBRARIES} + ${MSI_LINK_LIBRARIES} + common + ole32.lib) + +target_precompile_headers(wslpluginhost REUSE_FROM common) +set_target_properties(wslpluginhost PROPERTIES FOLDER windows) diff --git a/src/windows/wslpluginhost/exe/PluginHost.cpp b/src/windows/wslpluginhost/exe/PluginHost.cpp new file mode 100644 index 000000000..e3271ed92 --- /dev/null +++ b/src/windows/wslpluginhost/exe/PluginHost.cpp @@ -0,0 +1,459 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + PluginHost.cpp + +Abstract: + + This file contains the IWslPluginHost COM class implementation. + It loads a plugin DLL in this (host) process and forwards lifecycle + notifications from the service to the plugin, while routing plugin API + callbacks back to the service via IWslPluginHostCallback. + +--*/ + +#include "precomp.h" +#include "PluginHost.h" +#include "install.h" + +using namespace wsl::windows::pluginhost; + +// Defined in main.cpp — part of the COM local server lifecycle. +extern void ReleaseComRef(); + +PluginHost* wsl::windows::pluginhost::g_pluginHost = nullptr; + +// Thread ID of the thread currently dispatching a plugin hook. +// Only that thread may call PluginError. Using thread ID instead of +// thread_local to avoid TLS initialization issues across DLL/EXE boundaries. +static std::atomic g_hookThreadId{0}; + +PluginHost::~PluginHost() +{ + // Clear globally reachable state so late plugin API calls fail with + // E_UNEXPECTED instead of dereferencing freed memory. + if (g_pluginHost == this) + { + g_pluginHost = nullptr; + } + + // Module unloads automatically via wil::unique_hmodule destructor. + + // Decrement the COM server reference count. When it reaches zero, + // the process will exit. Matches AddComRef() in PluginHostFactory::CreateInstance. + ReleaseComRef(); +} + +// --- IWslPluginHost implementation --- + +STDMETHODIMP PluginHost::Initialize(_In_ IWslPluginHostCallback* Callback, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) +try +{ + RETURN_HR_IF(E_INVALIDARG, Callback == nullptr || PluginPath == nullptr || PluginName == nullptr); + RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, m_module.is_valid()); // Already initialized + + m_callback = Callback; + m_pluginName = PluginName; + + // Validate the plugin signature before loading it. + // Keep the file handle open to prevent TOCTOU (swap between validation and load). + wil::unique_hfile signatureHandle; + if constexpr (wsl::shared::OfficialBuild) + { + signatureHandle = wsl::windows::common::install::ValidateFileSignature(PluginPath); + } + + m_module.reset(LoadLibrary(PluginPath)); + THROW_LAST_ERROR_IF_NULL(m_module); + signatureHandle.reset(); // Safe to release after LoadLibrary has mapped the DLL + + const auto entryPoint = + reinterpret_cast(GetProcAddress(m_module.get(), GSL_STRINGIFY(WSLPLUGINAPI_ENTRYPOINTV1))); + THROW_LAST_ERROR_IF_NULL(entryPoint); + + // Build the API vtable that the plugin will use to call back into the service. + // The function pointers are static methods on this class that route through g_pluginHost. + static const WSLPluginAPIV1 api = { + {wsl::shared::VersionMajor, wsl::shared::VersionMinor, wsl::shared::VersionRevision}, + &LocalMountFolder, + &LocalExecuteBinary, + &LocalPluginError, + &LocalExecuteBinaryInDistribution}; + + g_pluginHost = this; + HRESULT hr = entryPoint(&api, &m_hooks); + + if (FAILED(hr)) + { + g_pluginHost = nullptr; + RETURN_HR_MSG(hr, "Plugin entry point failed: '%ls'", PluginPath); + } + + return S_OK; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::GetProcessHandle(_Out_ HANDLE* ProcessHandle) +try +{ + RETURN_HR_IF(E_POINTER, ProcessHandle == nullptr); + *ProcessHandle = nullptr; + + wil::unique_handle process(OpenProcess(PROCESS_SET_QUOTA | PROCESS_TERMINATE, FALSE, GetCurrentProcessId())); + RETURN_LAST_ERROR_IF_NULL(process); + + // COM's system_handle(sh_process) marshaling will duplicate this into the caller's process. + *ProcessHandle = process.release(); + return S_OK; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnVMStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ DWORD CustomConfigurationFlags, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.OnVMStarted == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + WSLVmCreationSettings settings{}; + settings.CustomConfigurationFlags = static_cast(CustomConfigurationFlags); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnVMStarted(&ctx.info, &settings); + + // If the plugin called PluginError during the hook, return the message. + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData) +try +{ + if (m_hooks.OnVMStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnVMStopping(&ctx.info); + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) +try +{ + RETURN_HR_IF(E_POINTER, ErrorMessage == nullptr); + *ErrorMessage = nullptr; + + if (m_hooks.OnDistributionStarted == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WSLDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PidNamespace = PidNamespace; + distro.PackageFamilyName = PackageFamilyName ? PackageFamilyName : L""; + distro.InitPid = InitPid; + distro.Flavor = Flavor ? Flavor : L""; + distro.Version = Version ? Version : L""; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + m_pluginErrorMessage.reset(); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionStarted(&ctx.info, &distro); + + if (m_pluginErrorMessage.has_value()) + { + *ErrorMessage = wil::make_cotaskmem_string(m_pluginErrorMessage->c_str()).release(); + m_pluginErrorMessage.reset(); + } + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionStopping( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionStopping == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WSLDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PidNamespace = PidNamespace; + distro.PackageFamilyName = PackageFamilyName ? PackageFamilyName : L""; + distro.InitPid = InitPid; + distro.Flavor = Flavor ? Flavor : L""; + distro.Version = Version ? Version : L""; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionStopping(&ctx.info, &distro); + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionRegistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionRegistered == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WslOfflineDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PackageFamilyName = PackageFamilyName ? PackageFamilyName : L""; + distro.Flavor = Flavor ? Flavor : L""; + distro.Version = Version ? Version : L""; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionRegistered(&ctx.info, &distro); + + return hr; +} +CATCH_RETURN(); + +STDMETHODIMP PluginHost::OnDistributionUnregistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) +try +{ + if (m_hooks.OnDistributionUnregistered == nullptr) + { + return S_OK; + } + + auto ctx = BuildSessionContext(SessionId, UserToken, SidSize, SidData); + + WslOfflineDistributionInformation distro{}; + distro.Id = *DistributionId; + distro.Name = DistributionName; + distro.PackageFamilyName = PackageFamilyName ? PackageFamilyName : L""; + distro.Flavor = Flavor ? Flavor : L""; + distro.Version = Version ? Version : L""; + + std::lock_guard hookLock(m_hookLock); + g_hookThreadId.store(GetCurrentThreadId()); + auto cleanup = wil::scope_exit([&] { g_hookThreadId.store(0); }); + + HRESULT hr = m_hooks.OnDistributionUnregistered(&ctx.info, &distro); + + return hr; +} +CATCH_RETURN(); + +// --- Helpers --- + +PluginHost::SessionContext PluginHost::BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData) +{ + SessionContext ctx{}; + ctx.info.SessionId = SessionId; + + // COM's system_handle marshaling automatically duplicated the token into our process. + // Take ownership of the duplicated handle. + ctx.tokenHandle.reset(UserToken); + ctx.info.UserToken = ctx.tokenHandle.get(); + + // Reconstruct the SID from the serialized bytes. + ctx.sidBuffer.assign(SidData, SidData + SidSize); + ctx.info.UserSid = reinterpret_cast(ctx.sidBuffer.data()); + + return ctx; +} + +// --- Static API stubs --- + +HRESULT CALLBACK PluginHost::LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name) +{ + if (g_pluginHost == nullptr || g_pluginHost->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + auto hr = g_pluginHost->m_callback->MountFolder(Session, WindowsPath, LinuxPath, ReadOnly, Name); + return hr; +} + +HRESULT CALLBACK PluginHost::LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +{ + if (g_pluginHost == nullptr || g_pluginHost->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_POINTER, Socket == nullptr); + + // Count arguments (NULL-terminated array) + DWORD count = 0; + if (Arguments != nullptr) + { + for (const LPCSTR* p = Arguments; *p != nullptr; ++p) + { + ++count; + } + } + + HANDLE socketResult = nullptr; + HRESULT hr = g_pluginHost->m_callback->ExecuteBinary(Session, Path, count, Arguments, &socketResult); + + if (SUCCEEDED(hr)) + { + // COM's system_handle marshaling duplicated the socket into our process. + *Socket = reinterpret_cast(socketResult); + } + else if (socketResult != nullptr) + { + if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR) + { + LOG_WIN32(WSAGetLastError()); + } + } + + return hr; +} + +HRESULT CALLBACK PluginHost::LocalPluginError(LPCWSTR UserMessage) +{ + if (g_pluginHost == nullptr) + { + // Not on a hook thread — PluginError must only be called + // synchronously from within OnVMStarted/OnDistributionStarted. + return E_ILLEGAL_METHOD_CALL; + } + + RETURN_HR_IF(E_INVALIDARG, UserMessage == nullptr); + RETURN_HR_IF(E_ILLEGAL_METHOD_CALL, GetCurrentThreadId() != g_hookThreadId.load()); + RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, g_pluginHost->m_pluginErrorMessage.has_value()); + + // Store locally — returned to service alongside the hook HRESULT. + g_pluginHost->m_pluginErrorMessage.emplace(UserMessage); + return S_OK; +} + +HRESULT CALLBACK PluginHost::LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket) +{ + if (g_pluginHost == nullptr || g_pluginHost->m_callback == nullptr) + { + return E_UNEXPECTED; + } + + RETURN_HR_IF(E_INVALIDARG, Distro == nullptr); + RETURN_HR_IF(E_POINTER, Socket == nullptr); + + DWORD count = 0; + if (Arguments != nullptr) + { + for (const LPCSTR* p = Arguments; *p != nullptr; ++p) + { + ++count; + } + } + + HANDLE socketResult = nullptr; + HRESULT hr = g_pluginHost->m_callback->ExecuteBinaryInDistribution(Session, Distro, Path, count, Arguments, &socketResult); + + if (SUCCEEDED(hr)) + { + *Socket = reinterpret_cast(socketResult); + } + else if (socketResult != nullptr) + { + if (closesocket(reinterpret_cast(socketResult)) == SOCKET_ERROR) + { + LOG_WIN32(WSAGetLastError()); + } + } + + return hr; +} diff --git a/src/windows/wslpluginhost/exe/PluginHost.h b/src/windows/wslpluginhost/exe/PluginHost.h new file mode 100644 index 000000000..266ae0d29 --- /dev/null +++ b/src/windows/wslpluginhost/exe/PluginHost.h @@ -0,0 +1,134 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + PluginHost.h + +Abstract: + + This file contains the COM class that implements IWslPluginHost. + It loads a plugin DLL and dispatches lifecycle notifications to it, + forwarding plugin API callbacks to the service via IWslPluginHostCallback. + +--*/ + +#pragma once + +#include "WslPluginApi.h" +#include "WslPluginHost.h" + +namespace wsl::windows::pluginhost { + +class PluginHost : public Microsoft::WRL::RuntimeClass, IWslPluginHost> +{ +public: + PluginHost() = default; + ~PluginHost(); + + PluginHost(const PluginHost&) = delete; + PluginHost& operator=(const PluginHost&) = delete; + + // IWslPluginHost + STDMETHODIMP Initialize(_In_ IWslPluginHostCallback* Callback, _In_ LPCWSTR PluginPath, _In_ LPCWSTR PluginName) override; + STDMETHODIMP GetProcessHandle(_Out_ HANDLE* ProcessHandle) override; + + STDMETHODIMP OnVMStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ DWORD CustomConfigurationFlags, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnVMStopping(_In_ DWORD SessionId, _In_ HANDLE UserToken, _In_ DWORD SidSize, _In_reads_(SidSize) BYTE* SidData) override; + + STDMETHODIMP OnDistributionStarted( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version, + _Outptr_result_maybenull_ LPWSTR* ErrorMessage) override; + + STDMETHODIMP OnDistributionStopping( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_ ULONGLONG PidNamespace, + _In_opt_ LPCWSTR PackageFamilyName, + _In_ DWORD InitPid, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + + STDMETHODIMP OnDistributionRegistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + + STDMETHODIMP OnDistributionUnregistered( + _In_ DWORD SessionId, + _In_ HANDLE UserToken, + _In_ DWORD SidSize, + _In_reads_(SidSize) BYTE* SidData, + _In_ const GUID* DistributionId, + _In_ LPCWSTR DistributionName, + _In_opt_ LPCWSTR PackageFamilyName, + _In_opt_ LPCWSTR Flavor, + _In_opt_ LPCWSTR Version) override; + +private: + // Build a WSLSessionInformation struct from the marshaled parameters. + // The returned struct and its SID allocation are valid for the lifetime of the wil::unique_handle. + struct SessionContext + { + WSLSessionInformation info{}; + wil::unique_handle tokenHandle; + std::vector sidBuffer; + }; + + SessionContext BuildSessionContext(DWORD SessionId, HANDLE UserToken, DWORD SidSize, BYTE* SidData); + + // Local stubs for the WSLPluginAPIV1 function pointers. + // These forward calls to the service via m_callback. + static HRESULT CALLBACK LocalMountFolder(WSLSessionId Session, LPCWSTR WindowsPath, LPCWSTR LinuxPath, BOOL ReadOnly, LPCWSTR Name); + static HRESULT CALLBACK LocalExecuteBinary(WSLSessionId Session, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); + static HRESULT CALLBACK LocalPluginError(LPCWSTR UserMessage); + static HRESULT CALLBACK LocalExecuteBinaryInDistribution(WSLSessionId Session, const GUID* Distro, LPCSTR Path, LPCSTR* Arguments, SOCKET* Socket); + + wil::unique_hmodule m_module; + std::wstring m_pluginName; + WSLPluginHooksV1 m_hooks{}; + Microsoft::WRL::ComPtr m_callback; + + // Serializes hook dispatch so m_pluginErrorMessage and g_hookThreadId + // are not raced when multiple sessions call hooks concurrently (MTA). + std::mutex m_hookLock; + + // Error message captured by LocalPluginError during hook execution + std::optional m_pluginErrorMessage; +}; + +// Process-wide pointer to the single PluginHost instance. Safe because +// REGCLS_SINGLEUSE guarantees one PluginHost per wslpluginhost.exe process. +// This allows plugin DLLs to call API functions from any thread, not just +// the thread dispatching the current hook. +extern PluginHost* g_pluginHost; + +} // namespace wsl::windows::pluginhost diff --git a/src/windows/wslpluginhost/exe/main.cpp b/src/windows/wslpluginhost/exe/main.cpp new file mode 100644 index 000000000..7ba20c37a --- /dev/null +++ b/src/windows/wslpluginhost/exe/main.cpp @@ -0,0 +1,111 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.cpp + +Abstract: + + This file contains the entry point for wslpluginhost.exe. + This process acts as a COM local server that loads a single WSL plugin DLL + in an isolated process, preventing a buggy or malicious plugin from crashing + the main WSL service. + + The host is activated through COM local-server activation. It registers its + COM class factory, serves activation requests, and remains alive until all + COM server-process references are released, at which point it exits. + +--*/ + +#include "precomp.h" +#include "PluginHost.h" +#include "WslPluginHost.h" + +using namespace Microsoft::WRL; + +static wil::unique_event g_exitEvent(wil::EventOptions::ManualReset); + +void AddComRef() +{ + CoAddRefServerProcess(); +} + +void ReleaseComRef() +{ + if (CoReleaseServerProcess() == 0) + { + g_exitEvent.SetEvent(); + } +} + +class PluginHostFactory : public RuntimeClass, IClassFactory> +{ +public: + STDMETHODIMP CreateInstance(_In_opt_ IUnknown* pUnkOuter, _In_ REFIID riid, _Outptr_ void** ppCreated) override + try + { + RETURN_HR_IF_NULL(E_POINTER, ppCreated); + *ppCreated = nullptr; + RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr); + + auto host = Make(); + RETURN_IF_NULL_ALLOC(host); + + AddComRef(); + auto releaseOnFailure = wil::scope_exit([] { ReleaseComRef(); }); + RETURN_IF_FAILED(host.CopyTo(riid, ppCreated)); + releaseOnFailure.release(); + return S_OK; + } + CATCH_RETURN(); + + STDMETHODIMP LockServer(BOOL lock) noexcept override + { + if (lock) + { + AddComRef(); + } + else + { + ReleaseComRef(); + } + return S_OK; + } +}; + +int WINAPI wWinMain(_In_ HINSTANCE, _In_opt_ HINSTANCE, _In_ LPWSTR, _In_ int) +try +{ + wsl::windows::common::wslutil::ConfigureCrt(); + wsl::windows::common::wslutil::InitializeWil(); + + // Initialize logging. + WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild); + auto cleanupTracing = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [] { WslTraceLoggingUninitialize(); }); + + auto coInit = wil::CoInitializeEx(COINIT_MULTITHREADED); + wsl::windows::common::wslutil::CoInitializeSecurity(); + + // Initialize Winsock — plugins receive sockets from ExecuteBinary and need + // Winsock to be initialized for recv/send/closesocket to work. + WSADATA wsaData{}; + THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &wsaData)); + auto cleanupWinsock = wil::scope_exit([] { WSACleanup(); }); + + // Register the class factory so the service can CoCreateInstance on us. + DWORD cookie = 0; + auto factory = Make(); + THROW_IF_NULL_ALLOC(factory); + + THROW_IF_FAILED(::CoRegisterClassObject(CLSID_WslPluginHost, factory.Get(), CLSCTX_LOCAL_SERVER, REGCLS_SINGLEUSE, &cookie)); + + auto revokeOnExit = wil::scope_exit([&]() { ::CoRevokeClassObject(cookie); }); + + // Wait until the COM reference count drops to zero. + g_exitEvent.wait(); + + return 0; +} +CATCH_RETURN(); diff --git a/src/windows/wslpluginhost/exe/main.rc b/src/windows/wslpluginhost/exe/main.rc new file mode 100644 index 000000000..84d6b8b8c --- /dev/null +++ b/src/windows/wslpluginhost/exe/main.rc @@ -0,0 +1,25 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.rc + +Abstract: + + This file contains resources for wslpluginhost. + +--*/ + +#include +#include "resource.h" +#include "wslversioninfo.h" + +#define VER_INTERNALNAME_STR "wslpluginhost.exe" +#define VER_ORIGINALFILENAME_STR "wslpluginhost.exe" + +#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux" +ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\..\Images\wsl.ico" + +#include diff --git a/src/windows/wslpluginhost/exe/resource.h b/src/windows/wslpluginhost/exe/resource.h new file mode 100644 index 000000000..355437a44 --- /dev/null +++ b/src/windows/wslpluginhost/exe/resource.h @@ -0,0 +1,15 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + resource.h + +Abstract: + + This file contains resource declarations for wslpluginhost.exe + +--*/ + +#define ID_ICON 1 diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index 99cf3cf45..f1b238f52 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -823,7 +823,14 @@ void CreateProcessCrashReport(DWORD Pid, LPCWSTR ImageName, LPCWSTR EventName) void CreateWerReports() { static const std::set WslProcesses{ - L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslservice.exe", L"wslg.exe", L"vmcompute.exe", L"vmwp.exe"}; + L"wsl.exe", + L"wslhost.exe", + L"wslrelay.exe", + L"wslpluginhost.exe", + L"wslservice.exe", + L"wslg.exe", + L"vmcompute.exe", + L"vmwp.exe"}; auto PrivilegeState = wsl::windows::common::security::AcquirePrivilege(SE_DEBUG_NAME); const std::wstring EventName = L"WslTestHang-" + g_pipelineBuildId; diff --git a/test/windows/Common.h b/test/windows/Common.h index 1c65d2790..2872b47b3 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -112,6 +112,7 @@ Module Name: TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"WslServiceProxyStub.dll") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslhost.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslrelay.exe") \ + TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslpluginhost.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslconfig.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wsl.exe") \ TEST_CLASS_PROPERTY(L"BinaryUnderTest", L"wslg.exe") \ diff --git a/test/windows/InstallerTests.cpp b/test/windows/InstallerTests.cpp index 423b05815..b78a4f776 100644 --- a/test/windows/InstallerTests.cpp +++ b/test/windows/InstallerTests.cpp @@ -800,7 +800,7 @@ class InstallerTests return flags; }; - const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslg.exe"}; + const std::vector executables = {L"wsl.exe", L"wslhost.exe", L"wslrelay.exe", L"wslpluginhost.exe", L"wslg.exe"}; for (const auto& e : executables) { auto fullPath = installPath.value() + e; diff --git a/test/windows/PluginTests.cpp b/test/windows/PluginTests.cpp index 8b42f941e..e626f0c27 100644 --- a/test/windows/PluginTests.cpp +++ b/test/windows/PluginTests.cpp @@ -334,11 +334,10 @@ class PluginTests WSL1_TEST_METHOD(SuccessWSL1) { - constexpr auto ExpectedOutput = LR"(Plugin loaded. TestMode=1)"; - + // Plugins are not loaded for WSL1-only sessions (no VM, no plugin hooks). + // Verify that WSL1 works without plugins. ConfigurePlugin(PluginTestType::Success); StartWsl(0); - ValidateLogFile(ExpectedOutput); } WSL2_TEST_METHOD(LoadFailureFatalWSL2) @@ -357,13 +356,10 @@ class PluginTests WSL1_TEST_METHOD(LoadFailureNonFatalWSL1) { - constexpr auto ExpectedOutput = - LR"(Plugin loaded. TestMode=2 - OnLoad: E_UNEXPECTED)"; - + // Plugins are not loaded for WSL1-only sessions, so a plugin that + // would fail to load on WSL2 has no effect on WSL1. ConfigurePlugin(PluginTestType::FailToLoad); StartWsl(0); - ValidateLogFile(ExpectedOutput); } WSL2_TEST_METHOD(VmStartFailure)