Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pkg/machine/hyperv/hutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ package hyperv
import (
"errors"

"github.com/containers/podman/v6/pkg/machine/windows"
"github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
syswindows "golang.org/x/sys/windows"
)

var (
Expand All @@ -17,7 +18,7 @@ var (
)

func HasHyperVAdminRights() bool {
sid, err := windows.CreateWellKnownSid(windows.WinBuiltinHyperVAdminsSid)
sid, err := syswindows.CreateWellKnownSid(syswindows.WinBuiltinHyperVAdminsSid)
if err != nil {
return false
}
Expand All @@ -27,7 +28,7 @@ func HasHyperVAdminRights() bool {
// token of the calling thread. If the thread is not impersonating,
// the function duplicates the thread's primary token to create an
// impersonation token."
token := windows.Token(0)
token := syswindows.Token(0)
member, err := token.IsMember(sid)
if err != nil {
logrus.Warnf("Token Membership Error: %s", err)
Expand All @@ -36,3 +37,8 @@ func HasHyperVAdminRights() bool {

return member
}

// HasHyperVPermissions checks if the user has either admin rights or Hyper-V admin rights.
func HasHyperVPermissions() bool {
return windows.HasAdminRights() || HasHyperVAdminRights()
}
4 changes: 2 additions & 2 deletions pkg/machine/provider/platform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestSupportedProviders(t *testing.T) {
case "darwin":
assert.Equal(t, []define.VMType{define.AppleHvVirt, define.LibKrun}, SupportedProviders())
case "windows":
assert.Equal(t, []define.VMType{define.WSLVirt, define.HyperVVirt}, SupportedProviders())
assert.ElementsMatch(t, []define.VMType{define.WSLVirt, define.HyperVVirt}, SupportedProviders())
case "linux":
assert.Equal(t, []define.VMType{define.QemuVirt}, SupportedProviders())
}
Expand All @@ -28,7 +28,7 @@ func TestInstalledProviders(t *testing.T) {
case "windows":
provider, err := Get()
assert.NoError(t, err)
assert.Contains(t, installed, provider)
assert.Contains(t, installed, provider.VMType())
case "linux":
assert.Equal(t, []define.VMType{define.QemuVirt}, installed)
}
Expand Down
18 changes: 16 additions & 2 deletions pkg/machine/provider/platform_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
"go.podman.io/common/pkg/config"
)

// Variable to hold permission check function for testing purposes.
var (
hasHyperVPermissionsFunc = hyperv.HasHyperVPermissions
)

func Get() (vmconfigs.VMProvider, error) {
cfg, err := config.Default()
if err != nil {
Expand All @@ -39,17 +44,26 @@ func GetByVMType(resolvedVMType define.VMType) (vmconfigs.VMProvider, error) {
case define.WSLVirt:
return new(wsl.WSLStubber), nil
case define.HyperVVirt:
// Check permissions before returning the Hyper-V provider.
// Working with Hyper-V requires users to be at least members of the Hyper-V admin group.
// Init and remove actions have custom use cases and they are checked on the stubber.
if !hasHyperVPermissionsFunc() {
return nil, hyperv.ErrHypervUserNotInAdminGroup
}
return new(hyperv.HyperVStubber), nil
default:
}
return nil, fmt.Errorf("unsupported virtualization provider: `%s`", resolvedVMType.String())
}

func GetAll() []vmconfigs.VMProvider {
return []vmconfigs.VMProvider{
providers := []vmconfigs.VMProvider{
new(wsl.WSLStubber),
new(hyperv.HyperVStubber),
}
if hasHyperVPermissionsFunc() {
providers = append(providers, new(hyperv.HyperVStubber))
}
return providers
}

// SupportedProviders returns the providers that are supported on the host operating system
Expand Down
113 changes: 113 additions & 0 deletions pkg/machine/provider/platform_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//go:build windows

package provider

import (
"testing"

"github.com/containers/podman/v6/pkg/machine/define"
"github.com/containers/podman/v6/pkg/machine/hyperv"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// helper to setup mocks and ensure cleanup
func mockPermissions(t *testing.T, hasHyperVPermissions bool) {
origHyperVPermissionsFunc := hasHyperVPermissionsFunc
t.Cleanup(func() {
hasHyperVPermissionsFunc = origHyperVPermissionsFunc
})

hasHyperVPermissionsFunc = func() bool { return hasHyperVPermissions }
}

func TestGetByVMType_HyperV(t *testing.T) {
tests := []struct {
name string
hasHyperVPermissions bool
expectError bool
}{
{
name: "WithHyperVPermissions",
hasHyperVPermissions: true,
expectError: false,
},
{
name: "WithoutPermissions",
hasHyperVPermissions: false,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockPermissions(t, tt.hasHyperVPermissions)

provider, err := GetByVMType(define.HyperVVirt)

if tt.expectError {
assert.Error(t, err)
assert.Equal(t, err.Error(), hyperv.ErrHypervUserNotInAdminGroup.Error())
assert.Nil(t, provider)
} else {
require.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, define.HyperVVirt, provider.VMType())
}
})
}
}

func TestGetAll_HyperV_Inclusion(t *testing.T) {
tests := []struct {
name string
hasHyperVPermissions bool
expectHyperV bool
}{
{"WithHyperVPermissions", true, true},
{"WithoutHyperVPermissions", false, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockPermissions(t, tt.hasHyperVPermissions)

providers := GetAll()

// Check for HyperV presence
hasHyperV := false
for _, p := range providers {
if p.VMType() == define.HyperVVirt {
hasHyperV = true
break
}
}

assert.Equal(t, tt.expectHyperV, hasHyperV, "Hyper-V provider presence mismatch")

// WSL should always be present in these scenarios
hasWSL := false
for _, p := range providers {
if p.VMType() == define.WSLVirt {
hasWSL = true
break
}
}
assert.True(t, hasWSL, "GetAll should always include WSL provider")
})
}
}

func TestGetByVMType_WSL_AlwaysWorks(t *testing.T) {
provider, err := GetByVMType(define.WSLVirt)
require.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, define.WSLVirt, provider.VMType())
}

func TestGetByVMType_UnsupportedProvider(t *testing.T) {
provider, err := GetByVMType(define.QemuVirt)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported virtualization provider")
assert.Nil(t, provider)
}
2 changes: 1 addition & 1 deletion winmake.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function Local-Unit {
Build-Ginkgo
$skippackages = 'hack,internal\domain\infra\abi,internal\domain\infra\tunnel,libpod\lock\shm,pkg\api\handlers\libpod,pkg\api\handlers\utils,pkg\bindings,'
$skippackages += 'pkg\domain\infra\abi,pkg\emulation,pkg\machine\apple,pkg\machine\applehv,pkg\machine\e2e,pkg\machine\libkrun,'
$skippackages += 'pkg\machine\provider,pkg\machine\proxyenv,pkg\machine\qemu,pkg\specgen\generate,pkg\systemd,test\e2e,test\utils,cmd\rootlessport,'
$skippackages += 'pkg\machine\proxyenv,pkg\machine\qemu,pkg\specgen\generate,pkg\systemd,test\e2e,test\utils,cmd\rootlessport,'
$skippackages += 'pkg\pidhandle'
if ($null -eq $ENV:GINKGOTIMEOUT) { $ENV:GINKGOTIMEOUT = '--timeout=15m' }
Run-Command "./bin/ginkgo.exe -vv -r --tags `"$remotetags`" ${ENV:GINKGOTIMEOUT} --trace --no-color --skip-package `"$skippackages`""
Expand Down