diff --git a/pkg/machine/hyperv/hutil.go b/pkg/machine/hyperv/hutil.go index 65c6d73c830..72cf15dc71e 100644 --- a/pkg/machine/hyperv/hutil.go +++ b/pkg/machine/hyperv/hutil.go @@ -5,8 +5,9 @@ package hyperv import ( "errors" + "github.com/containers/podman/v6/pkg/machine/windows" "github.com/sirupsen/logrus" - "golang.org/x/sys/windows" + syswindows "golang.org/x/sys/windows" ) var ( @@ -17,7 +18,7 @@ var ( ) func HasHyperVAdminRights() bool { - sid, err := windows.CreateWellKnownSid(windows.WinBuiltinHyperVAdminsSid) + sid, err := syswindows.CreateWellKnownSid(syswindows.WinBuiltinHyperVAdminsSid) if err != nil { return false } @@ -27,7 +28,7 @@ func HasHyperVAdminRights() bool { // token of the calling thread. If the thread is not impersonating, // the function duplicates the thread's primary token to create an // impersonation token." - token := windows.Token(0) + token := syswindows.Token(0) member, err := token.IsMember(sid) if err != nil { logrus.Warnf("Token Membership Error: %s", err) @@ -36,3 +37,8 @@ func HasHyperVAdminRights() bool { return member } + +// HasHyperVPermissions checks if the user has either admin rights or Hyper-V admin rights. +func HasHyperVPermissions() bool { + return windows.HasAdminRights() || HasHyperVAdminRights() +} diff --git a/pkg/machine/provider/platform_test.go b/pkg/machine/provider/platform_test.go index 20cf1a92cf2..bbdb5a59d11 100644 --- a/pkg/machine/provider/platform_test.go +++ b/pkg/machine/provider/platform_test.go @@ -13,7 +13,7 @@ func TestSupportedProviders(t *testing.T) { case "darwin": assert.Equal(t, []define.VMType{define.AppleHvVirt, define.LibKrun}, SupportedProviders()) case "windows": - assert.Equal(t, []define.VMType{define.WSLVirt, define.HyperVVirt}, SupportedProviders()) + assert.ElementsMatch(t, []define.VMType{define.WSLVirt, define.HyperVVirt}, SupportedProviders()) case "linux": assert.Equal(t, []define.VMType{define.QemuVirt}, SupportedProviders()) } @@ -28,7 +28,7 @@ func TestInstalledProviders(t *testing.T) { case "windows": provider, err := Get() assert.NoError(t, err) - assert.Contains(t, installed, provider) + assert.Contains(t, installed, provider.VMType()) case "linux": assert.Equal(t, []define.VMType{define.QemuVirt}, installed) } diff --git a/pkg/machine/provider/platform_windows.go b/pkg/machine/provider/platform_windows.go index d2e6522d15a..2068c54467e 100644 --- a/pkg/machine/provider/platform_windows.go +++ b/pkg/machine/provider/platform_windows.go @@ -15,6 +15,11 @@ import ( "go.podman.io/common/pkg/config" ) +// Variable to hold permission check function for testing purposes. +var ( + hasHyperVPermissionsFunc = hyperv.HasHyperVPermissions +) + func Get() (vmconfigs.VMProvider, error) { cfg, err := config.Default() if err != nil { @@ -39,6 +44,12 @@ func GetByVMType(resolvedVMType define.VMType) (vmconfigs.VMProvider, error) { case define.WSLVirt: return new(wsl.WSLStubber), nil case define.HyperVVirt: + // Check permissions before returning the Hyper-V provider. + // Working with Hyper-V requires users to be at least members of the Hyper-V admin group. + // Init and remove actions have custom use cases and they are checked on the stubber. + if !hasHyperVPermissionsFunc() { + return nil, hyperv.ErrHypervUserNotInAdminGroup + } return new(hyperv.HyperVStubber), nil default: } @@ -46,10 +57,13 @@ func GetByVMType(resolvedVMType define.VMType) (vmconfigs.VMProvider, error) { } func GetAll() []vmconfigs.VMProvider { - return []vmconfigs.VMProvider{ + providers := []vmconfigs.VMProvider{ new(wsl.WSLStubber), - new(hyperv.HyperVStubber), } + if hasHyperVPermissionsFunc() { + providers = append(providers, new(hyperv.HyperVStubber)) + } + return providers } // SupportedProviders returns the providers that are supported on the host operating system diff --git a/pkg/machine/provider/platform_windows_test.go b/pkg/machine/provider/platform_windows_test.go new file mode 100644 index 00000000000..a66a0af0025 --- /dev/null +++ b/pkg/machine/provider/platform_windows_test.go @@ -0,0 +1,113 @@ +//go:build windows + +package provider + +import ( + "testing" + + "github.com/containers/podman/v6/pkg/machine/define" + "github.com/containers/podman/v6/pkg/machine/hyperv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helper to setup mocks and ensure cleanup +func mockPermissions(t *testing.T, hasHyperVPermissions bool) { + origHyperVPermissionsFunc := hasHyperVPermissionsFunc + t.Cleanup(func() { + hasHyperVPermissionsFunc = origHyperVPermissionsFunc + }) + + hasHyperVPermissionsFunc = func() bool { return hasHyperVPermissions } +} + +func TestGetByVMType_HyperV(t *testing.T) { + tests := []struct { + name string + hasHyperVPermissions bool + expectError bool + }{ + { + name: "WithHyperVPermissions", + hasHyperVPermissions: true, + expectError: false, + }, + { + name: "WithoutPermissions", + hasHyperVPermissions: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPermissions(t, tt.hasHyperVPermissions) + + provider, err := GetByVMType(define.HyperVVirt) + + if tt.expectError { + assert.Error(t, err) + assert.Equal(t, err.Error(), hyperv.ErrHypervUserNotInAdminGroup.Error()) + assert.Nil(t, provider) + } else { + require.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, define.HyperVVirt, provider.VMType()) + } + }) + } +} + +func TestGetAll_HyperV_Inclusion(t *testing.T) { + tests := []struct { + name string + hasHyperVPermissions bool + expectHyperV bool + }{ + {"WithHyperVPermissions", true, true}, + {"WithoutHyperVPermissions", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPermissions(t, tt.hasHyperVPermissions) + + providers := GetAll() + + // Check for HyperV presence + hasHyperV := false + for _, p := range providers { + if p.VMType() == define.HyperVVirt { + hasHyperV = true + break + } + } + + assert.Equal(t, tt.expectHyperV, hasHyperV, "Hyper-V provider presence mismatch") + + // WSL should always be present in these scenarios + hasWSL := false + for _, p := range providers { + if p.VMType() == define.WSLVirt { + hasWSL = true + break + } + } + assert.True(t, hasWSL, "GetAll should always include WSL provider") + }) + } +} + +func TestGetByVMType_WSL_AlwaysWorks(t *testing.T) { + provider, err := GetByVMType(define.WSLVirt) + require.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, define.WSLVirt, provider.VMType()) +} + +func TestGetByVMType_UnsupportedProvider(t *testing.T) { + provider, err := GetByVMType(define.QemuVirt) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported virtualization provider") + assert.Nil(t, provider) +} diff --git a/winmake.ps1 b/winmake.ps1 index 7e75f37b5ea..af03699cf69 100644 --- a/winmake.ps1 +++ b/winmake.ps1 @@ -78,7 +78,7 @@ function Local-Unit { Build-Ginkgo $skippackages = 'hack,internal\domain\infra\abi,internal\domain\infra\tunnel,libpod\lock\shm,pkg\api\handlers\libpod,pkg\api\handlers\utils,pkg\bindings,' $skippackages += 'pkg\domain\infra\abi,pkg\emulation,pkg\machine\apple,pkg\machine\applehv,pkg\machine\e2e,pkg\machine\libkrun,' - $skippackages += 'pkg\machine\provider,pkg\machine\proxyenv,pkg\machine\qemu,pkg\specgen\generate,pkg\systemd,test\e2e,test\utils,cmd\rootlessport,' + $skippackages += 'pkg\machine\proxyenv,pkg\machine\qemu,pkg\specgen\generate,pkg\systemd,test\e2e,test\utils,cmd\rootlessport,' $skippackages += 'pkg\pidhandle' if ($null -eq $ENV:GINKGOTIMEOUT) { $ENV:GINKGOTIMEOUT = '--timeout=15m' } Run-Command "./bin/ginkgo.exe -vv -r --tags `"$remotetags`" ${ENV:GINKGOTIMEOUT} --trace --no-color --skip-package `"$skippackages`""