Skip to content

Commit 4681bd4

Browse files
committed
add windows platform tests
Signed-off-by: lstocchi <lstocchi@redhat.com>
1 parent b6768d3 commit 4681bd4

File tree

2 files changed

+141
-3
lines changed

2 files changed

+141
-3
lines changed

pkg/machine/provider/platform_windows.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ import (
1515
"go.podman.io/common/pkg/config"
1616
)
1717

18+
// Variables to hold permission check functions for testing purposes.
19+
var (
20+
hasAdminRightsFunc = windows.HasAdminRights
21+
hasHyperVAdminRightsFunc = hyperv.HasHyperVAdminRights
22+
)
23+
1824
func Get() (vmconfigs.VMProvider, error) {
1925
cfg, err := config.Default()
2026
if err != nil {
@@ -42,7 +48,7 @@ func GetByVMType(resolvedVMType define.VMType) (vmconfigs.VMProvider, error) {
4248
// Check permissions before returning the Hyper-V provider.
4349
// Working with Hyper-V requires users to be at least members of the Hyper-V admin group.
4450
// Init and remove actions have custom use cases and they are checked on the stubber.
45-
if !windows.HasAdminRights() && !hyperv.HasHyperVAdminRights() {
51+
if !hasAdminRightsFunc() && !hasHyperVAdminRightsFunc() {
4652
return nil, hyperv.ErrHypervUserNotInAdminGroup
4753
}
4854
return new(hyperv.HyperVStubber), nil
@@ -55,7 +61,7 @@ func GetAll() []vmconfigs.VMProvider {
5561
providers := []vmconfigs.VMProvider{
5662
new(wsl.WSLStubber),
5763
}
58-
if windows.HasAdminRights() || hyperv.HasHyperVAdminRights() {
64+
if hasAdminRightsFunc() || hasHyperVAdminRightsFunc() {
5965
providers = append(providers, new(hyperv.HyperVStubber))
6066
}
6167
return providers
@@ -92,7 +98,7 @@ func HasPermsForProvider(provider define.VMType) bool {
9298
case define.AppleHvVirt:
9399
return false
94100
case define.HyperVVirt:
95-
return windows.HasAdminRights()
101+
return hasAdminRightsFunc()
96102
}
97103

98104
return true
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//go:build windows
2+
3+
package provider
4+
5+
import (
6+
"testing"
7+
8+
"github.com/containers/podman/v6/pkg/machine/define"
9+
"github.com/containers/podman/v6/pkg/machine/hyperv"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
// helper to setup mocks and ensure cleanup
15+
func mockPermissions(t *testing.T, admin, hyperV bool) {
16+
origAdmin, origHyperV := hasAdminRightsFunc, hasHyperVAdminRightsFunc
17+
t.Cleanup(func() {
18+
hasAdminRightsFunc = origAdmin
19+
hasHyperVAdminRightsFunc = origHyperV
20+
})
21+
22+
hasAdminRightsFunc = func() bool { return admin }
23+
hasHyperVAdminRightsFunc = func() bool { return hyperV }
24+
}
25+
26+
func TestGetByVMType_HyperV(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
isAdmin bool
30+
isHyperV bool
31+
expectError bool
32+
}{
33+
{
34+
name: "WithAdminRights",
35+
isAdmin: true,
36+
isHyperV: false,
37+
expectError: false,
38+
},
39+
{
40+
name: "WithHyperVAdminRights",
41+
isAdmin: false,
42+
isHyperV: true,
43+
expectError: false,
44+
},
45+
{
46+
name: "WithBothRights",
47+
isAdmin: true,
48+
isHyperV: true,
49+
expectError: false,
50+
},
51+
{
52+
name: "WithoutPermissions",
53+
isAdmin: false,
54+
isHyperV: false,
55+
expectError: true,
56+
},
57+
}
58+
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
mockPermissions(t, tt.isAdmin, tt.isHyperV)
62+
63+
provider, err := GetByVMType(define.HyperVVirt)
64+
65+
if tt.expectError {
66+
assert.Error(t, err)
67+
assert.Equal(t, err.Error(), hyperv.ErrHypervUserNotInAdminGroup.Error())
68+
assert.Nil(t, provider)
69+
} else {
70+
require.NoError(t, err)
71+
assert.NotNil(t, provider)
72+
assert.Equal(t, define.HyperVVirt, provider.VMType())
73+
}
74+
})
75+
}
76+
}
77+
78+
func TestGetAll_HyperV_Inclusion(t *testing.T) {
79+
tests := []struct {
80+
name string
81+
isAdmin bool
82+
isHyperV bool
83+
expectHyperV bool
84+
}{
85+
{"WithAdminRights", true, false, true},
86+
{"WithHyperVRights", false, true, true},
87+
{"NoRights", false, false, false},
88+
}
89+
90+
for _, tt := range tests {
91+
t.Run(tt.name, func(t *testing.T) {
92+
mockPermissions(t, tt.isAdmin, tt.isHyperV)
93+
94+
providers := GetAll()
95+
96+
// Check for HyperV presence
97+
hasHyperV := false
98+
for _, p := range providers {
99+
if p.VMType() == define.HyperVVirt {
100+
hasHyperV = true
101+
break
102+
}
103+
}
104+
105+
assert.Equal(t, tt.expectHyperV, hasHyperV, "Hyper-V provider presence mismatch")
106+
107+
// WSL should always be present in these scenarios
108+
hasWSL := false
109+
for _, p := range providers {
110+
if p.VMType() == define.WSLVirt {
111+
hasWSL = true
112+
break
113+
}
114+
}
115+
assert.True(t, hasWSL, "GetAll should always include WSL provider")
116+
})
117+
}
118+
}
119+
120+
func TestGetByVMType_WSL_AlwaysWorks(t *testing.T) {
121+
provider, err := GetByVMType(define.WSLVirt)
122+
require.NoError(t, err)
123+
assert.NotNil(t, provider)
124+
assert.Equal(t, define.WSLVirt, provider.VMType())
125+
}
126+
127+
func TestGetByVMType_UnsupportedProvider(t *testing.T) {
128+
provider, err := GetByVMType(define.QemuVirt)
129+
assert.Error(t, err)
130+
assert.Contains(t, err.Error(), "unsupported virtualization provider")
131+
assert.Nil(t, provider)
132+
}

0 commit comments

Comments
 (0)