diff --git a/localization/strings/en-US/Resources.resw b/localization/strings/en-US/Resources.resw index 42030d93b..9387ab995 100644 --- a/localization/strings/en-US/Resources.resw +++ b/localization/strings/en-US/Resources.resw @@ -1935,6 +1935,15 @@ Usage: Session termination failed: '{}' {FixedPlaceholder="{}"}Command line arguments, file names and string inserts should not be translated + + Default session not found + + + Failed to open default session + + + Default session termination failed + {} exited with: {} {FixedPlaceholder="{}"}{FixedPlaceholder="{}"}Command line arguments, file names and string inserts should not be translated diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt index d9b31b135..a8ce02201 100644 --- a/src/windows/common/CMakeLists.txt +++ b/src/windows/common/CMakeLists.txt @@ -49,6 +49,7 @@ set(SOURCES WslTelemetry.cpp wslutil.cpp install.cpp + WSLCUserSettings.cpp ) set(HEADERS @@ -130,11 +131,19 @@ set(HEADERS WslSecurity.h WslTelemetry.h wslutil.h + EnumVariantMap.h + WSLCUserSettings.h + WSLCSessionDefaults.h ) add_library(common STATIC ${SOURCES} ${HEADERS}) -add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl) +add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl yaml-cpp) target_precompile_headers(common PRIVATE precomp.h) set_target_properties(common PROPERTIES FOLDER windows) target_include_directories(common PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/../service/mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) + +# WSLCUserSettings.cpp uses yaml-cpp headers. +set_source_files_properties(WSLCUserSettings.cpp PROPERTIES + INCLUDE_DIRECTORIES "${yaml-cpp_SOURCE_DIR}/include" + COMPILE_DEFINITIONS "YAML_CPP_STATIC_DEFINE") diff --git a/src/windows/wslc/core/EnumVariantMap.h b/src/windows/common/EnumVariantMap.h similarity index 100% rename from src/windows/wslc/core/EnumVariantMap.h rename to src/windows/common/EnumVariantMap.h diff --git a/src/windows/common/WSLCSessionDefaults.h b/src/windows/common/WSLCSessionDefaults.h new file mode 100644 index 000000000..5d502c6ad --- /dev/null +++ b/src/windows/common/WSLCSessionDefaults.h @@ -0,0 +1,25 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLCSessionDefaults.h + +Abstract: + + Shared constants for WSLc session naming and storage. + +--*/ +#pragma once + +#include + +namespace wsl::windows::wslc { + +inline constexpr const wchar_t DefaultSessionName[] = L"wslc-cli"; +inline constexpr const wchar_t DefaultAdminSessionName[] = L"wslc-cli-admin"; +inline constexpr const wchar_t DefaultStorageSubPath[] = L"wslc\\sessions"; +inline constexpr uint32_t DefaultBootTimeoutMs = 30000; + +} // namespace wsl::windows::wslc diff --git a/src/windows/wslc/settings/UserSettings.cpp b/src/windows/common/WSLCUserSettings.cpp similarity index 98% rename from src/windows/wslc/settings/UserSettings.cpp rename to src/windows/common/WSLCUserSettings.cpp index 0e4ca688d..977f03635 100644 --- a/src/windows/wslc/settings/UserSettings.cpp +++ b/src/windows/common/WSLCUserSettings.cpp @@ -4,18 +4,23 @@ Copyright (c) Microsoft. All rights reserved. Module Name: - UserSettings.cpp + WSLCUserSettings.cpp Abstract: Implementation of UserSettings — YAML loading and validation. --*/ -#include "UserSettings.h" +#include "precomp.h" +#include "WSLCUserSettings.h" #include "filesystem.hpp" #include "string.hpp" #include "wslutil.h" + +#pragma warning(push) +#pragma warning(disable : 4251 4275) #include +#pragma warning(pop) #include #include #include @@ -25,7 +30,6 @@ using namespace wsl::windows::common::string; namespace wsl::windows::wslc::settings { -// Default settings file template — written on first run. // All entries are commented out; the values shown are the built-in defaults. // TODO: localization for comments needed? static constexpr std::string_view s_DefaultSettingsTemplate = diff --git a/src/windows/wslc/settings/UserSettings.h b/src/windows/common/WSLCUserSettings.h similarity index 97% rename from src/windows/wslc/settings/UserSettings.h rename to src/windows/common/WSLCUserSettings.h index a58b93ab4..d47da6b5e 100644 --- a/src/windows/wslc/settings/UserSettings.h +++ b/src/windows/common/WSLCUserSettings.h @@ -4,7 +4,7 @@ Copyright (c) Microsoft. All rights reserved. Module Name: - UserSettings.h + WSLCUserSettings.h Abstract: @@ -156,9 +156,7 @@ class UserSettings // Overwrites the settings file with the commented-out defaults template. void Reset() const; -protected: - // Loads settings from an explicit directory. Used by the singleton (via - // the private zero-arg constructor) and by test subclasses. + // Loads settings from an explicit directory. explicit UserSettings(const std::filesystem::path& settingsDir); ~UserSettings() = default; diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index d9a71e77f..779a87042 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -153,6 +153,8 @@ static const std::map g_commonErrors{ X(WSLC_E_VOLUME_NOT_FOUND), X(WSLC_E_CONTAINER_NOT_RUNNING), X(WSLC_E_CONTAINER_IS_RUNNING), + X(WSLC_E_SESSION_RESERVED), + X(WSLC_E_INVALID_SESSION_NAME), X_WIN32(RPC_S_SERVER_UNAVAILABLE), X_WIN32(ERROR_ELEVATION_REQUIRED)}; diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index 16d5b90b4..6198a4a6f 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -71,7 +71,8 @@ target_link_libraries(wslservice legacy_stdio_definitions VirtDisk.lib Winhttp.lib - Synchronization.lib) + Synchronization.lib + yaml-cpp) target_precompile_headers(wslservice REUSE_FROM common) set_target_properties(wslservice PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/service/exe/WSLCSessionManager.cpp b/src/windows/service/exe/WSLCSessionManager.cpp index 8a1edf5c0..ca6ff8641 100644 --- a/src/windows/service/exe/WSLCSessionManager.cpp +++ b/src/windows/service/exe/WSLCSessionManager.cpp @@ -29,12 +29,81 @@ Module Name: #include "WSLCSessionManager.h" #include "HcsVirtualMachine.h" +#include "WSLCUserSettings.h" +#include "WSLCSessionDefaults.h" #include "wslutil.h" +#include "filesystem.hpp" using wsl::windows::service::wslc::CallingProcessTokenInfo; using wsl::windows::service::wslc::HcsVirtualMachine; using wsl::windows::service::wslc::WSLCSessionManagerImpl; namespace wslutil = wsl::windows::common::wslutil; +namespace settings = wsl::windows::wslc::settings; + +namespace { + +// Session settings built server-side from the caller's settings.yaml. +struct SessionSettings +{ + std::wstring DisplayName; + std::wstring StoragePath; + WSLCSessionSettings Settings{}; + + NON_COPYABLE(SessionSettings); + NON_MOVABLE(SessionSettings); + + // Load user settings under impersonation. + static settings::UserSettings LoadUserSettings(HANDLE UserToken) + { + auto localAppData = wsl::windows::common::filesystem::GetLocalAppDataPath(UserToken); + auto runAsUser = wil::impersonate_token(UserToken); + return settings::UserSettings(localAppData / L"wslc"); + } + + // Default session: name and storage path determined from caller's token. + static std::unique_ptr Default(HANDLE UserToken, bool Elevated, const std::wstring& ResolvedName) + { + auto userSettings = LoadUserSettings(UserToken); + auto localAppData = wsl::windows::common::filesystem::GetLocalAppDataPath(UserToken); + + auto customPath = userSettings.Get(); + std::filesystem::path basePath = + customPath.empty() ? (localAppData / wsl::windows::wslc::DefaultStorageSubPath) : std::filesystem::path{customPath}; + auto storagePath = (basePath / ResolvedName).wstring(); + + return std::unique_ptr( + new SessionSettings(std::wstring(ResolvedName), std::move(storagePath), WSLCSessionStorageFlagsNone, userSettings)); + } + + // Custom session: caller provides name and storage path. + static SessionSettings Custom(HANDLE UserToken, LPCWSTR Name, LPCWSTR Path, WSLCSessionStorageFlags StorageFlags = WSLCSessionStorageFlagsNone) + { + auto userSettings = LoadUserSettings(UserToken); + return SessionSettings(Name, Path, StorageFlags, userSettings); + } + +private: + SessionSettings(std::wstring name, std::wstring path, WSLCSessionStorageFlags storageFlags, const settings::UserSettings& userSettings) : + DisplayName(std::move(name)), StoragePath(std::move(path)) + { + Settings.DisplayName = DisplayName.c_str(); + Settings.StoragePath = StoragePath.c_str(); + Settings.CpuCount = userSettings.Get(); + Settings.MemoryMb = userSettings.Get(); + Settings.MaximumStorageSizeMb = userSettings.Get(); + Settings.BootTimeoutMs = wsl::windows::wslc::DefaultBootTimeoutMs; + Settings.NetworkingMode = userSettings.Get(); + Settings.FeatureFlags = WslcFeatureFlagsNone; + WI_SetFlagIf(Settings.FeatureFlags, WslcFeatureFlagsDnsTunneling, userSettings.Get()); + WI_SetFlagIf( + Settings.FeatureFlags, + WslcFeatureFlagsVirtioFs, + userSettings.Get() == settings::HostFileShareMode::VirtioFs); + Settings.StorageFlags = storageFlags; + } +}; + +} // namespace WSLCSessionManagerImpl::~WSLCSessionManagerImpl() { @@ -51,22 +120,39 @@ WSLCSessionManagerImpl::~WSLCSessionManagerImpl() void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) { - // Ensure that the session display name is non-null and not too long. - THROW_HR_IF(E_INVALIDARG, Settings->DisplayName == nullptr); - THROW_HR_IF(E_INVALIDARG, wcslen(Settings->DisplayName) >= std::size(WSLCSessionInformation{}.DisplayName)); - THROW_HR_IF_MSG( - E_INVALIDARG, - WI_IsAnyFlagSet(Settings->StorageFlags, ~WSLCSessionStorageFlagsValid), - "Invalid storage flags: %i", - Settings->StorageFlags); - auto tokenInfo = GetCallingProcessTokenInfo(); + const auto callerToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + + // Resolve display name upfront (for both default and custom sessions). + std::wstring resolvedDisplayName; + if (Settings == nullptr) + { + // Default session: name determined from token, qualified with username. + resolvedDisplayName = ResolveDefaultSessionName(tokenInfo); + Flags = WSLCSessionFlagsOpenExisting | WSLCSessionFlagsPersistent; + } + else + { + THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, Settings->DisplayName == nullptr || wcslen(Settings->DisplayName) == 0); + THROW_HR_IF(E_INVALIDARG, Settings->StoragePath != nullptr && wcslen(Settings->StoragePath) == 0); + THROW_HR_IF(WSLC_E_INVALID_SESSION_NAME, wcslen(Settings->DisplayName) >= std::size(WSLCSessionInformation{}.DisplayName)); + THROW_HR_IF_MSG( + E_INVALIDARG, + WI_IsAnyFlagSet(Settings->StorageFlags, ~WSLCSessionStorageFlagsValid), + "Invalid storage flags: %i", + Settings->StorageFlags); + + // Reserved names can only be assigned server-side via null Settings. + THROW_HR_IF(WSLC_E_SESSION_RESERVED, IsReservedSessionName(Settings->DisplayName)); + + resolvedDisplayName = Settings->DisplayName; + } std::lock_guard lock(m_wslcSessionsLock); // Check for an existing session first. auto result = ForEachSession([&](auto& entry, const wil::com_ptr& session) noexcept -> std::optional { - if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), Settings->DisplayName)) + if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), resolvedDisplayName.c_str())) { return {}; } @@ -88,6 +174,14 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, wslutil::StopWatch stopWatch; + // Initialize settings for the default session. + std::unique_ptr defaultSettings; + if (Settings == nullptr) + { + defaultSettings = SessionSettings::Default(callerToken.get(), tokenInfo.Elevated, resolvedDisplayName); + Settings = &defaultSettings->Settings; + } + HRESULT creationResult = wil::ResultFromException([&]() { // Get caller info. const auto callerProcess = wslutil::OpenCallingProcess(PROCESS_QUERY_LIMITED_INFORMATION); @@ -103,13 +197,13 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, AddSessionProcessToJobObject(factory.get()); // Create the session via the factory. - const auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings); + const auto sessionSettings = CreateSessionSettings(sessionId, creatorPid, Settings, resolvedDisplayName.c_str()); wil::com_ptr session; wil::com_ptr serviceRef; THROW_IF_FAILED(factory->CreateSession(&sessionSettings, vm.Get(), &session, &serviceRef)); // Track the session via its service ref, along with metadata and security info. - m_sessions.push_back({std::move(serviceRef), sessionId, creatorPid, Settings->DisplayName, std::move(tokenInfo)}); + m_sessions.push_back({std::move(serviceRef), sessionId, creatorPid, resolvedDisplayName, std::move(tokenInfo)}); // For persistent sessions, also hold a strong reference to keep them alive. const bool persistent = WI_IsFlagSet(Flags, WSLCSessionFlagsPersistent); @@ -122,19 +216,18 @@ void WSLCSessionManagerImpl::CreateSession(const WSLCSessionSettings* Settings, }); // This telemetry event is used to keep track of session creation performance (via CreationTimeMs) and failure reasons (via Result). - WSL_LOG_TELEMETRY( "WSLCCreateSession", PDT_ProductAndServicePerformance, TraceLoggingKeyword(MICROSOFT_KEYWORD_CRITICAL_DATA), - TraceLoggingValue(Settings->DisplayName, "Name"), + TraceLoggingValue(resolvedDisplayName.c_str(), "Name"), TraceLoggingValue(stopWatch.ElapsedMilliseconds(), "CreationTimeMs"), TraceLoggingValue(creationResult, "Result"), TraceLoggingValue(tokenInfo.Elevated, "Elevated"), TraceLoggingValue(static_cast(Flags), "Flags"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); - THROW_IF_FAILED_MSG(creationResult, "Failed to create session: %ls", Settings->DisplayName); + THROW_IF_FAILED_MSG(creationResult, "Failed to create session: %ls", resolvedDisplayName.c_str()); } void WSLCSessionManagerImpl::OpenSession(ULONG Id, IWSLCSession** Session) @@ -160,6 +253,14 @@ void WSLCSessionManagerImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLCSession { auto tokenInfo = GetCallingProcessTokenInfo(); + // Null name = default session, resolved from caller's token + username. + std::wstring resolvedName; + if (DisplayName == nullptr) + { + resolvedName = ResolveDefaultSessionName(tokenInfo); + DisplayName = resolvedName.c_str(); + } + auto result = ForEachSession([&](auto& entry, const wil::com_ptr& session) noexcept -> std::optional { if (!wsl::shared::string::IsEqual(entry.DisplayName.c_str(), DisplayName)) { @@ -207,12 +308,23 @@ void WSLCSessionManagerImpl::GetVersion(_Out_ WSLCVersion* Version) Version->Revision = WSL_PACKAGE_VERSION_REVISION; } -WSLCSessionInitSettings WSLCSessionManagerImpl::CreateSessionSettings(_In_ ULONG SessionId, _In_ DWORD CreatorPid, _In_ const WSLCSessionSettings* Settings) +void WSLCSessionManagerImpl::EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) +{ + THROW_HR_IF(E_POINTER, DisplayName == nullptr || StoragePath == nullptr); + THROW_HR_IF(E_INVALIDARG, DisplayName[0] == L'\0' || StoragePath[0] == L'\0'); + + const auto callerToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + auto sessionSettings = SessionSettings::Custom(callerToken.get(), DisplayName, StoragePath, WSLCSessionStorageFlagsNoCreate); + CreateSession(&sessionSettings.Settings, WSLCSessionFlagsNone, WslcSession); +} + +WSLCSessionInitSettings WSLCSessionManagerImpl::CreateSessionSettings( + _In_ ULONG SessionId, _In_ DWORD CreatorPid, _In_ const WSLCSessionSettings* Settings, _In_ LPCWSTR ResolvedDisplayName) { WSLCSessionInitSettings sessionSettings{}; sessionSettings.SessionId = SessionId; sessionSettings.CreatorPid = CreatorPid; - sessionSettings.DisplayName = Settings->DisplayName; + sessionSettings.DisplayName = ResolvedDisplayName; sessionSettings.StoragePath = Settings->StoragePath; sessionSettings.MaximumStorageSizeMb = Settings->MaximumStorageSizeMb; sessionSettings.BootTimeoutMs = Settings->BootTimeoutMs; @@ -260,6 +372,40 @@ CallingProcessTokenInfo WSLCSessionManagerImpl::GetCallingProcessTokenInfo() return {std::move(tokenInfo), elevated}; } +std::wstring WSLCSessionManagerImpl::ResolveDefaultSessionName(const CallingProcessTokenInfo& TokenInfo) +{ + // Look up the username from the caller's SID so each user gets their own + // default session (e.g. "wslc-cli-alice", "wslc-cli-admin-bob"). + wchar_t username[256 + 1] = {}; + DWORD usernameLen = ARRAYSIZE(username); + wchar_t domain[MAX_PATH] = {}; + DWORD domainLen = ARRAYSIZE(domain); + SID_NAME_USE sidType; + THROW_IF_WIN32_BOOL_FALSE(LookupAccountSidW(nullptr, TokenInfo.TokenInfo->User.Sid, username, &usernameLen, domain, &domainLen, &sidType)); + + auto baseName = TokenInfo.Elevated ? wsl::windows::wslc::DefaultAdminSessionName : wsl::windows::wslc::DefaultSessionName; + return std::format(L"{}-{}", baseName, username); +} + +bool WSLCSessionManagerImpl::IsReservedSessionName(LPCWSTR Name) +{ + // Block any name that is exactly "wslc-cli" or starts with "wslc-cli-", + // which covers the admin variant and all per-user resolved names. + constexpr std::wstring_view prefix{wsl::windows::wslc::DefaultSessionName}; + std::wstring_view name{Name}; + if (name.size() < prefix.size()) + { + return false; + } + + if (!wsl::shared::string::IsEqual(name.substr(0, prefix.size()), prefix, true)) + { + return false; + } + + return name.size() == prefix.size() || name[prefix.size()] == L'-'; +} + HRESULT WSLCSessionManagerImpl::CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo) { // Allow elevated tokens to access all sessions. @@ -292,6 +438,11 @@ HRESULT WSLCSessionManager::CreateSession(const WSLCSessionSettings* WslcSession return CallImpl(&WSLCSessionManagerImpl::CreateSession, WslcSessionSettings, Flags, WslcSession); } +HRESULT WSLCSessionManager::EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) +{ + return CallImpl(&WSLCSessionManagerImpl::EnterSession, DisplayName, StoragePath, WslcSession); +} + HRESULT WSLCSessionManager::ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) { return CallImpl(&WSLCSessionManagerImpl::ListSessions, Sessions, SessionsCount); diff --git a/src/windows/service/exe/WSLCSessionManager.h b/src/windows/service/exe/WSLCSessionManager.h index 197e8b174..f2e5f42a0 100644 --- a/src/windows/service/exe/WSLCSessionManager.h +++ b/src/windows/service/exe/WSLCSessionManager.h @@ -73,11 +73,19 @@ class WSLCSessionManagerImpl void GetVersion(_Out_ WSLCVersion* Version); void CreateSession(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession); + void EnterSession(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession); void ListSessions(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount); void OpenSession(_In_ ULONG Id, _Out_ IWSLCSession** Session); void OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session); private: + // Resolves the default session name for a caller: appends the username + // from the token SID so different users don't collide. + static std::wstring ResolveDefaultSessionName(const CallingProcessTokenInfo& TokenInfo); + + // Returns true if the name matches a reserved default session prefix. + static bool IsReservedSessionName(LPCWSTR Name); + // Iterates over all sessions, cleaning up released sessions. // The routine receives a SessionEntry& and can return an optional to stop iteration. template @@ -138,7 +146,8 @@ class WSLCSessionManagerImpl } void AddSessionProcessToJobObject(_In_ IWSLCSessionFactory* Factory); - WSLCSessionInitSettings CreateSessionSettings(_In_ ULONG SessionId, _In_ DWORD CreatorPid, _In_ const WSLCSessionSettings* Settings); + WSLCSessionInitSettings CreateSessionSettings( + _In_ ULONG SessionId, _In_ DWORD CreatorPid, _In_ const WSLCSessionSettings* Settings, _In_ LPCWSTR ResolvedDisplayName); void EnsureJobObjectCreated(); static CallingProcessTokenInfo GetCallingProcessTokenInfo(); static HRESULT CheckTokenAccess(const SessionEntry& Entry, const CallingProcessTokenInfo& TokenInfo); @@ -173,6 +182,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLCSessionManager IFACEMETHOD(GetVersion)(_Out_ WSLCVersion* Version) override; IFACEMETHOD(CreateSession)(const WSLCSessionSettings* WslcSessionSettings, WSLCSessionFlags Flags, IWSLCSession** WslcSession) override; + IFACEMETHOD(EnterSession)(_In_ LPCWSTR DisplayName, _In_ LPCWSTR StoragePath, IWSLCSession** WslcSession) override; IFACEMETHOD(ListSessions)(_Out_ WSLCSessionInformation** Sessions, _Out_ ULONG* SessionsCount) override; IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLCSession** Session) override; IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLCSession** Session) override; diff --git a/src/windows/service/inc/wslc.idl b/src/windows/service/inc/wslc.idl index 94ee186c1..0cfd36140 100644 --- a/src/windows/service/inc/wslc.idl +++ b/src/windows/service/inc/wslc.idl @@ -749,10 +749,11 @@ interface IWSLCSessionManager : IUnknown HRESULT GetVersion([out] WSLCVersion* Version); // Session management. - HRESULT CreateSession([in, ref] const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, [out] IWSLCSession** Session); + HRESULT CreateSession([in, unique] const WSLCSessionSettings* Settings, WSLCSessionFlags Flags, [out] IWSLCSession** Session); + HRESULT EnterSession([in, ref] LPCWSTR DisplayName, [in, ref] LPCWSTR StoragePath, [out] IWSLCSession** Session); HRESULT ListSessions([out, size_is(, *SessionsCount)] WSLCSessionInformation** Sessions, [out] ULONG* SessionsCount); HRESULT OpenSession([in] ULONG Id, [out] IWSLCSession** Session); - HRESULT OpenSessionByName([in, ref] LPCWSTR DisplayName, [out] IWSLCSession** Session); + HRESULT OpenSessionByName([in, unique] LPCWSTR DisplayName, [out] IWSLCSession** Session); } cpp_quote("#define WSLC_E_BASE (0x0600)") @@ -761,4 +762,6 @@ cpp_quote("#define WSLC_E_CONTAINER_PREFIX_AMBIGUOUS MAKE_HRESULT(SEVERITY_ERROR cpp_quote("#define WSLC_E_CONTAINER_NOT_FOUND MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 3) /* 0x80040603 */") cpp_quote("#define WSLC_E_VOLUME_NOT_FOUND MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 4) /* 0x80040604 */") cpp_quote("#define WSLC_E_CONTAINER_NOT_RUNNING MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 5) /* 0x80040605 */") -cpp_quote("#define WSLC_E_CONTAINER_IS_RUNNING MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 6) /* 0x80040606 */") \ No newline at end of file +cpp_quote("#define WSLC_E_CONTAINER_IS_RUNNING MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 6) /* 0x80040606 */") +cpp_quote("#define WSLC_E_SESSION_RESERVED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 7) /* 0x80040607 */") +cpp_quote("#define WSLC_E_INVALID_SESSION_NAME MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLC_E_BASE + 8) /* 0x80040608 */") diff --git a/src/windows/wslc/CMakeLists.txt b/src/windows/wslc/CMakeLists.txt index c160840c5..b1552aa9a 100644 --- a/src/windows/wslc/CMakeLists.txt +++ b/src/windows/wslc/CMakeLists.txt @@ -1,4 +1,4 @@ -set(WSLC_SUBDIRS arguments commands core services settings tasks) +set(WSLC_SUBDIRS arguments commands core services tasks) list(TRANSFORM WSLC_SUBDIRS PREPEND ${CMAKE_CURRENT_SOURCE_DIR}/ OUTPUT_VARIABLE WSLC_SUBDIR_PATHS) list(TRANSFORM WSLC_SUBDIR_PATHS APPEND /*.h OUTPUT_VARIABLE HEADER_PATTERNS) @@ -13,8 +13,8 @@ target_include_directories(wslclib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${WSLC_SUB target_link_libraries(wslclib ${COMMON_LINK_LIBRARIES} - common - yaml-cpp) + yaml-cpp + common) target_precompile_headers(wslclib REUSE_FROM common) set_target_properties(wslclib PROPERTIES FOLDER windows) diff --git a/src/windows/wslc/commands/SettingsCommand.cpp b/src/windows/wslc/commands/SettingsCommand.cpp index 6b754c10d..d66a7b999 100644 --- a/src/windows/wslc/commands/SettingsCommand.cpp +++ b/src/windows/wslc/commands/SettingsCommand.cpp @@ -13,12 +13,11 @@ Module Name: --*/ #include "Argument.h" #include "SettingsCommand.h" -#include "UserSettings.h" +#include "WSLCUserSettings.h" #include "wslutil.h" using namespace wsl::windows::common::wslutil; using namespace wsl::windows::wslc::execution; -using namespace wsl::windows::wslc::settings; using namespace wsl::shared; namespace wsl::windows::wslc { @@ -48,9 +47,9 @@ std::wstring SettingsCommand::LongDescription() const void SettingsCommand::ExecuteInternal(CLIExecutionContext& context) const { - settings::User().PrepareToShellExecuteFile(); - - const auto& path = settings::User().SettingsFilePath(); + const auto& userSettings = settings::User(); + userSettings.PrepareToShellExecuteFile(); + const auto path = userSettings.SettingsFilePath(); // Some versions of windows will fail if no file extension association exists, other will pop up the dialog // to make the user pick their default. @@ -83,7 +82,6 @@ std::wstring SettingsResetCommand::LongDescription() const void SettingsResetCommand::ExecuteInternal(CLIExecutionContext& context) const { - // TODO: do we need prompt support? settings::User().Reset(); PrintMessage(Localization::WSLCCLI_SettingsResetConfirm()); } diff --git a/src/windows/wslc/core/Main.cpp b/src/windows/wslc/core/Main.cpp index 6f1ab5e30..8278130b7 100644 --- a/src/windows/wslc/core/Main.cpp +++ b/src/windows/wslc/core/Main.cpp @@ -20,7 +20,6 @@ Module Name: #include "CLIExecutionContext.h" #include "Invocation.h" #include "RootCommand.h" -#include "UserSettings.h" using namespace wsl::shared; using namespace wsl::windows::common; diff --git a/src/windows/wslc/services/SessionModel.cpp b/src/windows/wslc/services/SessionModel.cpp deleted file mode 100644 index 9eb670041..000000000 --- a/src/windows/wslc/services/SessionModel.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/*++ - -Copyright (c) Microsoft. All rights reserved. - -Module Name: - - SessionModel.cpp - -Abstract: - - This file contains the SessionModel implementation. - ---*/ - -#include -#include "SessionModel.h" -#include "UserSettings.h" - -namespace wsl::windows::wslc::models { - -const wchar_t* SessionOptions::GetDefaultSessionName() -{ - return IsElevated() ? s_defaultAdminSessionName : s_defaultSessionName; -} - -bool SessionOptions::IsDefaultSessionName(const std::wstring& sessionName) -{ - // Only returns true for the default session name that matches current elevation. - return wsl::shared::string::IsEqual(sessionName, GetDefaultSessionName()); -} - -SessionOptions::SessionOptions() -{ - m_sessionSettings.DisplayName = GetDefaultSessionName(); - m_sessionSettings.StoragePath = GetStoragePath().c_str(); - m_sessionSettings.CpuCount = settings::User().Get(); - m_sessionSettings.MemoryMb = settings::User().Get(); - m_sessionSettings.BootTimeoutMs = s_defaultBootTimeoutMs; - m_sessionSettings.MaximumStorageSizeMb = settings::User().Get(); - m_sessionSettings.NetworkingMode = settings::User().Get(); - if (settings::User().Get() == settings::HostFileShareMode::VirtioFs) - { - WI_SetFlag(m_sessionSettings.FeatureFlags, WslcFeatureFlagsVirtioFs); - } - - if (settings::User().Get()) - { - WI_SetFlag(m_sessionSettings.FeatureFlags, WslcFeatureFlagsDnsTunneling); - } -} - -bool SessionOptions::IsElevated() -{ - auto token = wil::open_current_access_token(TOKEN_QUERY); - - // IsTokenElevated checks if the integrity level is exactly HIGH. - // We must also check for local system because it is above HIGH. - // However, IsTokenLocalSystem() does not work correctly and fails. - // TODO: Add proper handling for system user callers. - return wsl::windows::common::security::IsTokenElevated(token.get()); -} - -const std::filesystem::path& SessionOptions::GetStoragePath() -{ - static const std::filesystem::path basePath = []() { - return settings::User().Get().empty() - ? std::filesystem::path{wsl::windows::common::filesystem::GetLocalAppDataPath(nullptr) / SessionOptions::s_defaultStorageSubPath} - : settings::User().Get().c_str(); - }(); - - static const std::filesystem::path storagePathNonAdmin = basePath / std::wstring{s_defaultSessionName}; - static const std::filesystem::path storagePathAdmin = basePath / std::wstring{s_defaultAdminSessionName}; - - return IsElevated() ? storagePathAdmin : storagePathNonAdmin; -} - -} // namespace wsl::windows::wslc::models diff --git a/src/windows/wslc/services/SessionModel.h b/src/windows/wslc/services/SessionModel.h index fe610d5fb..7424a8c04 100644 --- a/src/windows/wslc/services/SessionModel.h +++ b/src/windows/wslc/services/SessionModel.h @@ -31,38 +31,4 @@ struct Session wil::com_ptr m_session; }; -class SessionOptions -{ -public: - // These are elevation-aware static methods that will return the correct - // session name or validate against the correct session name based on the - // elevation of the process. - static const wchar_t* GetDefaultSessionName(); - static bool IsDefaultSessionName(const std::wstring& sessionName); - - SessionOptions(); - - static const std::filesystem::path& GetStoragePath(); - - const WSLCSessionSettings* Get() const - { - return &m_sessionSettings; - } - - WSLCSessionSettings* Get() - { - return &m_sessionSettings; - } - -private: - static constexpr const wchar_t s_defaultSessionName[] = L"wslc-cli"; - static constexpr const wchar_t s_defaultAdminSessionName[] = L"wslc-cli-admin"; - static constexpr const wchar_t s_defaultStorageSubPath[] = L"wslc\\sessions"; - static constexpr uint32_t s_defaultBootTimeoutMs = 30 * 1000; - - static bool IsElevated(); - - WSLCSessionSettings m_sessionSettings{}; -}; - } // namespace wsl::windows::wslc::models \ No newline at end of file diff --git a/src/windows/wslc/services/SessionService.cpp b/src/windows/wslc/services/SessionService.cpp index d58ade9b0..a38b5c5b5 100644 --- a/src/windows/wslc/services/SessionService.cpp +++ b/src/windows/wslc/services/SessionService.cpp @@ -25,25 +25,30 @@ namespace wslutil = wsl::windows::common::wslutil; int SessionService::Attach(const std::wstring& sessionName) { - THROW_HR_IF(E_INVALIDARG, sessionName.empty()); - wil::com_ptr manager; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&manager))); wsl::windows::common::security::ConfigureForCOMImpersonation(manager.get()); wil::com_ptr session; - HRESULT hr = manager->OpenSessionByName(sessionName.c_str(), &session); + HRESULT hr = manager->OpenSessionByName(sessionName.empty() ? nullptr : sessionName.c_str(), &session); if (FAILED(hr)) { if (hr == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) { - wslutil::PrintMessage(Localization::MessageWslcSessionNotFound(sessionName.c_str()), stderr); + wslutil::PrintMessage( + sessionName.empty() ? Localization::MessageWslcDefaultSessionNotFound() + : Localization::MessageWslcSessionNotFound(sessionName.c_str()), + stderr); return 1; } auto errorString = wsl::windows::common::wslutil::ErrorCodeToString(hr); wslutil::PrintMessage( - Localization::MessageErrorCode(Localization::MessageWslcOpenSessionFailed(sessionName.c_str()), errorString), stderr); + Localization::MessageErrorCode( + sessionName.empty() ? Localization::MessageWslcOpenDefaultSessionFailed() + : Localization::MessageWslcOpenSessionFailed(sessionName.c_str()), + errorString), + stderr); return 1; } @@ -100,15 +105,15 @@ int SessionService::Attach(const std::wstring& sessionName) return static_cast(exitCode); } -Session SessionService::CreateSession(const SessionOptions& options, WSLCSessionFlags Flags) +Session SessionService::CreateDefaultSession() { - const WSLCSessionSettings* settings = options.Get(); wil::com_ptr sessionManager; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); + // Null Settings = default session with server-determined name and settings. wil::com_ptr session; - THROW_IF_FAILED(sessionManager->CreateSession(settings, Flags, &session)); + THROW_IF_FAILED(sessionManager->CreateSession(nullptr, WSLCSessionFlagsNone, &session)); wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); return Session(std::move(session)); } @@ -118,14 +123,13 @@ int SessionService::Enter(const std::wstring& storagePath, const std::wstring& d THROW_HR_IF(E_INVALIDARG, storagePath.empty()); THROW_HR_IF(E_INVALIDARG, displayName.empty()); - // Build session settings from the user configuration, overriding storage path and display name. - SessionOptions options; - options.Get()->DisplayName = displayName.c_str(); - options.Get()->StoragePath = storagePath.c_str(); - options.Get()->StorageFlags = WSLCSessionStorageFlagsNoCreate; // Don't create storage if it doesn't exist. + wil::com_ptr sessionManager; + THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); + wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); - // Create a non-persistent session: lifetime is tied to our COM reference. - auto session = SessionService::CreateSession(options, WSLCSessionFlagsNone); + wil::com_ptr session; + THROW_IF_FAILED(sessionManager->EnterSession(displayName.c_str(), storagePath.c_str(), &session)); + wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); wsl::windows::common::wslutil::PrintMessage(Localization::MessageWslcCreatedSession(displayName), stderr); const std::string shell = "/bin/sh"; @@ -135,7 +139,7 @@ int SessionService::Enter(const std::wstring& storagePath, const std::wstring& d const auto windowSize = console.GetWindowSize(); launcher.SetTtySize(windowSize.Y, windowSize.X); - return ConsoleService::AttachToCurrentConsole(launcher.Launch(*session.Get())); + return ConsoleService::AttachToCurrentConsole(launcher.Launch(*session.get())); } std::vector SessionService::List() @@ -174,19 +178,20 @@ Session SessionService::OpenSession(const std::wstring& displayName) int SessionService::TerminateSession(const std::wstring& displayName) { - THROW_HR_IF(E_INVALIDARG, displayName.empty()); - wil::com_ptr sessionManager; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); wil::com_ptr session; - HRESULT hr = sessionManager->OpenSessionByName(displayName.c_str(), &session); + HRESULT hr = sessionManager->OpenSessionByName(displayName.empty() ? nullptr : displayName.c_str(), &session); if (FAILED(hr)) { if (hr == HRESULT_FROM_WIN32(ERROR_NOT_FOUND)) { - wslutil::PrintMessage(Localization::MessageWslcSessionNotFound(displayName.c_str()), stderr); + wslutil::PrintMessage( + displayName.empty() ? Localization::MessageWslcDefaultSessionNotFound() + : Localization::MessageWslcSessionNotFound(displayName.c_str()), + stderr); return 1; } @@ -200,7 +205,11 @@ int SessionService::TerminateSession(const std::wstring& displayName) { auto errorString = wsl::windows::common::wslutil::ErrorCodeToString(hr); wslutil::PrintMessage( - Localization::MessageErrorCode(Localization::MessageWslcTerminateSessionFailed(displayName.c_str()), errorString), stderr); + Localization::MessageErrorCode( + displayName.empty() ? Localization::MessageWslcTerminateDefaultSessionFailed() + : Localization::MessageWslcTerminateSessionFailed(displayName.c_str()), + errorString), + stderr); return 1; } diff --git a/src/windows/wslc/services/SessionService.h b/src/windows/wslc/services/SessionService.h index 90fbcc2b9..8be8f76a3 100644 --- a/src/windows/wslc/services/SessionService.h +++ b/src/windows/wslc/services/SessionService.h @@ -27,9 +27,8 @@ struct SessionInformation struct SessionService { static int Attach(const std::wstring& name); - static wsl::windows::wslc::models::Session CreateSession( - const wsl::windows::wslc::models::SessionOptions& options, - WSLCSessionFlags Flags = WSLCSessionFlagsOpenExisting | WSLCSessionFlagsPersistent); + // Creates a default session with server-determined name and settings. + static wsl::windows::wslc::models::Session CreateDefaultSession(); static int Enter(const std::wstring& storagePath, const std::wstring& displayName); static std::vector List(); static wsl::windows::wslc::models::Session OpenSession(const std::wstring& displayName); diff --git a/src/windows/wslc/tasks/SessionTasks.cpp b/src/windows/wslc/tasks/SessionTasks.cpp index 8524504a7..7e26701d2 100644 --- a/src/windows/wslc/tasks/SessionTasks.cpp +++ b/src/windows/wslc/tasks/SessionTasks.cpp @@ -13,7 +13,6 @@ Module Name: --*/ #include "Argument.h" #include "CLIExecutionContext.h" -#include "SessionModel.h" #include "SessionService.h" #include "SessionTasks.h" #include "TableOutput.h" @@ -25,7 +24,6 @@ using namespace wsl::windows::common::string; using namespace wsl::windows::common::wslutil; using namespace wsl::windows::wslc::execution; using namespace wsl::windows::wslc::services; -using wsl::windows::wslc::models::SessionOptions; namespace wsl::windows::wslc::task { @@ -36,10 +34,6 @@ void AttachToSession(CLIExecutionContext& context) { sessionId = context.Args.Get(); } - else - { - sessionId = SessionOptions::GetDefaultSessionName(); - } context.ExitCode = SessionService::Attach(sessionId); } @@ -48,23 +42,14 @@ void CreateSession(CLIExecutionContext& context) { if (context.Args.Contains(ArgType::Session)) { - // If provided session name is not the default CLI session use open only. - // This also ensures that mixed elevation types will only attempt to open - // a session and not create it. Example: Admin process attempting to open - // a non-admin session will fail to create but succeed to open, preventing - // accidental creation of a non-admin session with admin permissions. + // User specified a session name — open only, don't create. const auto& sessionName = context.Args.Get(); - if (!SessionOptions::IsDefaultSessionName(sessionName)) - { - context.Data.Add(SessionService::OpenSession(sessionName)); - return; - } + context.Data.Add(SessionService::OpenSession(sessionName)); + return; } - // Create/open the default session. Create is only called with default session - // settings so we ensure the CLI sessions are created with correct permissions. - SessionOptions options{}; - context.Data.Add(SessionService::CreateSession(options)); + // Create/open the default session. + context.Data.Add(SessionService::CreateDefaultSession()); } void ListSessions(CLIExecutionContext& context) @@ -98,10 +83,6 @@ void TerminateSession(CLIExecutionContext& context) { sessionId = context.Args.Get(); } - else - { - sessionId = SessionOptions::GetDefaultSessionName(); - } context.ExitCode = SessionService::TerminateSession(sessionId); } diff --git a/test/windows/CMakeLists.txt b/test/windows/CMakeLists.txt index e2f8b2caf..5e12672af 100644 --- a/test/windows/CMakeLists.txt +++ b/test/windows/CMakeLists.txt @@ -32,6 +32,7 @@ target_link_libraries(wsltests ${COMMON_LINK_LIBRARIES} ${MSI_LINK_LIBRARIES} ${HCS_LINK_LIBRARIES} + yaml-cpp ${SERVICE_LINK_LIBRARIES} VirtDisk.lib Wer.lib diff --git a/test/windows/WSLCTests.cpp b/test/windows/WSLCTests.cpp index 5f5df349b..2644eca33 100644 --- a/test/windows/WSLCTests.cpp +++ b/test/windows/WSLCTests.cpp @@ -341,7 +341,7 @@ class WSLCTests { auto settings = GetDefaultSessionSettings(nullptr); wil::com_ptr session; - VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), E_INVALIDARG); + VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME); } // Reject DisplayName at exact boundary (no room for null terminator). @@ -349,7 +349,7 @@ class WSLCTests std::wstring boundaryName(std::size(WSLCSessionInformation{}.DisplayName), L'x'); auto settings = GetDefaultSessionSettings(boundaryName.c_str()); wil::com_ptr session; - VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), E_INVALIDARG); + VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME); } // Reject too long DisplayName. @@ -357,7 +357,7 @@ class WSLCTests std::wstring longName(std::size(WSLCSessionInformation{}.DisplayName) + 1, L'x'); auto settings = GetDefaultSessionSettings(longName.c_str()); wil::com_ptr session; - VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), E_INVALIDARG); + VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_INVALID_SESSION_NAME); } // Validate that creating a session on a non-existing storage fails if WSLCSessionStorageFlagsNoCreate is set. @@ -2195,7 +2195,18 @@ class WSLCTests std::thread thread(readDmesg); // Needs to be created before the VM starts, to avoid a pipe deadlock. + // Ensure the thread is joined even if CreateSession throws, to avoid std::terminate. + auto threadGuard = wil::scope_exit([&]() { + write.reset(); + if (thread.joinable()) + { + thread.join(); + } + }); + auto session = CreateSession(settings); + threadGuard.release(); // CreateSession succeeded, detach scope_exit below takes over. + auto detach = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { session.reset(); if (thread.joinable()) diff --git a/test/windows/wslc/CMakeLists.txt b/test/windows/wslc/CMakeLists.txt index 4e91a5411..23418d7cb 100644 --- a/test/windows/wslc/CMakeLists.txt +++ b/test/windows/wslc/CMakeLists.txt @@ -28,5 +28,4 @@ target_include_directories(wsltests PRIVATE ${CMAKE_SOURCE_DIR}/src/windows/wslc/arguments ${CMAKE_SOURCE_DIR}/src/windows/wslc/services ${CMAKE_SOURCE_DIR}/src/windows/wslc/tasks - ${CMAKE_SOURCE_DIR}/src/windows/wslc/settings ) diff --git a/test/windows/wslc/WSLCCLISettingsUnitTests.cpp b/test/windows/wslc/WSLCCLISettingsUnitTests.cpp index 7a6c02b37..e7152e89a 100644 --- a/test/windows/wslc/WSLCCLISettingsUnitTests.cpp +++ b/test/windows/wslc/WSLCCLISettingsUnitTests.cpp @@ -16,7 +16,7 @@ Module Name: #include "precomp.h" #include "windows/Common.h" #include "WSLCCLITestHelpers.h" -#include "UserSettings.h" +#include "WSLCUserSettings.h" #include #include diff --git a/test/windows/wslc/e2e/WSLCE2EGlobalTests.cpp b/test/windows/wslc/e2e/WSLCE2EGlobalTests.cpp index 043a02656..81c353471 100644 --- a/test/windows/wslc/e2e/WSLCE2EGlobalTests.cpp +++ b/test/windows/wslc/e2e/WSLCE2EGlobalTests.cpp @@ -16,6 +16,7 @@ Module Name: #include "WSLCCLITestHelpers.h" #include "WSLCExecutor.h" #include "WSLCE2EHelpers.h" +#include "WSLCSessionDefaults.h" #include "Argument.h" using namespace WEX::Logging; @@ -23,6 +24,22 @@ using namespace WEX::Logging; namespace WSLCE2ETests { using namespace wsl::shared; +namespace { + + // Returns the expected default session name for the current user (e.g. "wslc-cli-admin-benhill"). + std::wstring GetExpectedDefaultSessionName(bool elevated) + { + auto baseName = elevated ? wsl::windows::wslc::DefaultAdminSessionName : wsl::windows::wslc::DefaultSessionName; + + wchar_t username[256 + 1] = {}; + DWORD usernameLen = ARRAYSIZE(username); + THROW_IF_WIN32_BOOL_FALSE(GetUserNameW(username, &usernameLen)); + + return std::format(L"{}-{}", baseName, username); + } + +} // namespace + class WSLCE2EGlobalTests { WSLC_TEST_CLASS(WSLCE2EGlobalTests) @@ -70,7 +87,8 @@ class WSLCE2EGlobalTests result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); + auto adminName = GetExpectedDefaultSessionName(true); + VERIFY_IS_TRUE(result.Stdout->find(adminName) != std::wstring::npos); } WSLC_TEST_METHOD(WSLCE2E_Session_DefaultNonElevated) @@ -86,7 +104,8 @@ class WSLCE2EGlobalTests VERIFY_IS_TRUE(result.Stdout.has_value()); // The "\r\n" after session name is important to differentiate it from the admin session. - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + auto nonAdminName = GetExpectedDefaultSessionName(false); + VERIFY_IS_TRUE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); } WSLC_TEST_METHOD(WSLCE2E_Session_NonElevatedCannotAccessAdminSession) @@ -96,7 +115,8 @@ class WSLCE2EGlobalTests result.Verify({.Stderr = L"", .ExitCode = 0}); // Try to explicitly target the admin session from non-elevated process - result = RunWslc(L"container list --session wslc-cli-admin", ElevationType::NonElevated); + auto adminName = GetExpectedDefaultSessionName(true); + result = RunWslc(std::format(L"container list --session {}", adminName), ElevationType::NonElevated); // Should fail with access denied. result.Verify({.Stderr = L"The requested operation requires elevation. \r\nError code: ERROR_ELEVATION_REQUIRED\r\n", .ExitCode = 1}); @@ -109,7 +129,8 @@ class WSLCE2EGlobalTests result.Verify({.Stderr = L"", .ExitCode = 0}); // Elevated user should be able to explicitly target the non-admin session - result = RunWslc(L"container list --session wslc-cli", ElevationType::Elevated); + auto nonAdminName = GetExpectedDefaultSessionName(false); + result = RunWslc(std::format(L"container list --session {}", nonAdminName), ElevationType::Elevated); // This should work - elevated users can access non-elevated sessions result.Verify({.Stderr = L"", .ExitCode = 0}); @@ -117,20 +138,107 @@ class WSLCE2EGlobalTests WSLC_TEST_METHOD(WSLCE2E_Session_CreateMixedElevation_Fails) { - EnsureSessionIsTerminated(L"wslc-cli"); - EnsureSessionIsTerminated(L"wslc-cli-admin"); + EnsureSessionIsTerminated(GetExpectedDefaultSessionName(false)); + EnsureSessionIsTerminated(GetExpectedDefaultSessionName(true)); // Ensure elevated cannot create the non-elevated session. - auto result = RunWslc(L"container list --session wslc-cli", ElevationType::Elevated); + auto nonAdminName = GetExpectedDefaultSessionName(false); + auto adminName = GetExpectedDefaultSessionName(true); + auto result = RunWslc(std::format(L"container list --session {}", nonAdminName), ElevationType::Elevated); result.Verify({.Stderr = L"Element not found. \r\nError code: ERROR_NOT_FOUND\r\n", .ExitCode = 1}); // Ensure non-elevated cannot create the elevated session. - result = RunWslc(L"container list --session wslc-cli-admin", ElevationType::NonElevated); + result = RunWslc(std::format(L"container list --session {}", adminName), ElevationType::NonElevated); result.Verify({.Stderr = L"Element not found. \r\nError code: ERROR_NOT_FOUND\r\n", .ExitCode = 1}); } + // Regression test for session name squatting vulnerability. + // + // Validates that a process cannot create a session with the reserved default + // session names ("wslc-cli" or "wslc-cli-admin") via the COM API. These names + // are assigned server-side when the client passes null Settings to CreateSession, + // preventing a malicious process from squatting on the name and blocking + // legitimate wslc.exe clients. + WSLC_TEST_METHOD(WSLCE2E_Session_NameSquatting_ElevatedCannotBlockNonElevated) + { + // Ensure no existing sessions with default names. + EnsureSessionIsTerminated(wsl::windows::wslc::DefaultSessionName); + EnsureSessionIsTerminated(wsl::windows::wslc::DefaultAdminSessionName); + + // Attack: attempt to create a session with the reserved non-admin default + // name directly through the COM API from this elevated process. + // The service should reject this because reserved default session names + // cannot be explicitly created. + { + wil::com_ptr sessionManager; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); + wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); + + WSLCSessionSettings settings{}; + settings.DisplayName = wsl::windows::wslc::DefaultSessionName; + settings.StoragePath = L"C:\\dummy"; + settings.CpuCount = 4; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30000; + settings.MaximumStorageSizeMb = 4096; + + wil::com_ptr session; + HRESULT hr = sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session); + VERIFY_ARE_EQUAL(hr, WSLC_E_SESSION_RESERVED); + } + + // Also verify that the admin reserved name is rejected. + { + wil::com_ptr sessionManager; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); + wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); + + WSLCSessionSettings settings{}; + settings.DisplayName = wsl::windows::wslc::DefaultAdminSessionName; + settings.StoragePath = L"C:\\dummy"; + settings.CpuCount = 4; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30000; + settings.MaximumStorageSizeMb = 4096; + + wil::com_ptr session; + HRESULT hr = sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session); + VERIFY_ARE_EQUAL(hr, WSLC_E_SESSION_RESERVED); + } + + // Non-elevated wslc.exe should still be able to create and use its default + // session (which now passes null Settings, resolved entirely server-side). + auto result = RunWslc(L"container list", ElevationType::NonElevated); + result.Verify({.Stderr = L"", .ExitCode = S_OK}); + + // Verify that case variations of reserved names are also rejected, + // preventing bypass on case-insensitive filesystems (NTFS). + { + wil::com_ptr sessionManager; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLCSessionManager), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&sessionManager))); + wsl::windows::common::security::ConfigureForCOMImpersonation(sessionManager.get()); + + WSLCSessionSettings settings{}; + settings.DisplayName = L"WSLC-CLI"; + settings.StoragePath = L"C:\\dummy"; + settings.CpuCount = 4; + settings.MemoryMb = 2048; + settings.BootTimeoutMs = 30000; + settings.MaximumStorageSizeMb = 4096; + + wil::com_ptr session; + VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_SESSION_RESERVED); + + settings.DisplayName = L"Wslc-Cli-Admin"; + VERIFY_ARE_EQUAL(sessionManager->CreateSession(&settings, WSLCSessionFlagsNone, &session), WSLC_E_SESSION_RESERVED); + } + } + WSLC_TEST_METHOD(WSLCE2E_Session_Terminate_Implicit) { + auto adminName = GetExpectedDefaultSessionName(true); + auto nonAdminName = GetExpectedDefaultSessionName(false); + // Run container list to create the default session if it does not already exist auto result = RunWslc(L"container list"); result.Verify({.Stderr = L"", .ExitCode = 0}); @@ -139,7 +247,7 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(adminName) != std::wstring::npos); // Terminate the session result = RunWslc(L"session terminate"); @@ -149,7 +257,7 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_FALSE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); + VERIFY_IS_FALSE(result.Stdout->find(adminName) != std::wstring::npos); // Repeat test for non-elevated session. @@ -161,7 +269,7 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); // Terminate the session result = RunWslc(L"session terminate", ElevationType::NonElevated); @@ -171,11 +279,14 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_FALSE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_FALSE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); } WSLC_TEST_METHOD(WSLCE2E_Session_Terminate_Explicit) { + auto adminName = GetExpectedDefaultSessionName(true); + auto nonAdminName = GetExpectedDefaultSessionName(false); + // Run container list to create the default session if it does not already exist auto result = RunWslc(L"container list"); result.Verify({.Stderr = L"", .ExitCode = 0}); @@ -184,17 +295,17 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(adminName) != std::wstring::npos); // Terminate the session - result = RunWslc(L"session terminate wslc-cli-admin"); + result = RunWslc(std::format(L"session terminate {}", adminName)); result.Verify({.Stderr = L"", .ExitCode = 0}); // Verify session no longer shows up result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_FALSE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); + VERIFY_IS_FALSE(result.Stdout->find(adminName) != std::wstring::npos); // Repeat test for non-elevated session. @@ -206,21 +317,24 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); // Terminate the session - result = RunWslc(L"session terminate wslc-cli", ElevationType::NonElevated); + result = RunWslc(std::format(L"session terminate {}", nonAdminName), ElevationType::NonElevated); result.Verify({.Stderr = L"", .ExitCode = 0}); // Verify session no longer shows up result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_FALSE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_FALSE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); } WSLC_TEST_METHOD(WSLCE2E_Session_Terminate_MixedElevation) { + auto adminName = GetExpectedDefaultSessionName(true); + auto nonAdminName = GetExpectedDefaultSessionName(false); + // Run container list to create the default sessions if they do not already exist. auto result = RunWslc(L"container list", ElevationType::Elevated); result.Verify({.Stderr = L"", .ExitCode = 0}); @@ -231,22 +345,22 @@ class WSLCE2EGlobalTests result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli-admin") != std::wstring::npos); - VERIFY_IS_TRUE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(adminName) != std::wstring::npos); + VERIFY_IS_TRUE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); // Attempt to terminate the admin session from the non-elevated process and fail. - result = RunWslc(L"session terminate wslc-cli-admin", ElevationType::NonElevated); + result = RunWslc(std::format(L"session terminate {}", adminName), ElevationType::NonElevated); result.Verify({.Stderr = L"The requested operation requires elevation. \r\nError code: ERROR_ELEVATION_REQUIRED\r\n", .ExitCode = 1}); // Terminate the non-elevated session from the elevated process. - result = RunWslc(L"session terminate wslc-cli", ElevationType::Elevated); + result = RunWslc(std::format(L"session terminate {}", nonAdminName), ElevationType::Elevated); result.Verify({.Stderr = L"", .ExitCode = 0}); // Verify non-elevated session no longer shows up result = RunWslc(L"session list"); result.Verify({.Stderr = L"", .ExitCode = 0}); VERIFY_IS_TRUE(result.Stdout.has_value()); - VERIFY_IS_FALSE(result.Stdout->find(L"wslc-cli\r\n") != std::wstring::npos); + VERIFY_IS_FALSE(result.Stdout->find(nonAdminName + L"\r\n") != std::wstring::npos); } WSLC_TEST_METHOD(WSLCE2E_Session_Targeting) @@ -329,7 +443,8 @@ class WSLCE2EGlobalTests { Log::Comment(L"Testing non-elevated interactive session with explicit session name"); // Non-Elevated session shell should attach to the wslc by name also. - auto session = RunWslcInteractive(L"session shell wslc-cli", ElevationType::NonElevated); + auto nonAdminName = GetExpectedDefaultSessionName(false); + auto session = RunWslcInteractive(std::format(L"session shell {}", nonAdminName), ElevationType::NonElevated); VERIFY_IS_TRUE(session.IsRunning(), L"Session should be running"); session.ExpectStdout(VT::SESSION_PROMPT); @@ -353,7 +468,8 @@ class WSLCE2EGlobalTests { Log::Comment(L"Testing elevated interactive session with explicit admin session name"); // Elevated session shell should attach to the wslc by name also. - auto session = RunWslcInteractive(L"session shell wslc-cli-admin", ElevationType::Elevated); + auto adminName = GetExpectedDefaultSessionName(true); + auto session = RunWslcInteractive(std::format(L"session shell {}", adminName), ElevationType::Elevated); VERIFY_IS_TRUE(session.IsRunning(), L"Session should be running"); session.ExpectStdout(VT::SESSION_PROMPT); diff --git a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp index f0e0e2420..4448ce665 100644 --- a/test/windows/wslc/e2e/WSLCE2EHelpers.cpp +++ b/test/windows/wslc/e2e/WSLCE2EHelpers.cpp @@ -12,7 +12,7 @@ Module Name: --*/ #include "precomp.h" -#include "SessionModel.h" +#include "WSLCSessionDefaults.h" #include "ImageModel.h" #include "windows/Common.h" #include "WSLCExecutor.h" @@ -341,7 +341,14 @@ void EnsureSessionIsTerminated(const std::wstring& sessionName) std::wstring targetSession = sessionName; if (targetSession.empty()) { - targetSession = std::wstring{wsl::windows::wslc::models::SessionOptions::GetDefaultSessionName()}; + auto isElevated = wsl::windows::common::security::IsTokenElevated(wil::open_current_access_token(TOKEN_QUERY).get()); + auto baseName = isElevated ? wsl::windows::wslc::DefaultAdminSessionName : wsl::windows::wslc::DefaultSessionName; + + wchar_t username[256 + 1] = {}; + DWORD usernameLen = ARRAYSIZE(username); + THROW_IF_WIN32_BOOL_FALSE(GetUserNameW(username, &usernameLen)); + + targetSession = std::format(L"{}-{}", baseName, username); } auto listResult = RunWslc(L"session list"); diff --git a/test/windows/wslc/e2e/WSLCE2ESessionEnterTests.cpp b/test/windows/wslc/e2e/WSLCE2ESessionEnterTests.cpp index 49aad0f82..00c66a6de 100644 --- a/test/windows/wslc/e2e/WSLCE2ESessionEnterTests.cpp +++ b/test/windows/wslc/e2e/WSLCE2ESessionEnterTests.cpp @@ -16,14 +16,42 @@ Module Name: #include "WSLCCLITestHelpers.h" #include "WSLCExecutor.h" #include "WSLCE2EHelpers.h" -#include "SessionModel.h" +#include "WSLCSessionDefaults.h" +#include "WSLCUserSettings.h" using namespace WEX::Logging; -using wsl::windows::wslc::models::SessionOptions; - namespace WSLCE2ETests { +namespace { + + const std::filesystem::path& GetDefaultStoragePath() + { + auto isElevated = wsl::windows::common::security::IsTokenElevated(wil::open_current_access_token(TOKEN_QUERY).get()); + + const auto& userSettings = wsl::windows::wslc::settings::User(); + auto customPath = userSettings.Get(); + + static const std::filesystem::path basePath = + customPath.empty() ? (wsl::windows::common::filesystem::GetLocalAppDataPath(nullptr) / wsl::windows::wslc::DefaultStorageSubPath) + : std::filesystem::path{customPath}; + + // Session names are now qualified with the username (e.g. "wslc-cli-alice"). + wchar_t username[256 + 1] = {}; + DWORD usernameLen = ARRAYSIZE(username); + THROW_IF_WIN32_BOOL_FALSE(GetUserNameW(username, &usernameLen)); + + auto adminName = std::format(L"{}-{}", wsl::windows::wslc::DefaultAdminSessionName, username); + auto nonAdminName = std::format(L"{}-{}", wsl::windows::wslc::DefaultSessionName, username); + + static const std::filesystem::path storagePathNonAdmin = basePath / nonAdminName; + static const std::filesystem::path storagePathAdmin = basePath / adminName; + + return isElevated ? storagePathAdmin : storagePathNonAdmin; + } + +} // namespace + class WSLCE2ESessionEnterTests { WSLC_TEST_CLASS(WSLCE2ESessionEnterTests) @@ -43,7 +71,7 @@ class WSLCE2ESessionEnterTests constexpr auto sessionName = L"test-wslc-session-enter"; // Run an interactive session enter with an explicit name. - auto session = RunWslcInteractive(std::format(L"session enter \"{}\" --name {}", SessionOptions::GetStoragePath(), sessionName)); + auto session = RunWslcInteractive(std::format(L"session enter \"{}\" --name {}", GetDefaultStoragePath(), sessionName)); VERIFY_IS_TRUE(session.IsRunning(), L"Session should be running"); session.ExpectStdout(VT::SESSION_PROMPT); @@ -73,7 +101,7 @@ class WSLCE2ESessionEnterTests WSLC_TEST_METHOD(WSLCE2E_SessionEnter_WithoutName_GeneratesGuid) { - auto session = RunWslcInteractive(std::format(L"session enter \"{}\"", SessionOptions::GetStoragePath())); + auto session = RunWslcInteractive(std::format(L"session enter \"{}\"", GetDefaultStoragePath())); VERIFY_IS_TRUE(session.IsRunning(), L"Session should be running"); session.ExpectStderr("Created session: ");