diff --git a/aks-flex-node-sudoers b/aks-flex-node-sudoers index f830e4d..2a27817 100644 --- a/aks-flex-node-sudoers +++ b/aks-flex-node-sudoers @@ -38,6 +38,11 @@ aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl status node-problem-det aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl check kubelet aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl check containerd aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl check node-problem-detector +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl enable openvpn@* +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl start openvpn@* +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl stop openvpn@* +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl restart openvpn@* +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl status openvpn@* aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl is-active * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl is-enabled * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/systemctl list-unit-files * @@ -82,10 +87,11 @@ aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/apt -y remove * aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/dpkg -i * aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/dpkg --purge * aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/lsof * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/apt install -y openvpn # Directory and file operations for Kubernetes paths - simplified for compatibility aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/mkdir *, /usr/bin/mkdir * -aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/mkdir -p /etc/kubernetes/*, /bin/mkdir -p /var/lib/kubelet/*, /bin/mkdir -p /var/lib/cni/*, /bin/mkdir -p /etc/containerd/*, /bin/mkdir -p /opt/cni/bin, /bin/mkdir -p /etc/cni/net.d, /bin/mkdir -p /etc/systemd/system/kubelet.service.d, /bin/mkdir -p /etc/sysctl.d, /bin/mkdir -p /etc/modules-load.d +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/mkdir -p /etc/kubernetes/*, /bin/mkdir -p /var/lib/kubelet/*, /bin/mkdir -p /var/lib/cni/*, /bin/mkdir -p /etc/containerd/*, /bin/mkdir -p /opt/cni/bin, /bin/mkdir -p /etc/cni/net.d, /bin/mkdir -p /etc/systemd/system/kubelet.service.d, /bin/mkdir -p /etc/sysctl.d, /bin/mkdir -p /etc/modules-load.d, /bin/mkdir -p /etc/aks-flex-node/*, /bin/mkdir -p /etc/openvpn aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/cp *, /bin/mv *, /bin/rm * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/chmod *, /bin/chown * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/ln *, /usr/bin/ln * @@ -94,6 +100,7 @@ aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/curl *, /usr/bin/wget * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/tar *, /usr/bin/unzip * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/ls *, /usr/bin/ls * aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/bin/test *, /bin/test * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/cat *, /usr/bin/cat * # System configuration for Kubernetes aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/sysctl --system @@ -117,9 +124,15 @@ aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/cat /etc/kubernetes/* aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/cat /var/lib/kubelet/kubeconfig -# Network operations for troubleshooting +# Network operations for troubleshooting and VPN gateway management aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip route aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip addr +aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip route add * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip route del * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip route delete * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/ip route show * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /usr/sbin/iptables * +aks-flex-node ALL=(root) NOPASSWD:SETENV: /sbin/iptables * aks-flex-node ALL=(root) NOPASSWD:SETENV: /bin/netstat -rn # Read-only Kubernetes API check for node readiness (used by status collector) diff --git a/assets/img/README/image.png b/assets/img/README/image.png new file mode 100644 index 0000000..13bea37 Binary files /dev/null and b/assets/img/README/image.png differ diff --git a/go.mod b/go.mod index 6cecbc6..264017a 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,9 @@ require ( require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice v1.0.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index 2a8019e..e350db6 100644 --- a/go.sum +++ b/go.sum @@ -8,12 +8,18 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDo github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 h1:qiir/pptnHqp6hV8QwV+IExYIf6cPsXBfUDUXQ27t2Y= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2/go.mod h1:jVRrRDLCOuif95HDYC23ADTMlvahB7tMdl519m9Iyjc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice v1.0.0 h1:figxyQZXzZQIcP3njhC68bYUiTw45J8/SsHaLW8Ax0M= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice v1.0.0/go.mod h1:TmlMW4W5OvXOmOyKNnor8nlMMiO1ctIyzmHme/VHsrA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 h1:5n7dPVqsWfVKw+ZiEKSd3Kzu7gwBkbEBkeXb8rgaE9Q= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0/go.mod h1:HcZY0PHPo/7d75p99lB6lK0qYOP4vLRJUBpiehYXtLQ= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcompute/armhybridcompute v1.2.0 h1:7UuAn4ljE+H3GQ7qts3c7oAaMRvge68EgyckoNP/1Ro= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/hybridcompute/armhybridcompute v1.2.0/go.mod h1:F2eDq/BGK2LOEoDtoHbBOphaPqcjT0K/Y5Am8vf7+0w= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0 h1:QM6sE5k2ZT/vI5BEe0r7mqjsUSnhVBFbOsVkEuaEfiA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0/go.mod h1:243D9iHbcQXoFUtgHJwL7gl2zx1aDuDMjvBZVGr2uW0= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 h1:bXwSugBiSbgtz7rOtbfGf+woewp4f06orW9OP5BjHLA= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0/go.mod h1:Y/HgrePTmGy9HjdSGTqZNa+apUpTVIEVKXJyARP2lrk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.1.1 h1:7CBQ+Ei8SP2c6ydQTGCCrS35bDxgTMfoP2miAwK++OU= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.1.1/go.mod h1:c/wcGeGx5FUPbM/JltUYHZcKmigwyVLJlDq+4HdtXaw= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= diff --git a/pkg/bootstrapper/bootstrapper.go b/pkg/bootstrapper/bootstrapper.go index d719a4c..776a67a 100644 --- a/pkg/bootstrapper/bootstrapper.go +++ b/pkg/bootstrapper/bootstrapper.go @@ -14,6 +14,7 @@ import ( "go.goms.io/aks/AKSFlexNode/pkg/components/runc" "go.goms.io/aks/AKSFlexNode/pkg/components/services" "go.goms.io/aks/AKSFlexNode/pkg/components/system_configuration" + "go.goms.io/aks/AKSFlexNode/pkg/components/vpn_gateway" "go.goms.io/aks/AKSFlexNode/pkg/config" ) @@ -34,6 +35,7 @@ func (b *Bootstrapper) Bootstrap(ctx context.Context) (*ExecutionResult, error) // Define the bootstrap steps in order - using modules directly steps := []Executor{ arc.NewInstaller(b.logger), // Setup Arc + vpn_gateway.NewInstaller(b.logger), // Setup VPN Gateway (if enabled) services.NewUnInstaller(b.logger), // Stop kubelet before setup system_configuration.NewInstaller(b.logger), // Configure system (early) runc.NewInstaller(b.logger), // Install runc @@ -59,6 +61,7 @@ func (b *Bootstrapper) Unbootstrap(ctx context.Context) (*ExecutionResult, error containerd.NewUnInstaller(b.logger), // Uninstall containerd binary runc.NewUnInstaller(b.logger), // Uninstall runc binary system_configuration.NewUnInstaller(b.logger), // Clean system settings + vpn_gateway.NewUnInstaller(b.logger), // Clean VPN Gateway arc.NewUnInstaller(b.logger), // Uninstall Arc (after cleanup) } diff --git a/pkg/components/cni/cni_setup_installer.go b/pkg/components/cni/cni_setup_installer.go index e7445c5..8429f29 100644 --- a/pkg/components/cni/cni_setup_installer.go +++ b/pkg/components/cni/cni_setup_installer.go @@ -89,13 +89,6 @@ func (i *Installer) IsCompleted(ctx context.Context) bool { } } - // Validate Step 3: Bridge configuration - configPath := filepath.Join(DefaultCNIConfDir, bridgeConfigFile) - if !utils.FileExistsAndValid(configPath) { - i.logger.Debug("Bridge configuration file not found") - return false - } - i.logger.Debug("CNI setup validation passed - all components properly configured") return true } diff --git a/pkg/components/kubelet/kubelet_installer.go b/pkg/components/kubelet/kubelet_installer.go index 3239a9d..00febbe 100644 --- a/pkg/components/kubelet/kubelet_installer.go +++ b/pkg/components/kubelet/kubelet_installer.go @@ -208,42 +208,58 @@ func (i *Installer) createKubeletDefaultsFile() error { labels = append(labels, fmt.Sprintf("%s=%s", key, value)) } + // Build kubelet flags dynamically + kubeletFlags := []string{ + fmt.Sprintf("--v=%d", i.config.Node.Kubelet.Verbosity), + "--address=0.0.0.0", + "--anonymous-auth=false", + "--authentication-token-webhook=true", + "--authorization-mode=Webhook", + "--cgroup-driver=systemd", + "--cgroups-per-qos=true", + "--enforce-node-allocatable=pods", + fmt.Sprintf("--cluster-dns=%s", i.config.Node.Kubelet.DNSServiceIP), + "--cluster-domain=cluster.local", + "--event-qps=0", + fmt.Sprintf("--eviction-hard=%s", mapToEvictionThresholds(i.config.Node.Kubelet.EvictionHard, ",")), + fmt.Sprintf("--kube-reserved=%s", mapToKeyValuePairs(i.config.Node.Kubelet.KubeReserved, ",")), + fmt.Sprintf("--image-gc-high-threshold=%d", i.config.Node.Kubelet.ImageGCHighThreshold), + fmt.Sprintf("--image-gc-low-threshold=%d", i.config.Node.Kubelet.ImageGCLowThreshold), + fmt.Sprintf("--max-pods=%d", i.config.Node.MaxPods), + "--node-status-update-frequency=10s", + fmt.Sprintf("--pod-infra-container-image=%s", i.config.Containerd.PauseImage), + "--pod-max-pids=-1", + "--protect-kernel-defaults=true", + "--read-only-port=0", + "--resolv-conf=/run/systemd/resolve/resolv.conf", + "--streaming-connection-idle-timeout=4h", + "--tls-cipher-suites=TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_128_GCM_SHA256", + } + + // Add VPN node IP if VPN gateway is enabled and connected + if vpnIP := i.getVPNInterfaceIP(); vpnIP != "" { + kubeletFlags = append(kubeletFlags, fmt.Sprintf("--node-ip=%s", vpnIP)) + i.logger.Infof("Configuring kubelet to use VPN interface IP: %s", vpnIP) + } + + // Format flags with proper line continuation + flagsFormatted := make([]string, len(kubeletFlags)) + for i, flag := range kubeletFlags { + flagsFormatted[i] = fmt.Sprintf(" %s \\", flag) + } + // Remove trailing backslash from last flag + if len(flagsFormatted) > 0 { + lastFlag := flagsFormatted[len(flagsFormatted)-1] + flagsFormatted[len(flagsFormatted)-1] = strings.TrimSuffix(lastFlag, " \\") + } + kubeletDefaults := fmt.Sprintf(`KUBELET_NODE_LABELS="%s" KUBELET_CONFIG_FILE_FLAGS="" KUBELET_FLAGS="\ - --v=%d \ - --address=0.0.0.0 \ - --anonymous-auth=false \ - --authentication-token-webhook=true \ - --authorization-mode=Webhook \ - --cgroup-driver=systemd \ - --cgroups-per-qos=true \ - --enforce-node-allocatable=pods \ - --cluster-dns=%s \ - --cluster-domain=cluster.local \ - --event-qps=0 \ - --eviction-hard=%s \ - --kube-reserved=%s \ - --image-gc-high-threshold=%d \ - --image-gc-low-threshold=%d \ - --max-pods=%d \ - --node-status-update-frequency=10s \ - --pod-max-pids=-1 \ - --protect-kernel-defaults=true \ - --read-only-port=0 \ - --resolv-conf=/run/systemd/resolve/resolv.conf \ - --streaming-connection-idle-timeout=4h \ - --rotate-certificates=true \ - --tls-cipher-suites=TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_256_GCM_SHA384,TLS_RSA_WITH_AES_128_GCM_SHA256 \ +%s \ "`, strings.Join(labels, ","), - i.config.Node.Kubelet.Verbosity, - i.config.Node.Kubelet.DNSServiceIP, - mapToEvictionThresholds(i.config.Node.Kubelet.EvictionHard, ","), - mapToKeyValuePairs(i.config.Node.Kubelet.KubeReserved, ","), - i.config.Node.Kubelet.ImageGCHighThreshold, - i.config.Node.Kubelet.ImageGCLowThreshold, - i.config.Node.MaxPods) + strings.Join(flagsFormatted, "\n")) // Ensure /etc/default directory exists if err := utils.RunSystemCommand("mkdir", "-p", etcDefaultDir); err != nil { @@ -709,3 +725,28 @@ func mapToEvictionThresholds(m map[string]string, separator string) string { } return strings.Join(pairs, separator) } + +// getVPNInterfaceIP returns the IP address of the VPN interface if VPN is enabled and connected +func (i *Installer) getVPNInterfaceIP() string { + // Check if VPN gateway is enabled in configuration + if !i.config.IsVPNGatewayEnabled() { + return "" + } + + // Get VPN interface using the generic utility function + vpnInterface, err := utils.GetVPNInterface() + if err != nil { + i.logger.Debugf("VPN interface not found: %v", err) + return "" + } + + // Get IP address of the VPN interface + ip, err := utils.GetVPNInterfaceIP(vpnInterface) + if err != nil { + i.logger.Debugf("Failed to get VPN interface IP: %v", err) + return "" + } + + i.logger.Infof("Found VPN interface %s with IP: %s", vpnInterface, ip) + return ip +} diff --git a/pkg/components/vpn_gateway/README.md b/pkg/components/vpn_gateway/README.md new file mode 100644 index 0000000..a4da174 --- /dev/null +++ b/pkg/components/vpn_gateway/README.md @@ -0,0 +1,29 @@ +# VPN Gateway Component + +This component provides VPN connectivity for AKS Flex Node using OpenVPN over Point-to-Site (P2S) connections. It's designed for scenarios where a limited number of clients need secure access to a virtual network. + +![VPN Gateway Architecture](../../../assets/img/README/image.png) + +## Overview + +The VPN Gateway component enables secure connectivity between AKS Flex Nodes and Azure Virtual Networks through: + +- **Certificate-based authentication** using self-generated root and client certificates +- **OpenVPN SSL tunnel** for encrypted communication +- **Automatic IP management** to update node IPs when VPN interface changes +- **Azure integration** for seamless VPN gateway configuration + +## Steps +1. Prepare Azure Resources +- Create a GatewaySubnet within the AKS VNet +- Deploy a Route-based Azure VPN Gateway into the GatewaySubnet + +2. Prepare Certificates +- root certificate: will be uploaded to Azure as a "trusted" cert (a Base64 encoded X.509 .cer file.) +- client certificates: generated from the root certificate and to be installed on each client computer for client authentication + +3. Configure VPN client profile + +## References + +[Configure server settings for P2S VPN Gateway certificate authentication](https://learn.microsoft.com/en-us/azure/vpn-gateway/point-to-site-certificate-gateway) \ No newline at end of file diff --git a/pkg/components/vpn_gateway/consts.go b/pkg/components/vpn_gateway/consts.go new file mode 100644 index 0000000..c1367e1 --- /dev/null +++ b/pkg/components/vpn_gateway/consts.go @@ -0,0 +1,96 @@ +package vpn_gateway + +import ( + "path/filepath" + "time" +) + +const ( + // VPN Gateway default name + defaultVPNGatewayName = "vpn-gateway" + + // Azure VPN Gateway configuration + vpnClientRootCertName = "VPNClientRootCert" + gatewaySubnetName = "GatewaySubnet" + gatewaySubnetPrefix = 27 // /27 subnet for GatewaySubnet + + // Directory paths + systemConfigDir = "/etc/aks-flex-node" + certificatesDir = "/etc/aks-flex-node/certs" + openVPNConfigDir = "/etc/openvpn" + + // File names + vpnConfigFileName = "vpn-config.ovpn" + vpnClientCertFileName = "vpn-client.crt" + vpnClientKeyFileName = "vpn-client.key" + vpnRootCertFileName = "vpn-root-ca.crt" + openVPNConfigFileName = "vpnconfig.conf" + + // File permissions + certificatesDirPerm = 0700 + configDirPerm = 0755 + privateKeyFilePerm = 0600 + certificateFilePerm = 0644 + + // Certificate configuration + certificateKeySize = 2048 + certificateValidYears = 10 + certificateCommonName = "VPN CA" + + // PEM block types + rsaPrivateKeyType = "RSA PRIVATE KEY" + certificateType = "CERTIFICATE" + + // Timeouts and intervals + gatewayProvisioningTimeout = 30 * time.Minute // VPN Gateway provisioning timeout + gatewayStatusCheckInterval = 30 * time.Second // Polling interval for gateway status + vpnConnectionTimeout = 1 * time.Minute // VPN connection establishment timeout + vpnConnectionCheckInterval = 2 * time.Second // Interval for VPN connection checks + + // System paths for validation + systemEtcPrefix = "/etc/" + systemUsrPrefix = "/usr/" + systemVarPrefix = "/var/" + + // Temporary file patterns + tempVPNConfigPattern = "vpnconfig-*.ovpn" + tempVPNCertPattern = "vpn-cert-*.tmp" + tempVPNZipPattern = "vpnconfig-*.zip" + tempVPNExtractPrefix = "vpnconfig-" + + // OpenVPN service template + openVPNServiceTemplate = "openvpn@vpnconfig" + openVPNServiceName = "vpnconfig" + + // Public IP naming pattern + gatewayPublicIPName = "vpn-gateway-ip" + vpnGatewayName = "vpn-gateway" + + // Point-to-Site configuration name + p2sConfigName = "P2SConfig" +) + +// GetVPNClientCertPath returns the full path to the VPN client certificate file +func GetVPNClientCertPath() string { + return filepath.Join(certificatesDir, vpnClientCertFileName) +} + +// GetVPNClientKeyPath returns the full path to the VPN client private key file +func GetVPNClientKeyPath() string { + return filepath.Join(certificatesDir, vpnClientKeyFileName) +} + +// GetVPNRootCertPath returns the full path to the VPN root CA certificate file +func GetVPNRootCertPath() string { + return filepath.Join(certificatesDir, vpnRootCertFileName) +} + +// GetOpenVPNConfigPath returns the full path to the OpenVPN configuration file +func GetOpenVPNConfigPath() string { + return filepath.Join(openVPNConfigDir, openVPNConfigFileName) +} + +// GetVPNConfigPath returns the full path to the VPN configuration file in system config directory +func GetVPNConfigPath() string { + return filepath.Join(systemConfigDir, vpnConfigFileName) +} diff --git a/pkg/components/vpn_gateway/vpn_gateway_installer.go b/pkg/components/vpn_gateway/vpn_gateway_installer.go new file mode 100644 index 0000000..b8eca47 --- /dev/null +++ b/pkg/components/vpn_gateway/vpn_gateway_installer.go @@ -0,0 +1,674 @@ +package vpn_gateway + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4" + "github.com/Azure/go-autorest/autorest/to" + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/auth" + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// Installer handles VPN Gateway installation operations +type Installer struct { + config *config.Config + logger *logrus.Logger + vnetClient *armnetwork.VirtualNetworksClient + subnetsClient *armnetwork.SubnetsClient + vgwClient *armnetwork.VirtualNetworkGatewaysClient + publicIPClient *armnetwork.PublicIPAddressesClient +} + +// NewInstaller creates a new VPN Gateway installer +func NewInstaller(logger *logrus.Logger) *Installer { + return &Installer{ + config: config.GetConfig(), + logger: logger, + } +} + +// Validate validates prerequisites for VPN Gateway installation +func (i *Installer) Validate(ctx context.Context) error { + if !i.config.IsVPNGatewayEnabled() { + i.logger.Info("VPN Gateway setup is not enabled in configuration, skipping Validate...") + return nil + } + + if i.config.Azure.VPNGateway.P2SGatewayCIDR == "" { + return fmt.Errorf("P2S Gateway CIDR is not configured") + } + + if i.config.Azure.VPNGateway.PodCIDR == "" { + return fmt.Errorf("pod CIDR is not configured - this is required for VPN network routing") + } + + if i.config.Azure.VPNGateway.VNetID == "" { + return fmt.Errorf("VNet ID for VPN Gateway is not configured") + } + + // Validate that VNet ID is a proper Azure resource ID + if err := utils.ValidateAzureResourceID(i.config.Azure.VPNGateway.VNetID, "virtualNetworks"); err != nil { + return fmt.Errorf("invalid VNet ID: %w", err) + } + + return nil +} + +// GetName returns the step name +func (i *Installer) GetName() string { + return "VPNGatewayInstaller" +} + +type vnetResourceInfo struct { + vnetID string + location string + resourceGroupName string + subscriptionID string + vnet *armnetwork.VirtualNetwork +} + +// Execute performs VPN Gateway setup as part of the bootstrap process +// This method handles the whole VPN Gateway creation and configuration flow: +// 1. VPN Gateway provisioning +// 2. Certificate generation and upload +// 3. VPN client configuration download +// 4. VPN connection establishment +func (i *Installer) Execute(ctx context.Context) error { + i.logger.Info("Starting VPN Gateway setup for bootstrap process") + + // Set up Azure clients + if err := i.setUpClients(ctx); err != nil { + i.logger.Errorf("Failed to set up Azure clients: %v", err) + return fmt.Errorf("vpn gateway setup failed at client setup: %w", err) + } + + // Discover the VNet used by AKS cluster nodes - it can be either BYO VNet or AKS managed VNet + // The VPN Gateway will be created in this VNet to establish connectivity between the flex node and AKS cluster nodes + vnetInfo, err := i.getNodeVNet(ctx) + if err != nil { + i.logger.Errorf("Failed to get AKS managed VNet: %v", err) + return fmt.Errorf("vpn gateway setup failed at VNet discovery: %w", err) + } + + // Provision VPN Gateway in the AKS Node VNet + _, err = i.provisionGateway(ctx, vnetInfo) + if err != nil { + i.logger.Errorf("Failed to provision VPN Gateway: %v", err) + return fmt.Errorf("vpn gateway setup failed at gateway provisioning: %w", err) + } + + // Check if VPN connection is already working before setting up certificates + if i.isVPNConnected() { + i.logger.Info("VPN connection is already established, skipping certificate setup and connection establishment") + } else { + // Setup VPN Gateway certificates (root and client) + i.logger.Info("Setting up VPN certificates") + if err := i.setupCertificates(ctx, vnetInfo); err != nil { + i.logger.Errorf("Failed to setup certificates: %v", err) + return fmt.Errorf("vpn gateway setup failed at certificate setup: %w", err) + } + i.logger.Info("VPN certificates setup completed") + + // Download VPN configuration + i.logger.Info("Downloading VPN client configuration") + configPath, err := i.downloadVPNConfig(ctx, vnetInfo) + if err != nil { + i.logger.Errorf("Failed to download VPN configuration: %v", err) + return fmt.Errorf("vpn gateway setup failed at config download: %w", err) + } + i.logger.Infof("VPN configuration downloaded to: %s", configPath) + + // Establish VPN connection using the downloaded configuration + i.logger.Info("Establishing VPN connection") + connected, err := i.establishVPNConnection(ctx, configPath) + if err != nil { + i.logger.Errorf("Failed to establish VPN connection: %v", err) + return fmt.Errorf("vpn gateway setup failed at connection establishment: %w", err) + } + if !connected { + return fmt.Errorf("vpn gateway setup failed: VPN connection could not be established") + } + i.logger.Info("VPN connection established successfully") + } + + // Always configure network routes and iptables rules + i.logger.Info("Configuring VPN network routing") + if err := i.configureVPNNetworking(ctx, vnetInfo); err != nil { + i.logger.Errorf("Failed to configure VPN networking: %v", err) + return fmt.Errorf("vpn gateway setup failed at network configuration: %w", err) + } + i.logger.Info("VPN networking configuration completed") + + i.logger.Info("VPN Gateway setup completed successfully") + return nil +} + +// setUpAKSClients sets up Azure Container Service clients using the target cluster subscription ID +func (i *Installer) setUpClients(ctx context.Context) error { + cred, err := auth.NewAuthProvider().UserCredential(config.GetConfig()) + if err != nil { + return fmt.Errorf("failed to get authentication credential: %w", err) + } + + vnetID := i.config.GetVPNGatewayVNetID() + if vnetID == "" { + return fmt.Errorf("failed to get VNet ID from configuration") + } + vnetSub := utils.GetSubscriptionIDFromResourceID(vnetID) + + clientFactory, err := armnetwork.NewClientFactory(vnetSub, cred, nil) + if err != nil { + return fmt.Errorf("failed to create Azure Network client factory: %w", err) + } + + i.vnetClient = clientFactory.NewVirtualNetworksClient() + i.subnetsClient = clientFactory.NewSubnetsClient() + i.vgwClient = clientFactory.NewVirtualNetworkGatewaysClient() + i.publicIPClient = clientFactory.NewPublicIPAddressesClient() + return nil +} + +// IsCompleted checks if VPN Gateway setup has been completed +func (i *Installer) IsCompleted(ctx context.Context) bool { + if !i.config.IsVPNGatewayEnabled() { + i.logger.Info("VPN Gateway setup is disabled in configuration, skipping installation...") + return true + } + + i.logger.Debug("Checking VPN Gateway setup completion status") + + // Check if VPN is connected + if !i.isVPNConnected() { + i.logger.Debug("VPN is not connected") + return false + } + + // // Check if network configuration is applied (automatically discovered) + // if !i.isNetworkConfigured(ctx) { + // i.logger.Debug("VPN network configuration not applied") + // return false + // } + + // // Check if VPN Gateway exists in Azure + // if gateway, err := i.getVPNGateway(ctx); err != nil || gateway == nil { + // i.logger.Debugf("VPN Gateway not found or not accessible: %v", err) + // return false + // } + + // i.logger.Debug("VPN Gateway setup appears to be completed") + return false +} + +// provisionGateway handles VPN Gateway provisioning with idempotency +func (i *Installer) provisionGateway(ctx context.Context, vnetInfo vnetResourceInfo) (*armnetwork.VirtualNetworkGateway, error) { + // Check if VPN Gateway already exists + if gateway, err := i.getVPNGateway(ctx, vnetInfo); err == nil && gateway != nil { + i.logger.Infof("VPN Gateway already exists: %s", to.String(gateway.Name)) + return gateway, nil + } + + i.logger.Infof("Provisioning VPN Gateway in VNet: %s", vnetInfo.vnetID) + + // Ensure GatewaySubnet exists + if err := i.ensureGatewaySubnet(ctx, vnetInfo); err != nil { + return nil, fmt.Errorf("failed to ensure gateway subnet: %w", err) + } + + // Create Public IP for VPN Gateway + publicIP, err := i.createPublicIPForVPNGateway(ctx, vnetInfo) + if err != nil { + return nil, fmt.Errorf("failed to create public IP: %w", err) + } + + // Create VPN Gateway in the GatewaySubnet + gateway, err := i.createVPNGateway(ctx, vnetInfo, publicIP) + if err != nil { + return nil, fmt.Errorf("failed to create VPN Gateway: %w", err) + } + + i.logger.Infof("Successfully provisioned VPN Gateway: %s", to.String(gateway.Name)) + return gateway, nil +} + +// createPublicIPForVPNGateway creates a public IP for the VPN Gateway +func (i *Installer) createPublicIPForVPNGateway(ctx context.Context, vnetInfo vnetResourceInfo) (string, error) { + i.logger.Infof("Ensuring public IP exists: %s", gatewayPublicIPName) + + // Prepare Public IP parameters + allocationMethod := armnetwork.IPAllocationMethodStatic + skuName := armnetwork.PublicIPAddressSKUNameStandard + skuTier := armnetwork.PublicIPAddressSKUTierRegional + + publicIPParams := armnetwork.PublicIPAddress{ + Location: &vnetInfo.location, + SKU: &armnetwork.PublicIPAddressSKU{ + Name: &skuName, + Tier: &skuTier, + }, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: &allocationMethod, + }, + Zones: []*string{ + &[]string{"1"}[0], + }, + } + + // Create Public IP - this is a long-running operation + poller, err := i.publicIPClient.BeginCreateOrUpdate(ctx, vnetInfo.resourceGroupName, gatewayPublicIPName, publicIPParams, nil) + if err != nil { + return "", fmt.Errorf("failed to start public IP creation: %w", err) + } + + i.logger.Info("Public IP creation initiated. Waiting for completion...") + + // Wait for completion + result, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return "", fmt.Errorf("failed to create public IP: %w", err) + } + + i.logger.Infof("Successfully created public IP: %s", to.String(result.ID)) + return to.String(result.ID), nil +} + +// setupCertificates handles certificate generation and upload +func (i *Installer) setupCertificates(ctx context.Context, vnetInfo vnetResourceInfo) error { + i.logger.Info("Setting up VPN root certificates...") + certData, err := i.generateCertificates() + if err != nil { + return fmt.Errorf("failed to generate VPN certificates: %w", err) + } + + i.logger.Info("Uploading VPN root certificate to Azure VPN Gateway...") + if err := i.uploadCertificateToAzure(ctx, certData, vnetInfo); err != nil { + i.logger.Warnf("Certificate upload failed: %v", err) + return fmt.Errorf("failed to upload certificate to Azure: %w", err) + } + i.logger.Info("Certificate uploaded to Azure VPN Gateway successfully") + + return nil +} + +// downloadVPNConfig downloads and saves the VPN configuration +func (i *Installer) downloadVPNConfig(ctx context.Context, vnetInfo vnetResourceInfo) (string, error) { + i.logger.Info("Downloading VPN client configuration...") + configData, err := i.downloadVPNClientConfig(ctx, defaultVPNGatewayName, vnetInfo.resourceGroupName) + if err != nil { + return "", fmt.Errorf("failed to download VPN client configuration: %w", err) + } + + // Save configuration to file + configPath, err := i.saveVPNConfig(configData) + if err != nil { + return "", fmt.Errorf("failed to save VPN config: %w", err) + } + + return configPath, nil +} + +// establishVPNConnection establishes the VPN connection +func (i *Installer) establishVPNConnection(ctx context.Context, configPath string) (bool, error) { + i.logger.Info("Setting up OpenVPN with downloaded configuration...") + if err := i.setupOpenVPN(configPath); err != nil { + return false, fmt.Errorf("failed to setup OpenVPN: %w", err) + } + + i.logger.Info("Waiting for VPN connection to establish...") + if err := i.waitForVPNConnection(vpnConnectionTimeout); err != nil { + return false, fmt.Errorf("VPN connection failed to establish: %w", err) + } + + i.logger.Info("VPN connection established successfully") + return true, nil +} + +// waitForVPNConnection waits for VPN connection to be established +func (i *Installer) waitForVPNConnection(timeout time.Duration) error { + i.logger.Infof("Waiting up to %v for VPN connection...", timeout) + + start := time.Now() + for time.Since(start) < timeout { + if i.isVPNConnected() { + i.logger.Info("VPN connection established successfully") + return nil + } + + i.logger.Debug("VPN not connected yet, waiting...") + time.Sleep(vpnConnectionCheckInterval) + } + + return fmt.Errorf("VPN connection timeout after %v", timeout) +} + +// saveVPNConfig saves VPN configuration to the appropriate directory +func (i *Installer) saveVPNConfig(configData string) (string, error) { + configPath := GetVPNConfigPath() + + // Save VPN config to the persistent location atomically + if err := utils.WriteFileAtomicSystem(configPath, []byte(configData), certificateFilePerm); err != nil { + return "", fmt.Errorf("failed to save VPN config file: %w", err) + } + + i.logger.Infof("VPN configuration saved to: %s", configPath) + return configPath, nil +} + +// calculateGatewaySubnetCIDR calculates an appropriate GatewaySubnet CIDR +func (i *Installer) calculateGatewaySubnetCIDR(ctx context.Context, vnetInfo vnetResourceInfo) (string, error) { + i.logger.Infof("Calculating GatewaySubnet CIDR for VNet: %s", vnetInfo.vnetID) + + // proactive checks, should not happen + if vnetInfo.vnet.Properties == nil || + vnetInfo.vnet.Properties.AddressSpace == nil || + len(vnetInfo.vnet.Properties.AddressSpace.AddressPrefixes) == 0 { + return "", fmt.Errorf("VNet has no address prefixes") + } + + // Try each address prefix until we find one with available space + var lastErr error + for idx, prefix := range vnetInfo.vnet.Properties.AddressSpace.AddressPrefixes { + if prefix == nil { + continue + } + + vnetCIDR := *prefix + i.logger.Infof("Trying VNet address prefix %d: %s", idx+1, vnetCIDR) + + // Calculate an available /27 subnet for GatewaySubnet in this address prefix + gatewaySubnetCIDR, err := i.calculateAvailableSubnetInRange(vnetCIDR, vnetInfo.vnet.Properties.Subnets, gatewaySubnetPrefix) + if err != nil { + i.logger.Warnf("No available space in address prefix %s: %v", vnetCIDR, err) + lastErr = err + continue + } + + i.logger.Infof("Successfully calculated GatewaySubnet CIDR: %s in address prefix: %s", gatewaySubnetCIDR, vnetCIDR) + return gatewaySubnetCIDR, nil + } + + // If we get here, no address prefix had available space + return "", fmt.Errorf("no available space for GatewaySubnet in any VNet address prefix. Last error: %w", lastErr) +} + +// calculateAvailableSubnetInRange finds an available subnet within the VNet address space +func (i *Installer) calculateAvailableSubnetInRange(vnetCIDR string, existingSubnets []*armnetwork.Subnet, prefixLength int) (string, error) { + // Parse VNet CIDR + _, vnetNet, err := net.ParseCIDR(vnetCIDR) + if err != nil { + return "", fmt.Errorf("failed to parse VNet CIDR %s: %w", vnetCIDR, err) + } + + // Convert existing subnets to IPNet for overlap checking + var existingNets []*net.IPNet + for _, subnet := range existingSubnets { + if subnet.Properties.AddressPrefix != nil && *subnet.Properties.AddressPrefix != "" { + _, subnetNet, err := net.ParseCIDR(*subnet.Properties.AddressPrefix) + if err != nil { + i.logger.Warnf("Failed to parse existing subnet CIDR %s: %v", *subnet.Properties.AddressPrefix, err) + continue + } + existingNets = append(existingNets, subnetNet) + } + } + + // Calculate subnet size + subnetSize := uint32(1 << (32 - prefixLength)) + + // Try to find an available subnet range + vnetIP := vnetNet.IP.To4() + if vnetIP == nil { + return "", fmt.Errorf("only IPv4 networks are supported") + } + + // Convert IP to uint32 for easier calculation + vnetStart := uint32(vnetIP[0])<<24 | uint32(vnetIP[1])<<16 | uint32(vnetIP[2])<<8 | uint32(vnetIP[3]) + prefixLen := i.getNetworkPrefixLength(vnetNet) + if prefixLen < 0 || prefixLen > 32 { + return "", fmt.Errorf("invalid network prefix length: %d", prefixLen) + } + vnetPrefixLength := uint32(prefixLen) + vnetMask := uint32(0xFFFFFFFF) << (32 - vnetPrefixLength) + vnetEnd := vnetStart | (^vnetMask) + + // Start from a high address in the VNet range to avoid conflicts with existing subnets + startAddress := vnetEnd - subnetSize + 1 + startAddress = startAddress &^ (subnetSize - 1) // Align to subnet boundary + + for currentAddr := startAddress; currentAddr >= vnetStart; currentAddr -= subnetSize { + // Create candidate subnet + candidateIP := net.IPv4( + byte(currentAddr>>24), + byte(currentAddr>>16), + byte(currentAddr>>8), + byte(currentAddr), + ) + + candidateNet := &net.IPNet{ + IP: candidateIP, + Mask: net.CIDRMask(prefixLength, 32), + } + + // Check if this subnet overlaps with any existing subnet + overlaps := false + for _, existing := range existingNets { + if i.subnetsOverlap(candidateNet, existing) { + overlaps = true + break + } + } + + if !overlaps { + return candidateNet.String(), nil + } + } + + return "", fmt.Errorf("no available /%d subnet found in VNet %s", prefixLength, vnetCIDR) +} + +// getNetworkPrefixLength returns the prefix length of a network +func (i *Installer) getNetworkPrefixLength(network *net.IPNet) int { + ones, _ := network.Mask.Size() + return ones +} + +// subnetsOverlap checks if two subnets overlap +func (i *Installer) subnetsOverlap(subnet1, subnet2 *net.IPNet) bool { + return subnet1.Contains(subnet2.IP) || subnet2.Contains(subnet1.IP) || + subnet1.Contains(i.getLastIP(subnet2)) || subnet2.Contains(i.getLastIP(subnet1)) +} + +// getLastIP returns the last IP address in a subnet +func (i *Installer) getLastIP(network *net.IPNet) net.IP { + ip := network.IP.To4() + if ip == nil { + return nil + } + + // Convert to uint32 + ipInt := uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) + + // Get network mask + ones, bits := network.Mask.Size() + mask := uint32(0xFFFFFFFF) << (bits - ones) + + // Calculate last IP + lastIPInt := ipInt | (^mask) + + return net.IPv4( + byte(lastIPInt>>24), + byte(lastIPInt>>16), + byte(lastIPInt>>8), + byte(lastIPInt), + ) +} + +// AKS nodes can be in either BYO VNet or AKS managed VNet +func (i *Installer) getNodeVNet(ctx context.Context) (vnetResourceInfo, error) { + // First try to discover BYO VNet from agent pools + vnetID := i.config.GetVPNGatewayVNetID() + // Get VNet details + vnetResp, err := i.vnetClient.Get(ctx, + utils.GetResourceGroupFromResourceID(vnetID), + utils.GetResourceNameFromResourceID(vnetID), nil) + if err != nil { + return vnetResourceInfo{}, fmt.Errorf("failed to get VNet details for VNet ID %s: %w", vnetID, err) + } + + vnet := &vnetResp.VirtualNetwork + vnetInfo := vnetResourceInfo{ + vnetID: to.String(vnet.ID), + location: to.String(vnet.Location), + resourceGroupName: utils.GetResourceGroupFromResourceID(to.String(vnet.ID)), + subscriptionID: utils.GetSubscriptionIDFromResourceID(to.String(vnet.ID)), + vnet: vnet, + } + + return vnetInfo, nil +} + +// getVPNGateway finds a VPN Gateway by name using Azure SDK +func (i *Installer) getVPNGateway(ctx context.Context, vnetInfo vnetResourceInfo) (*armnetwork.VirtualNetworkGateway, error) { + // Get the specific VPN Gateway by name + resp, err := i.vgwClient.Get(ctx, vnetInfo.resourceGroupName, defaultVPNGatewayName, nil) + if err != nil { + if strings.Contains(err.Error(), "NotFound") { + i.logger.Infof("VPN Gateway '%s' not found in resource group '%s'", defaultVPNGatewayName, vnetInfo.resourceGroupName) + return nil, errors.New("NotFound") // VPN Gateway not found + } + return nil, fmt.Errorf("failed to get VPN Gateway '%s' in resource group '%s': %w", defaultVPNGatewayName, vnetInfo.resourceGroupName, err) + } + + // Verify it's a VPN Gateway (GatewayType == "Vpn") + if resp.Properties != nil && + resp.Properties.GatewayType != nil && + *resp.Properties.GatewayType == armnetwork.VirtualNetworkGatewayTypeVPN { + + i.logger.Infof("Found VPN Gateway '%s' with GatewayType 'Vpn' in resource group '%s'", defaultVPNGatewayName, vnetInfo.resourceGroupName) + return &resp.VirtualNetworkGateway, nil + } + + i.logger.Infof("Gateway '%s' found but is not a VPN Gateway (GatewayType: %v)", defaultVPNGatewayName, resp.Properties.GatewayType) + return nil, errors.New("NotFound") // Gateway exists but is not a VPN Gateway +} + +// createVPNGateway creates a VPN Gateway +func (i *Installer) createVPNGateway(ctx context.Context, vnetInfo vnetResourceInfo, publicIPID string) (*armnetwork.VirtualNetworkGateway, error) { + i.logger.Infof("Creating VPN Gateway: %s in resource group: %s", vpnGatewayName, vnetInfo.resourceGroupName) + + // Construct gateway subnet ID + gatewaySubnetID := fmt.Sprintf("%s/subnets/%s", vnetInfo.vnetID, gatewaySubnetName) + + // Prepare VPN Gateway configuration + vpnGwSKU := armnetwork.VirtualNetworkGatewaySKUNameVPNGw2AZ + vpnGwTier := armnetwork.VirtualNetworkGatewaySKUTierVPNGw2AZ + gatewayType := armnetwork.VirtualNetworkGatewayTypeVPN + vpnType := armnetwork.VPNTypeRouteBased + enableBgp := false + activeActive := false + + // IP Configuration name + ipConfigName := p2sConfigName + + // VPN Client Configuration + p2sGatewayCIDR := i.config.Azure.VPNGateway.P2SGatewayCIDR + vpnClientProtocol := armnetwork.VPNClientProtocolOpenVPN + + gatewayParams := armnetwork.VirtualNetworkGateway{ + Location: &vnetInfo.location, + Properties: &armnetwork.VirtualNetworkGatewayPropertiesFormat{ + SKU: &armnetwork.VirtualNetworkGatewaySKU{ + Name: &vpnGwSKU, + Tier: &vpnGwTier, + }, + GatewayType: &gatewayType, + VPNType: &vpnType, + EnableBgp: &enableBgp, + Active: &activeActive, + IPConfigurations: []*armnetwork.VirtualNetworkGatewayIPConfiguration{ + { + Name: &ipConfigName, + Properties: &armnetwork.VirtualNetworkGatewayIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.SubResource{ + ID: &publicIPID, + }, + Subnet: &armnetwork.SubResource{ + ID: &gatewaySubnetID, + }, + }, + }, + }, + VPNClientConfiguration: &armnetwork.VPNClientConfiguration{ + VPNClientAddressPool: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{&p2sGatewayCIDR}, + }, + VPNClientProtocols: []*armnetwork.VPNClientProtocol{&vpnClientProtocol}, + }, + }, + } + + // Create VPN Gateway - this is a long-running operation + poller, err := i.vgwClient.BeginCreateOrUpdate(ctx, vnetInfo.resourceGroupName, vpnGatewayName, gatewayParams, nil) + if err != nil { + return nil, fmt.Errorf("failed to start VPN Gateway creation: %w", err) + } + + i.logger.Info("VPN Gateway creation initiated. Waiting for completion (this may take 20-30 minutes)...") + + resp, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return nil, fmt.Errorf("failed to create VPN Gateway: %w", err) + } + + i.logger.Infof("VPN Gateway creation completed: %s", *resp.Name) + return &resp.VirtualNetworkGateway, nil +} + +// ensureGatewaySubnet creates GatewaySubnet if it doesn't exist +func (i *Installer) ensureGatewaySubnet(ctx context.Context, vnetInfo vnetResourceInfo) error { + // Check if GatewaySubnet already exists + for _, subnet := range vnetInfo.vnet.Properties.Subnets { + if strings.EqualFold(to.String(subnet.Name), gatewaySubnetName) { + i.logger.Infof("GatewaySubnet already exists in VNet %s", vnetInfo.vnetID) + return nil + } + } + + // Calculate a CIDR for GatewaySubnet to ensure no + gatewaySubnetCIDR, err := i.calculateGatewaySubnetCIDR(ctx, vnetInfo) + if err != nil { + return fmt.Errorf("failed to calculate gateway subnet CIDR: %w", err) + } + + i.logger.Infof("Creating GatewaySubnet with CIDR: %s", gatewaySubnetCIDR) + + gatewaySubnetParams := armnetwork.Subnet{ + Properties: &armnetwork.SubnetPropertiesFormat{ + AddressPrefix: &gatewaySubnetCIDR, + }, + } + + // Create the subnet - this is a long-running operation + poller, err := i.subnetsClient.BeginCreateOrUpdate(ctx, vnetInfo.resourceGroupName, to.String(vnetInfo.vnet.Name), gatewaySubnetName, gatewaySubnetParams, nil) + if err != nil { + return fmt.Errorf("failed to start GatewaySubnet creation: %w", err) + } + + i.logger.Info("GatewaySubnet creation initiated. Waiting for completion...") + + // Wait for completion + result, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create GatewaySubnet: %w", err) + } + + i.logger.Infof("Successfully created GatewaySubnet: %s", *result.Name) + return nil +} diff --git a/pkg/components/vpn_gateway/vpn_gateway_uninstaller.go b/pkg/components/vpn_gateway/vpn_gateway_uninstaller.go new file mode 100644 index 0000000..6dd5dfe --- /dev/null +++ b/pkg/components/vpn_gateway/vpn_gateway_uninstaller.go @@ -0,0 +1,409 @@ +package vpn_gateway + +import ( + "context" + "fmt" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4" + "github.com/sirupsen/logrus" + + "go.goms.io/aks/AKSFlexNode/pkg/auth" + "go.goms.io/aks/AKSFlexNode/pkg/config" + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// UnInstaller handles VPN Gateway cleanup operations +type UnInstaller struct { + config *config.Config + logger *logrus.Logger + vnetClient *armnetwork.VirtualNetworksClient + subnetsClient *armnetwork.SubnetsClient + vgwClient *armnetwork.VirtualNetworkGatewaysClient + publicIPClient *armnetwork.PublicIPAddressesClient +} + +// NewUnInstaller creates a new VPN Gateway uninstaller +func NewUnInstaller(logger *logrus.Logger) *UnInstaller { + cfg := config.GetConfig() + return &UnInstaller{ + config: cfg, + logger: logger, + } +} + +// GetName returns the cleanup step name +func (u *UnInstaller) GetName() string { + return "VPNGatewayCleanup" +} + +// Execute performs VPN Gateway cleanup as part of the unbootstrap process +// This method is resilient to failures and continues cleanup even if some operations fail +func (u *UnInstaller) Execute(ctx context.Context) error { + u.logger.Info("Starting VPN Gateway cleanup for unbootstrap process") + + // Set up Azure clients + if err := u.setUpClients(ctx); err != nil { + u.logger.Errorf("Failed to set up Azure clients: %v", err) + return fmt.Errorf("vpn gateway setup failed at client setup: %w", err) + } + + // Step 1: Disconnect VPN connection + u.logger.Info("Step 1: Disconnecting VPN connection") + if err := u.disconnectVPN(); err != nil { + u.logger.Warnf("Failed to disconnect VPN (continuing cleanup): %v", err) + } else { + u.logger.Info("Successfully disconnected VPN connection") + } + + // Step 2: Clean up VPN networking (routes and iptables rules) + u.logger.Info("Step 2: Cleaning up VPN networking configuration") + if err := u.cleanupVPNNetworking(); err != nil { + u.logger.Warnf("Failed to cleanup VPN networking (continuing cleanup): %v", err) + } else { + u.logger.Info("Successfully cleaned up VPN networking configuration") + } + + // Step 3: Clean up VPN configuration files and certificates + u.logger.Info("Step 3: Cleaning up VPN configuration files and certificates") + if err := u.cleanupVPNFiles(); err != nil { + u.logger.Warnf("Failed to cleanup VPN files (continuing cleanup): %v", err) + } else { + u.logger.Info("Successfully cleaned up VPN configuration files and certificates") + } + + // Note: We don't delete the VPN Gateway from Azure as it's expensive to recreate + // and might be shared with other resources. The VPN Gateway will be left in Azure. + u.logger.Info("VPN Gateway resources in Azure are preserved for potential reuse") + + if err := u.cleanupAzureResources(ctx); err != nil { + u.logger.Warnf("Failed to cleanup Azure VPN Gateway resources: %v", err) + } else { + u.logger.Info("Successfully cleaned up Azure VPN Gateway resources") + } + + u.logger.Info("VPN Gateway cleanup for unbootstrap completed") + return nil +} + +// setUpAKSClients sets up Azure Container Service clients using the target cluster subscription ID +func (u *UnInstaller) setUpClients(ctx context.Context) error { + cred, err := auth.NewAuthProvider().UserCredential(config.GetConfig()) + if err != nil { + return fmt.Errorf("failed to get authentication credential: %w", err) + } + + vnetID := u.config.GetVPNGatewayVNetID() + if vnetID == "" { + return fmt.Errorf("failed to get VNet ID from configuration") + } + vnetSub := utils.GetSubscriptionIDFromResourceID(vnetID) + + clientFactory, err := armnetwork.NewClientFactory(vnetSub, cred, nil) + if err != nil { + return fmt.Errorf("failed to create Azure Network client factory: %w", err) + } + + u.vnetClient = clientFactory.NewVirtualNetworksClient() + u.subnetsClient = clientFactory.NewSubnetsClient() + u.vgwClient = clientFactory.NewVirtualNetworkGatewaysClient() + u.publicIPClient = clientFactory.NewPublicIPAddressesClient() + return nil +} + +// IsCompleted checks if VPN Gateway cleanup has been completed +func (u *UnInstaller) IsCompleted(ctx context.Context) bool { + if !u.config.IsVPNGatewayEnabled() { + u.logger.Info("VPN Gateway is not enabled in configuration; skipping cleanup") + return true + } + return false +} + +// disconnectVPN stops the VPN connection and OpenVPN service +func (u *UnInstaller) disconnectVPN() error { + u.logger.Info("Disconnecting VPN connection") + + // Stop OpenVPN service + if err := utils.StopService(openVPNServiceTemplate); err != nil { + u.logger.Warnf("Failed to stop OpenVPN service: %v", err) + // Continue with other cleanup steps + } + + // Kill any remaining OpenVPN processes + if err := utils.RunSystemCommand("pkill", "-f", "openvpn"); err != nil { + u.logger.Warnf("Failed to kill OpenVPN processes: %v", err) + // Continue with other cleanup steps + } + + u.logger.Info("VPN disconnection completed") + return nil +} + +// cleanupVPNFiles removes VPN configuration files and certificates +func (u *UnInstaller) cleanupVPNFiles() error { + u.logger.Info("Cleaning up VPN configuration files and certificates") + + // Use the utility function for file removal + filesToCleanup := []string{ + GetVPNConfigPath(), + GetVPNClientCertPath(), + GetVPNClientKeyPath(), + GetVPNRootCertPath(), + } + + if errors := utils.RemoveFiles(filesToCleanup, u.logger); len(errors) > 0 { + for _, err := range errors { + u.logger.Warnf("File removal error: %v", err) + } + } + + // Try to remove the certificates directory using the utility function + if errors := utils.RemoveDirectories([]string{certificatesDir}, u.logger); len(errors) > 0 { + for _, err := range errors { + u.logger.Debugf("Directory removal error: %v", err) + } + } + + u.logger.Info("VPN files and certificates cleanup completed") + return nil +} + +// cleanupVPNNetworking removes IP routes and iptables rules configured for VPN +func (u *UnInstaller) cleanupVPNNetworking() error { + u.logger.Info("Cleaning up VPN networking configuration (routes and iptables rules)") + + // Get VPN interface to clean up routes + vpnInterface, err := utils.GetVPNInterface() + if err != nil { + u.logger.Infof("No VPN interface found, skipping route cleanup: %v", err) + // Continue with iptables cleanup even if no VPN interface + } else { + u.logger.Infof("Found VPN interface: %s, cleaning up routes", vpnInterface) + + // Get all routes via the VPN interface and remove them + if err := u.cleanupVPNRoutes(vpnInterface); err != nil { + u.logger.Warnf("Failed to clean up VPN routes: %v", err) + } + } + + // Clean up iptables MASQUERADE rules + if err := u.cleanupIPTablesRules(); err != nil { + u.logger.Warnf("Failed to clean up iptables rules: %v", err) + } + + u.logger.Info("VPN networking cleanup completed") + return nil +} + +// cleanupVPNRoutes removes all routes that go through the VPN interface +func (u *UnInstaller) cleanupVPNRoutes(vpnInterface string) error { + u.logger.Infof("Cleaning up routes via interface: %s", vpnInterface) + + // Get current routing table + output, err := utils.RunCommandWithOutput("ip", "route", "show") + if err != nil { + return fmt.Errorf("failed to get current routes: %w", err) + } + + // Parse routes and find ones using our VPN interface + routes := strings.Split(output, "\n") + routesRemoved := 0 + + for _, route := range routes { + route = strings.TrimSpace(route) + if route == "" { + continue + } + + // Check if this route uses our VPN interface + if strings.Contains(route, "dev "+vpnInterface) { + // Extract the destination from the route (first part before whitespace) + parts := strings.Fields(route) + if len(parts) > 0 { + dest := parts[0] + u.logger.Infof("Removing route: %s", dest) + + // Remove the route + if err := utils.RunSystemCommand("ip", "route", "del", dest); err != nil { + u.logger.Warnf("Failed to remove route %s: %v", dest, err) + // Continue with other routes + } else { + routesRemoved++ + u.logger.Infof("Removed route: %s", dest) + } + } + } + } + + u.logger.Infof("Removed %d VPN routes", routesRemoved) + return nil +} + +// cleanupIPTablesRules removes iptables MASQUERADE rules added for VPN +func (u *UnInstaller) cleanupIPTablesRules() error { + u.logger.Info("Cleaning up iptables MASQUERADE rules") + + // Get current NAT table rules + output, err := utils.RunCommandWithOutput("iptables", "-t", "nat", "-L", "POSTROUTING", "-n", "--line-numbers") + if err != nil { + return fmt.Errorf("failed to list iptables rules: %w", err) + } + + // Parse output to find MASQUERADE rules + lines := strings.Split(output, "\n") + rulesRemoved := 0 + + // Process lines in reverse order to maintain line numbers when deleting + for i := len(lines) - 1; i >= 0; i-- { + line := strings.TrimSpace(lines[i]) + if line == "" || strings.HasPrefix(line, "Chain") || strings.HasPrefix(line, "num") { + continue + } + + // Look for MASQUERADE rules (likely involving our VPN interface or VNet CIDRs) + if strings.Contains(line, "MASQUERADE") { + // Extract the line number (first field) + parts := strings.Fields(line) + if len(parts) > 0 { + lineNum := parts[0] + + // Check if this rule involves VPN (look for tun interface references) + if strings.Contains(line, "tun") { + u.logger.Infof("Removing iptables MASQUERADE rule: %s", line) + + // Remove the rule by line number + if err := utils.RunSystemCommand("iptables", "-t", "nat", "-D", "POSTROUTING", lineNum); err != nil { + u.logger.Warnf("Failed to remove iptables rule %s: %v", lineNum, err) + // Continue with other rules + } else { + rulesRemoved++ + u.logger.Infof("Removed iptables rule: %s", line) + } + } + } + } + } + + u.logger.Infof("Removed %d iptables MASQUERADE rules", rulesRemoved) + return nil +} + +// cleanupAzureResources removes VPN Gateway resources from Azure +// This includes: VPN Gateway, Public IP, and GatewaySubnet +func (u *UnInstaller) cleanupAzureResources(ctx context.Context) error { + u.logger.Info("Cleaning up VPN Gateway resources from Azure") + + vnetID := u.config.GetVPNGatewayVNetID() + resourceGroupName := utils.GetResourceGroupFromResourceID(vnetID) + vnetName := utils.GetResourceNameFromResourceID(vnetID) + + // Step 1: Delete VPN Gateway (this must be done first as it depends on other resources) + u.logger.Infof("Deleting VPN Gateway: %s (this may take 10-20 minutes)", defaultVPNGatewayName) + if err := u.deleteVPNGateway(ctx, resourceGroupName); err != nil { + u.logger.Warnf("Failed to delete VPN Gateway: %v", err) + // Continue with other cleanup even if VPN Gateway deletion fails + } else { + u.logger.Info("VPN Gateway successfully deleted from Azure") + } + + // Step 2: Delete Public IP + u.logger.Infof("Deleting Public IP: %s", gatewayPublicIPName) + if err := u.deletePublicIP(ctx, resourceGroupName); err != nil { + u.logger.Warnf("Failed to delete Public IP: %v", err) + // Continue with other cleanup + } else { + u.logger.Info("Public IP successfully deleted from Azure") + } + + // Step 3: Delete GatewaySubnet (this should be done last) + u.logger.Infof("Deleting GatewaySubnet: %s", gatewaySubnetName) + if err := u.deleteGatewaySubnet(ctx, resourceGroupName, vnetName); err != nil { + u.logger.Warnf("Failed to delete GatewaySubnet: %v", err) + // Continue - this is not critical + } else { + u.logger.Info("GatewaySubnet successfully deleted from Azure") + } + + u.logger.Info("Azure VPN Gateway resources cleanup completed") + return nil +} + +// deleteVPNGateway deletes the VPN Gateway +func (u *UnInstaller) deleteVPNGateway(ctx context.Context, resourceGroupName string) error { + // Check if VPN Gateway exists before trying to delete + gateway, err := u.vgwClient.Get(ctx, resourceGroupName, defaultVPNGatewayName, nil) + if err != nil { + // If gateway doesn't exist, consider it already deleted + u.logger.Infof("VPN Gateway %s not found, may already be deleted", defaultVPNGatewayName) + return nil + } + + if gateway.Properties == nil || gateway.Properties.ProvisioningState == nil { + u.logger.Warn("VPN Gateway found but has incomplete properties") + return nil + } + + u.logger.Infof("Found VPN Gateway %s in state: %s", defaultVPNGatewayName, *gateway.Properties.ProvisioningState) + + poller, err := u.vgwClient.BeginDelete(ctx, resourceGroupName, defaultVPNGatewayName, nil) + if err != nil { + return fmt.Errorf("failed to start VPN Gateway deletion: %w", err) + } + + // Wait for deletion to complete + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("VPN Gateway deletion failed: %w", err) + } + + return nil +} + +// deletePublicIP deletes the Public IP used by VPN Gateway +func (u *UnInstaller) deletePublicIP(ctx context.Context, resourceGroupName string) error { + // Check if Public IP exists before trying to delete + _, err := u.publicIPClient.Get(ctx, resourceGroupName, gatewayPublicIPName, nil) + if err != nil { + // If Public IP doesn't exist, consider it already deleted + u.logger.Infof("Public IP %s not found, may already be deleted", gatewayPublicIPName) + return nil + } + + poller, err := u.publicIPClient.BeginDelete(ctx, resourceGroupName, gatewayPublicIPName, nil) + if err != nil { + return fmt.Errorf("failed to start Public IP deletion: %w", err) + } + + // Wait for deletion to complete + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("public IP deletion failed: %w", err) + } + + return nil +} + +// deleteGatewaySubnet deletes the GatewaySubnet +func (u *UnInstaller) deleteGatewaySubnet(ctx context.Context, resourceGroupName, vnetName string) error { + // Check if GatewaySubnet exists before trying to delete + _, err := u.subnetsClient.Get(ctx, resourceGroupName, vnetName, gatewaySubnetName, nil) + if err != nil { + // If subnet doesn't exist, consider it already deleted + u.logger.Infof("GatewaySubnet %s not found, may already be deleted", gatewaySubnetName) + return nil + } + + poller, err := u.subnetsClient.BeginDelete(ctx, resourceGroupName, vnetName, gatewaySubnetName, nil) + if err != nil { + return fmt.Errorf("failed to start GatewaySubnet deletion: %w", err) + } + + // Wait for deletion to complete + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("GatewaySubnet deletion failed: %w", err) + } + + return nil +} diff --git a/pkg/components/vpn_gateway/vpn_operations.go b/pkg/components/vpn_gateway/vpn_operations.go new file mode 100644 index 0000000..9f0cd03 --- /dev/null +++ b/pkg/components/vpn_gateway/vpn_operations.go @@ -0,0 +1,794 @@ +package vpn_gateway + +import ( + "archive/zip" + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/pem" + "fmt" + "io" + "math/big" + "net" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4" + "go.goms.io/aks/AKSFlexNode/pkg/utils" +) + +// VPN certificate and connection management functions for the Installer + +// generateCertificates generates VPN certificates for P2S connection +func (i *Installer) generateCertificates() (string, error) { + _, err := i.setupCertificateDirectory() + if err != nil { + return "", err + } + + // Check if certificates already exist + if certBase64, exists := i.loadExistingCertificate(); exists { + return certBase64, nil + } + + i.logger.Info("Generating new VPN certificates...") + certBase64, err := i.generateNewCertificate() + if err != nil { + return "", err + } + return certBase64, nil // new certificate generated +} + +// setupCertificateDirectory creates and configures the certificate directory +func (i *Installer) setupCertificateDirectory() (string, error) { + certDir := certificatesDir + if err := utils.RunSystemCommand("mkdir", "-p", certDir); err != nil { + return "", fmt.Errorf("failed to create certificates directory: %w", err) + } + + // Set proper permissions on the directory + if err := utils.RunSystemCommand("chmod", fmt.Sprintf("%o", certificatesDirPerm), certDir); err != nil { + return "", fmt.Errorf("failed to set permissions on certificates directory: %w", err) + } + + i.logger.Infof("Using system certificate directory: %s", certDir) + + return certDir, nil +} + +// loadExistingCertificate checks for existing certificates and returns root certificate base64 data if found +func (i *Installer) loadExistingCertificate() (string, bool) { + certPath := GetVPNClientCertPath() + keyPath := GetVPNClientKeyPath() + rootCertPath := GetVPNRootCertPath() + + // Check if all required files exist + if _, err := os.Stat(certPath); err == nil { + if _, err := os.Stat(keyPath); err == nil { + if _, err := os.Stat(rootCertPath); err == nil { + i.logger.Info("VPN certificates already exist, using existing certificates") + + // Read existing root certificate and return base64 data for Azure comparison + rootCertData, err := os.ReadFile(rootCertPath) + if err != nil { + i.logger.Warnf("Failed to read existing root certificate: %v", err) + return "", false + } + + // Parse certificate to get DER data for base64 encoding + block, _ := pem.Decode(rootCertData) + if block == nil { + i.logger.Warnf("Failed to parse existing root certificate PEM") + return "", false + } + + rootCertBase64 := base64.StdEncoding.EncodeToString(block.Bytes) + i.logger.Info("Using existing VPN certificates") + return rootCertBase64, true + } + } + } + + return "", false +} + +// generateNewCertificate creates a proper CA hierarchy with root CA and client certificate +func (i *Installer) generateNewCertificate() (string, error) { + // Generate root CA private key + rootPrivateKey, err := rsa.GenerateKey(rand.Reader, certificateKeySize) + if err != nil { + return "", fmt.Errorf("failed to generate root CA private key: %w", err) + } + + // Create root CA certificate + rootCertDER, err := i.createRootCACertificate(rootPrivateKey) + if err != nil { + return "", fmt.Errorf("failed to create root CA certificate: %w", err) + } + + // Generate client private key + clientPrivateKey, err := rsa.GenerateKey(rand.Reader, certificateKeySize) + if err != nil { + return "", fmt.Errorf("failed to generate client private key: %w", err) + } + + // Create client certificate signed by root CA + clientCertDER, err := i.createClientCertificate(clientPrivateKey, rootPrivateKey, rootCertDER) + if err != nil { + return "", fmt.Errorf("failed to create client certificate: %w", err) + } + + // Save certificate and key to files + clientKeyPath := GetVPNClientKeyPath() + clientCertPath := GetVPNClientCertPath() + rootCertPath := GetVPNRootCertPath() + + // Save the client private key + if err := i.savePrivateKey(clientKeyPath, clientPrivateKey); err != nil { + return "", err + } + + // Save the client certificate + if err := i.saveCertificate(clientCertPath, clientCertDER); err != nil { + return "", err + } + + // Save the root CA certificate (for Azure upload) + if err := i.saveCertificate(rootCertPath, rootCertDER); err != nil { + return "", err + } + + // Return base64-encoded root certificate for upload to Azure + rootCertBase64 := base64.StdEncoding.EncodeToString(rootCertDER) + + i.logger.Info("VPN certificate hierarchy generated successfully (root CA + client cert)") + + return rootCertBase64, nil +} + +// createRootCACertificate generates a root CA certificate +func (i *Installer) createRootCACertificate(privateKey *rsa.PrivateKey) ([]byte, error) { + // Generate SubjectKeyIdentifier for the root CA + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } + subjectKeyID := sha256.Sum256(publicKeyBytes) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: certificateCommonName, + }, + NotBefore: time.Now().Add(-10 * time.Minute), + NotAfter: time.Now().Add(certificateValidYears * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, // This is a CA certificate + SubjectKeyId: subjectKeyID[:], // Required for chain validation + } + + // Self-signed root CA: template is both the certificate to create and the issuer + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create root CA certificate: %w", err) + } + + return certDER, nil +} + +// createClientCertificate generates a client certificate signed by the root CA +func (i *Installer) createClientCertificate(clientPrivateKey, rootPrivateKey *rsa.PrivateKey, rootCertDER []byte) ([]byte, error) { + // Parse the root certificate to use as issuer + rootCert, err := x509.ParseCertificate(rootCertDER) + if err != nil { + return nil, fmt.Errorf("failed to parse root certificate: %w", err) + } + + // Ensure the root certificate has SubjectKeyId (should be set from createRootCACertificate) + if len(rootCert.SubjectKeyId) == 0 { + return nil, fmt.Errorf("root certificate is missing SubjectKeyId") + } + + // Generate SubjectKeyIdentifier for the client certificate + clientPublicKeyBytes, err := x509.MarshalPKIXPublicKey(&clientPrivateKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal client public key: %w", err) + } + clientSubjectKeyID := sha256.Sum256(clientPublicKeyBytes) + + template := x509.Certificate{ + SerialNumber: big.NewInt(2), // Different serial number for client cert + Subject: pkix.Name{ + CommonName: "VPN Client", + }, + NotBefore: time.Now().Add(-10 * time.Minute), + NotAfter: time.Now().Add(certificateValidYears * 365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: false, // This is NOT a CA certificate + SubjectKeyId: clientSubjectKeyID[:], // Required for chain validation + AuthorityKeyId: rootCert.SubjectKeyId, // MUST match root's SubjectKeyId exactly + } + + i.logger.Infof("Creating client certificate with AuthorityKeyId matching root SubjectKeyId: %x", rootCert.SubjectKeyId) + + // Client certificate signed by root CA + certDER, err := x509.CreateCertificate(rand.Reader, &template, rootCert, &clientPrivateKey.PublicKey, rootPrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %w", err) + } + + return certDER, nil +} + +// savePrivateKey saves the private key to file with proper permissions +func (i *Installer) savePrivateKey(keyPath string, privateKey *rsa.PrivateKey) error { + privateKeyPEM := &pem.Block{ + Type: rsaPrivateKeyType, + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + // Encode PEM to bytes + var keyBuffer bytes.Buffer + if err := pem.Encode(&keyBuffer, privateKeyPEM); err != nil { + return fmt.Errorf("failed to encode private key: %w", err) + } + + // Write file using system-level file operations + if err := utils.WriteFileAtomicSystem(keyPath, keyBuffer.Bytes(), privateKeyFilePerm); err != nil { + return fmt.Errorf("failed to create key file: %w", err) + } + + return nil +} + +// saveCertificate saves the certificate to file with proper permissions +func (i *Installer) saveCertificate(certPath string, certDER []byte) error { + certPEM := &pem.Block{ + Type: certificateType, + Bytes: certDER, + } + + // Encode PEM to bytes + var certBuffer bytes.Buffer + if err := pem.Encode(&certBuffer, certPEM); err != nil { + return fmt.Errorf("failed to encode certificate: %w", err) + } + + // Write file using system-level file operations + if err := utils.WriteFileAtomicSystem(certPath, certBuffer.Bytes(), certificateFilePerm); err != nil { + return fmt.Errorf("failed to create certificate file: %w", err) + } + + return nil +} + +// processVPNConfig processes VPN config file and embeds certificates using sed (matching working implementation) +func (i *Installer) processVPNConfig(sourcePath, destPath string) error { + i.logger.Info("Processing VPN configuration with certificate data...") + + // Ensure certificates exist - use the same directory logic as GenerateCertificates + _, err := i.setupCertificateDirectory() + if err != nil { + return fmt.Errorf("failed to setup certificate directory: %w", err) + } + + clientCertPath := GetVPNClientCertPath() + clientKeyPath := GetVPNClientKeyPath() + + // Check if certificates exist, if not generate them + if _, err := os.Stat(clientCertPath); os.IsNotExist(err) { + i.logger.Info("Client certificates not found, generating new certificates...") + if _, err := i.generateCertificates(); err != nil { + return fmt.Errorf("failed to generate certificates: %w", err) + } + } + + // Create a temporary working copy of the VPN config + tempConfig, err := os.CreateTemp("", tempVPNConfigPattern) + if err != nil { + return fmt.Errorf("failed to create temporary config file: %w", err) + } + tempConfigPath := tempConfig.Name() + _ = tempConfig.Close() + defer func() { + if err := utils.RunCleanupCommand(tempConfigPath); err != nil { + i.logger.Warnf("Failed to clean up temp config file %s: %v", tempConfigPath, err) + } + }() + + // Copy source to temp file + if err := utils.RunSystemCommand("cp", sourcePath, tempConfigPath); err != nil { + return fmt.Errorf("failed to copy VPN config to temp file: %w", err) + } + + // Read certificate and key content using system commands + certContent, err := utils.RunCommandWithOutput("cat", clientCertPath) + if err != nil { + return fmt.Errorf("failed to read certificate file: %w", err) + } + + keyContent, err := utils.RunCommandWithOutput("cat", clientKeyPath) + if err != nil { + return fmt.Errorf("failed to read private key file: %w", err) + } + + // Read the config file content + configContent, err := utils.RunCommandWithOutput("cat", tempConfigPath) + if err != nil { + return fmt.Errorf("failed to read temp config file: %w", err) + } + + // Replace placeholders with actual content (trim to remove any trailing whitespace) + processedConfig := strings.ReplaceAll(configContent, "$CLIENTCERTIFICATE", strings.TrimSpace(certContent)) + processedConfig = strings.ReplaceAll(processedConfig, "$PRIVATEKEY", strings.TrimSpace(keyContent)) + + // Write processed config back to temp file using system command + if err := utils.WriteFileAtomicSystem(tempConfigPath, []byte(processedConfig), 0600); err != nil { + return fmt.Errorf("failed to write processed config: %w", err) + } + + // Copy processed config to final destination + if strings.HasPrefix(destPath, systemEtcPrefix) || strings.HasPrefix(destPath, systemUsrPrefix) || strings.HasPrefix(destPath, systemVarPrefix) { + // Create destination directory if it doesn't exist + destDir := filepath.Dir(destPath) + if err := utils.RunSystemCommand("mkdir", "-p", destDir); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Copy temp file to destination with sudo + if err := utils.RunSystemCommand("cp", tempConfigPath, destPath); err != nil { + return fmt.Errorf("failed to copy VPN config to %s: %w", destPath, err) + } + + // Set proper permissions and ownership with sudo + if err := utils.RunSystemCommand("chmod", "600", destPath); err != nil { + i.logger.Warnf("Failed to set permissions on VPN config: %v", err) + } + + if err := utils.RunSystemCommand("chown", "root:root", destPath); err != nil { + i.logger.Warnf("Failed to set ownership on VPN config: %v", err) + } + } else { + // Destination is in user directory, copy directly + if err := utils.RunSystemCommand("cp", tempConfigPath, destPath); err != nil { + return fmt.Errorf("failed to copy VPN config to %s: %w", destPath, err) + } + } + + i.logger.Info("VPN configuration processed successfully") + return nil +} + +// setupOpenVPN installs and configures OpenVPN +func (i *Installer) setupOpenVPN(configPath string) error { + i.logger.Info("Setting up OpenVPN...") + + // Check if VPN connection is already established + if i.isVPNConnected() { + i.logger.Info("VPN connection is already established, skipping OpenVPN setup") + return nil + } + + // Install OpenVPN + // Check if OpenVPN is already installed + if err := utils.RunSystemCommand("which", "openvpn"); err == nil { + i.logger.Info("OpenVPN is already installed, skipping installation") + } else { + i.logger.Info("Installing OpenVPN...") + if err := utils.RunSystemCommand("apt", "install", "-y", "openvpn"); err != nil { + return fmt.Errorf("failed to install OpenVPN: %w", err) + } + } + + // Always ensure certificates are embedded in the OpenVPN config + destPath := GetOpenVPNConfigPath() + + // If configPath is provided, process it; otherwise process existing config + sourceConfigPath := configPath + if sourceConfigPath == "" { + sourceConfigPath = destPath // Process existing config in place + } + + // Copy and process VPN config file + if sourceConfigPath != "" { + vpnConfigDir := openVPNConfigDir + if err := utils.RunSystemCommand("mkdir", "-p", vpnConfigDir); err != nil { + return fmt.Errorf("failed to create OpenVPN config directory: %w", err) + } + + // Process VPN config with certificate data + if err := i.processVPNConfig(sourceConfigPath, destPath); err != nil { + return fmt.Errorf("failed to process VPN config: %w", err) + } + + // Enable and restart OpenVPN service to ensure it uses the updated configuration + if err := utils.EnableService(openVPNServiceTemplate); err != nil { + return fmt.Errorf("failed to enable OpenVPN service: %w", err) + } + + i.logger.Info("Restarting OpenVPN service to apply updated configuration...") + if err := utils.RestartService(openVPNServiceTemplate); err != nil { + return fmt.Errorf("failed to restart OpenVPN service: %w", err) + } + + // Give OpenVPN a moment to start before checking status + time.Sleep(2 * time.Second) + + // Check if OpenVPN service started successfully + if !utils.IsServiceActive(openVPNServiceTemplate) { + i.logger.Warn("OpenVPN service is not active, please check the service status for details") + } else { + i.logger.Info("OpenVPN service restarted successfully") + } + } + return nil +} + +// configureVPNNetworking configures routes and iptables rules for VPN gateway connectivity +func (i *Installer) configureVPNNetworking(ctx context.Context, vnetInfo vnetResourceInfo) error { + i.logger.Info("Configuring VPN network routes and iptables rules...") + + // Get Pod CIDR from user configuration and extract all AKS VNet CIDRs from vnetInfo + podCIDR, aksVNetCIDRs, err := i.getNetworkConfiguration(vnetInfo) + if err != nil { + return fmt.Errorf("failed to get network configuration: %w", err) + } + + i.logger.Infof("Using AKS VNet CIDRs: %v, Pod CIDR: %s", aksVNetCIDRs, podCIDR) + + // Get VPN interface + vpnInterface, err := utils.GetVPNInterface() + if err != nil { + return fmt.Errorf("failed to get VPN interface: %w", err) + } + + i.logger.Infof("Configuring networking for VPN interface: %s", vpnInterface) + + // Add route for AKS VNet via VPN gateway + // The gateway IP is typically the first IP in the P2S CIDR range + 1 + gatewayIP, err := i.calculateGatewayIP() + if err != nil { + return fmt.Errorf("failed to calculate gateway IP: %w", err) + } + + // Add routes for all VNet CIDRs + for _, vnetCIDR := range aksVNetCIDRs { + i.logger.Infof("Adding route: %s via %s dev %s", vnetCIDR, gatewayIP, vpnInterface) + if err := i.addIPRoute(vnetCIDR, gatewayIP, vpnInterface); err != nil { + return fmt.Errorf("failed to add route for AKS VNet CIDR %s: %w", vnetCIDR, err) + } + } + + // Add route for AKS pod network (required for flex pod to aks pod communication) + // This enables Flex node pods to reach AKS cluster pods (like DNS services) + i.logger.Infof("Adding route for AKS pod network: %s via %s dev %s", podCIDR, gatewayIP, vpnInterface) + if err := i.addIPRoute(podCIDR, gatewayIP, vpnInterface); err != nil { + return fmt.Errorf("failed to add route for AKS pod CIDR %s: %w", podCIDR, err) + } + + i.logger.Info("VPN network configuration completed successfully") + return nil +} + +// getNetworkConfiguration gets Pod CIDR from user config and extracts AKS VNet CIDRs from vnetInfo +func (i *Installer) getNetworkConfiguration(vnetInfo vnetResourceInfo) (string, []string, error) { + // Get Pod CIDR from user configuration (required) + if i.config.Azure.VPNGateway == nil || i.config.Azure.VPNGateway.PodCIDR == "" { + return "", nil, fmt.Errorf("pod CIDR is required in VPN configuration when enabled, please set it") + } + podCIDR := i.config.GetVPNGatewayPodCIDR() + + // Extract all AKS VNet CIDRs from the already discovered VNet info + // Using all VNet CIDRs ensures we can reach all subnets including AKS nodes + aksVNetCIDRs, err := i.getVNetCIDRsFromInfo(vnetInfo) + if err != nil { + return "", nil, fmt.Errorf("failed to get AKS VNet CIDRs: %w", err) + } + + return podCIDR, aksVNetCIDRs, nil +} + +// getVNetCIDRsFromInfo extracts all VNet CIDRs from vnetInfo +func (i *Installer) getVNetCIDRsFromInfo(vnetInfo vnetResourceInfo) ([]string, error) { + // Extract all VNet CIDRs from the address space + if vnetInfo.vnet == nil || + vnetInfo.vnet.Properties == nil || + vnetInfo.vnet.Properties.AddressSpace == nil || + len(vnetInfo.vnet.Properties.AddressSpace.AddressPrefixes) == 0 { + return nil, fmt.Errorf("VNet has no address prefixes") + } + + // Extract all address prefixes as VNet CIDRs + var vnetCIDRs []string + for _, prefix := range vnetInfo.vnet.Properties.AddressSpace.AddressPrefixes { + if prefix != nil { + vnetCIDRs = append(vnetCIDRs, *prefix) + } + } + + if len(vnetCIDRs) == 0 { + return nil, fmt.Errorf("VNet has no valid address prefixes") + } + + i.logger.Infof("Using VNet CIDRs: %v", vnetCIDRs) + return vnetCIDRs, nil +} + +// calculateGatewayIP calculates the gateway IP from P2S CIDR +func (i *Installer) calculateGatewayIP() (string, error) { + p2sCIDR := i.config.Azure.VPNGateway.P2SGatewayCIDR + if p2sCIDR == "" { + return "", fmt.Errorf("P2S Gateway CIDR not configured") + } + + // Parse the P2S CIDR (e.g., "192.168.100.0/24") + _, network, err := net.ParseCIDR(p2sCIDR) + if err != nil { + return "", fmt.Errorf("failed to parse P2S CIDR %s: %w", p2sCIDR, err) + } + + // Gateway IP is typically the first usable IP in the range + // For 192.168.100.0/24, the gateway would be 192.168.100.1 + ip := network.IP.To4() + if ip == nil { + return "", fmt.Errorf("only IPv4 networks are supported") + } + + // Increment the network address by 1 to get the gateway IP + gatewayIP := net.IPv4(ip[0], ip[1], ip[2], ip[3]+1) + return gatewayIP.String(), nil +} + +// addIPRoute adds an IP route if it doesn't already exist +func (i *Installer) addIPRoute(vnetCIDR, gatewayIP, vpnInterface string) error { + // Try to add the route, capture combined output to check for "File exists" error + output, err := utils.RunCommandWithOutput("ip", "route", "add", vnetCIDR, "via", gatewayIP, "dev", vpnInterface) + if err != nil { + // Check if route already exists by looking for "File exists" in the output or error + if strings.Contains(output, "File exists") || strings.Contains(err.Error(), "File exists") { + i.logger.Infof("Route to %s already exists, skipping", vnetCIDR) + return nil + } + return fmt.Errorf("failed to add route for AKS VNet CIDR %s: %s (exit code: %v)", vnetCIDR, strings.TrimSpace(output), err) + } + + i.logger.Infof("Added route: %s via %s dev %s", vnetCIDR, gatewayIP, vpnInterface) + return nil +} + +// isVPNConnected checks if VPN connection is active +func (i *Installer) isVPNConnected() bool { + iface, err := utils.GetVPNInterface() + if err != nil { + return false + } + + ip, err := utils.GetVPNInterfaceIP(iface) + return err == nil && ip != "" +} + +// uploadCertificateToAzure uploads the root certificate to Azure VPN Gateway using Azure SDK +func (i *Installer) uploadCertificateToAzure(ctx context.Context, certData string, vnetInfo vnetResourceInfo) error { + // Get the current VPN Gateway to update its configuration + gateway, err := i.vgwClient.Get(ctx, vnetInfo.resourceGroupName, defaultVPNGatewayName, nil) + if err != nil { + return fmt.Errorf("failed to get VPN Gateway: %w", err) + } + + // Check if VPN client configuration exists + // Look for our specific certificate by name and data + var existingCertFound bool + var existingCertMatches bool + + // Check if VPN client configuration exists and has certificates + if gateway.Properties.VPNClientConfiguration != nil && + gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates != nil { + for _, cert := range gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates { + if cert.Properties != nil && cert.Properties.PublicCertData != nil { + // Compare certificate data directly + if *cert.Properties.PublicCertData == certData { + i.logger.Infof("VPN certificate already exists on Azure VPN Gateway with name '%s', skipping upload", *cert.Name) + return nil // Certificate already exists and matches, no need to upload + } + } + // Track if any certificate exists (regardless of name) + if cert.Name != nil { + existingCertFound = true + } + } + } + + if existingCertFound && !existingCertMatches { + i.logger.Info("Adding new VPN root certificate alongside existing certificates") + } else if !existingCertFound { + i.logger.Info("No existing VPN root certificates found, uploading first certificate") + } + + // Ensure the VPN client configuration section exists with required address pool + if gateway.Properties.VPNClientConfiguration == nil { + p2sGatewayCIDR := i.config.Azure.VPNGateway.P2SGatewayCIDR + vpnClientProtocol := armnetwork.VPNClientProtocolOpenVPN + + gateway.Properties.VPNClientConfiguration = &armnetwork.VPNClientConfiguration{ + VPNClientAddressPool: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{&p2sGatewayCIDR}, + }, + VPNClientProtocols: []*armnetwork.VPNClientProtocol{&vpnClientProtocol}, + } + } + + // Create root certificate parameters with unique name based on certificate data + certHash := fmt.Sprintf("%x", sha256.Sum256([]byte(certData)))[:8] // Use first 8 chars of hash + certName := fmt.Sprintf("%s-%s", vpnClientRootCertName, certHash) + + i.logger.Infof("Adding VPN root certificate with name: %s", certName) + + rootCert := armnetwork.VPNClientRootCertificate{ + Name: &certName, + Properties: &armnetwork.VPNClientRootCertificatePropertiesFormat{ + PublicCertData: &certData, + }, + } + + // Append the new certificate to existing certificates instead of replacing them + if gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates == nil { + gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates = []*armnetwork.VPNClientRootCertificate{} + } + + // Always append the new certificate (with unique name, no conflicts) + gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates = + append(gateway.Properties.VPNClientConfiguration.VPNClientRootCertificates, &rootCert) + + // Update the VPN Gateway with the new certificate configuration + poller, err := i.vgwClient.BeginCreateOrUpdate(ctx, vnetInfo.resourceGroupName, defaultVPNGatewayName, gateway.VirtualNetworkGateway, nil) + if err != nil { + return fmt.Errorf("failed to start VPN Gateway update: %w", err) + } + + // Wait for the operation to complete + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to update VPN Gateway with certificate: %w", err) + } + + if existingCertFound && !existingCertMatches { + i.logger.Info("VPN certificate added successfully - now have multiple certificates available") + } else { + i.logger.Info("VPN certificate uploaded to Azure successfully") + } + return nil + +} + +// downloadVPNClientConfig downloads the VPN client configuration from Azure using Azure SDK +func (i *Installer) downloadVPNClientConfig(ctx context.Context, gatewayName, resourceGroup string) (string, error) { + i.logger.Info("Downloading VPN client configuration from Azure VPN Gateway...") + + // Generate VPN client configuration + authMethod := armnetwork.AuthenticationMethodEAPTLS + req := armnetwork.VPNClientParameters{ + AuthenticationMethod: &authMethod, + } + + poller, err := i.vgwClient.BeginGenerateVPNProfile(ctx, resourceGroup, gatewayName, req, nil) + if err != nil { + return "", fmt.Errorf("failed to start VPN client config generation: %w", err) + } + + // Wait for the operation to complete + result, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return "", fmt.Errorf("failed to generate VPN client config: %w", err) + } + + if result.Value == nil || *result.Value == "" { + return "", fmt.Errorf("no VPN client configuration URL returned from Azure") + } + + downloadURL := *result.Value + i.logger.Infof("VPN client configuration URL: %s", downloadURL) + + // Download the configuration file + configData, err := i.downloadConfigFromURL(downloadURL) + if err != nil { + return "", fmt.Errorf("failed to download VPN client configuration: %w", err) + } + + return configData, nil + +} + +// downloadConfigFromURL downloads the VPN client configuration from the provided URL +func (i *Installer) downloadConfigFromURL(urlStr string) (string, error) { + // Validate URL to prevent SSRF attacks + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + + // Only allow HTTPS URLs for security + if parsedURL.Scheme != "https" { + return "", fmt.Errorf("only HTTPS URLs are allowed, got: %s", parsedURL.Scheme) + } + + // Validate that this is an Azure Blob Storage URL for VPN configuration + host := strings.ToLower(parsedURL.Host) + if !strings.HasSuffix(host, ".blob.core.windows.net") { + return "", fmt.Errorf("URL must be from Azure Blob Storage: %s", parsedURL.Host) + } + + // Create temporary file for ZIP download + tempZipFile, err := os.CreateTemp("", tempVPNZipPattern) + if err != nil { + return "", fmt.Errorf("failed to create temporary ZIP file: %w", err) + } + tempZipPath := tempZipFile.Name() + _ = tempZipFile.Close() // Close file handle so utils.DownloadFile can write to it + + defer func() { + if err := utils.RunCleanupCommand(tempZipPath); err != nil { + i.logger.Warnf("Failed to clean up temp ZIP file %s: %v", tempZipPath, err) + } + }() + + // Download the ZIP file using utils.DownloadFile + i.logger.Info("Downloading VPN configuration ZIP file...") + if err := utils.DownloadFile(urlStr, tempZipPath); err != nil { + return "", fmt.Errorf("failed to download VPN config ZIP: %w", err) + } + + // Extract and return the OpenVPN configuration + return i.extractOpenVPNConfig(tempZipPath) +} + +// extractOpenVPNConfig extracts the OpenVPN configuration from a ZIP file +func (i *Installer) extractOpenVPNConfig(zipPath string) (string, error) { + // Open the ZIP file + reader, err := zip.OpenReader(zipPath) + if err != nil { + return "", fmt.Errorf("failed to open VPN config ZIP: %w", err) + } + defer func() { _ = reader.Close() }() + + // Look for any .ovpn file in the ZIP (handle different path separators) + i.logger.Info("Examining ZIP contents:") + for _, file := range reader.File { + i.logger.Infof(" File: %s", file.Name) + fileName := strings.ToLower(file.Name) + // Check for .ovpn files in OpenVPN directory (handle both / and \ separators) + if strings.HasSuffix(fileName, ".ovpn") && + (strings.Contains(fileName, "openvpn/") || strings.Contains(fileName, "openvpn\\")) { + // Extract and read the config file with size limits + fileReader, err := file.Open() + if err != nil { + return "", fmt.Errorf("failed to open OpenVPN config: %w", err) + } + defer func() { _ = fileReader.Close() }() + + // Add size limit of 1MB (config files are typically small) + const maxFileSize = 1024 * 1024 // 1MB + limitedReader := io.LimitReader(fileReader, maxFileSize) + + configData, err := io.ReadAll(limitedReader) + if err != nil { + return "", fmt.Errorf("failed to read OpenVPN config: %w", err) + } + + i.logger.Info("VPN configuration extracted successfully") + return string(configData), nil + } + } + + return "", fmt.Errorf("OpenVPN configuration file (.ovpn) not found in ZIP") +} diff --git a/pkg/config/structs.go b/pkg/config/structs.go index 3c1cc4a..609a361 100644 --- a/pkg/config/structs.go +++ b/pkg/config/structs.go @@ -30,6 +30,7 @@ type AzureConfig struct { ManagedIdentity *ManagedIdentityConfig `json:"managedIdentity,omitempty"` // Optional managed identity authentication BootstrapToken *BootstrapTokenConfig `json:"bootstrapToken,omitempty"` // Optional bootstrap token authentication Arc *ArcConfig `json:"arc"` // Azure Arc machine configuration + VPNGateway *VPNConfig `json:"vpnGateway"` // Azure VPN gateway configuration for P2S connectivity TargetCluster *TargetClusterConfig `json:"targetCluster"` // Target AKS cluster configuration } @@ -72,6 +73,15 @@ type ArcConfig struct { Location string `json:"location"` // Azure region for Arc machine } +// VPNConfig holds configuration settings for the VPN gateway. +type VPNConfig struct { + Enabled bool `json:"enabled"` + P2SGatewayCIDR string `json:"p2sGatewayCIDR"` + GatewaySKU string `json:"gatewaySKU"` + PodCIDR string `json:"podCIDR,omitempty"` // Pod network CIDR (e.g., "10.244.0.0/16") - required for routing + VNetID string `json:"vnetID,omitempty"` // Azure VNet resource ID (AKS managed or BYO VNet where AKS nodes reside) +} + // AgentConfig holds agent-specific operational configuration. type AgentConfig struct { LogLevel string `json:"logLevel"` // Logging level: debug, info, warning, error @@ -130,7 +140,7 @@ type KubernetesPathsConfig struct { KubeletDir string `json:"kubeletDir"` } -// CNIPathsConfig holds file system paths related to CNI plugins and configurations. +// CNIConfig holds configuration settings for CNI plugins and networking type CNIConfig struct { Version string `json:"version"` } @@ -196,6 +206,14 @@ func (cfg *Config) GetTargetClusterResourceGroup() string { return "" } +// GetTargetClusterNodeResourceGroup returns the target AKS cluster node resource group from configuration +func (cfg *Config) GetTargetClusterNodeResourceGroup() string { + if cfg.Azure.TargetCluster != nil && cfg.Azure.TargetCluster.NodeResourceGroup != "" { + return cfg.Azure.TargetCluster.NodeResourceGroup + } + return "" +} + // GetTargetClusterLocation returns the target AKS cluster location from configuration func (cfg *Config) GetTargetClusterLocation() string { if cfg.Azure.TargetCluster != nil && cfg.Azure.TargetCluster.Location != "" { @@ -256,3 +274,28 @@ func (cfg *Config) GetKubernetesVersion() string { func (cfg *Config) IsARCEnabled() bool { return cfg.Azure.Arc != nil && cfg.Azure.Arc.Enabled } + +// IsVPNGatewayEnabled checks if VPN Gateway is enabled in the configuration +func (cfg *Config) IsVPNGatewayEnabled() bool { + if cfg.Azure.VPNGateway != nil && + cfg.Azure.VPNGateway.Enabled { + return true + } + return false +} + +// GetVPNGatewayVNetID returns the VNet ID for the VPN Gateway from configuration +func (cfg *Config) GetVPNGatewayVNetID() string { + if cfg.Azure.VPNGateway != nil { + return cfg.Azure.VPNGateway.VNetID + } + return "" +} + +// GetVPNGatewayPodCIDR returns the Pod CIDR for the VPN Gateway from configuration +func (cfg *Config) GetVPNGatewayPodCIDR() string { + if cfg.Azure.VPNGateway != nil { + return cfg.Azure.VPNGateway.PodCIDR + } + return "" +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 80fe053..9e4f976 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -20,7 +20,7 @@ import ( // sudoCommandLists holds the command lists for sudo determination var ( - alwaysNeedsSudo = []string{"apt", "apt-get", "dpkg", "systemctl", "mount", "umount", "modprobe", "sysctl", "azcmagent", "usermod", "kubectl", "swapoff"} + alwaysNeedsSudo = []string{"apt", "apt-get", "dpkg", "systemctl", "mount", "umount", "modprobe", "sysctl", "azcmagent", "usermod", "kubectl", "pkill", "swapoff", "iptables", "ip"} conditionalSudo = []string{"mkdir", "cp", "chmod", "chown", "mv", "tar", "rm", "bash", "install", "ln", "cat"} systemPaths = []string{"/etc/", "/usr/", "/var/", "/opt/", "/boot/", "/sys/"} ) @@ -98,6 +98,15 @@ func IsServiceActive(serviceName string) bool { return strings.TrimSpace(output) == "active" } +// IsServiceEnabled checks if a systemd service is enabled +func IsServiceEnabled(serviceName string) bool { + output, err := RunCommandWithOutput("systemctl", "is-enabled", serviceName) + if err != nil { + return false + } + return strings.TrimSpace(output) == "enabled" +} + // ServiceExists checks if a systemd service unit file exists func ServiceExists(serviceName string) bool { err := RunSystemCommand("systemctl", "list-unit-files", serviceName+".service") @@ -114,6 +123,11 @@ func DisableService(serviceName string) error { return RunSystemCommand("systemctl", "disable", serviceName) } +// EnableService enables a systemd service +func EnableService(serviceName string) error { + return RunSystemCommand("systemctl", "enable", serviceName) +} + // EnableAndStartService enables and starts a systemd service func EnableAndStartService(serviceName string) error { return RunSystemCommand("systemctl", "enable", "--now", serviceName) @@ -468,3 +482,133 @@ func ExtractClusterInfo(kubeconfigData []byte) (string, string, error) { caCertDataB64 := base64.StdEncoding.EncodeToString(cluster.CertificateAuthorityData) return cluster.Server, caCertDataB64, nil } + +// GetVPNInterface returns the first available VPN interface (tun0, tun1, etc.) +func GetVPNInterface() (string, error) { + const vpnInterfacePrefix = "tun" + const maxVPNInterfaces = 10 + + // Check for tun interfaces + for j := 0; j < maxVPNInterfaces; j++ { + iface := fmt.Sprintf("%s%d", vpnInterfacePrefix, j) + if _, err := os.Stat(fmt.Sprintf("/sys/class/net/%s", iface)); err == nil { + return iface, nil + } + } + return "", fmt.Errorf("no VPN interface found") +} + +// GetVPNInterfaceIP returns the IP address of the given VPN interface +func GetVPNInterfaceIP(iface string) (string, error) { + output, err := RunCommandWithOutput("ip", "addr", "show", iface) + if err != nil { + return "", fmt.Errorf("failed to get interface %s info: %w", iface, err) + } + + // Parse IP address from output + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.Contains(line, "inet ") && !strings.Contains(line, "inet6") { + fields := strings.Fields(line) + if len(fields) >= 2 { + ip := strings.Split(fields[1], "/")[0] + return ip, nil + } + } + } + return "", fmt.Errorf("no IP address found for interface %s", iface) +} + +// ValidateAzureResourceID validates that the provided resource ID follows Azure resource ID format +// Expected format: /subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/{resourceProvider}/{resourceType}/{resourceName} +func ValidateAzureResourceID(resourceID, expectedResourceType string) error { + if resourceID == "" { + return fmt.Errorf("resource ID cannot be empty") + } + + // Azure resource IDs must start with /subscriptions/ + if !strings.HasPrefix(resourceID, "/subscriptions/") { + return fmt.Errorf("resource ID must start with '/subscriptions/', got: %s", resourceID) + } + + // Split the resource ID into parts + parts := strings.Split(resourceID, "/") + + // Azure resource ID should have at least 9 parts: + // ["", "subscriptions", "{subscriptionId}", "resourceGroups", "{resourceGroupName}", "providers", "{resourceProvider}", "{resourceType}", "{resourceName}"] + if len(parts) < 9 { + return fmt.Errorf("resource ID has invalid format, expected at least 9 segments, got %d: %s", len(parts), resourceID) + } + + // Validate the fixed parts of the resource ID format + expectedSegments := map[int]string{ + 1: "subscriptions", + 3: "resourceGroups", + 5: "providers", + } + + for index, expectedValue := range expectedSegments { + if index >= len(parts) || parts[index] != expectedValue { + return fmt.Errorf("resource ID segment %d should be '%s', got '%s': %s", index, expectedValue, parts[index], resourceID) + } + } + + // Validate that required segments are not empty + requiredSegments := map[int]string{ + 2: "subscription ID", + 4: "resource group name", + 6: "resource provider", + 7: "resource type", + 8: "resource name", + } + + for index, segmentName := range requiredSegments { + if index >= len(parts) || strings.TrimSpace(parts[index]) == "" { + return fmt.Errorf("resource ID %s cannot be empty: %s", segmentName, resourceID) + } + } + + // Validate the resource type matches expected type + if expectedResourceType != "" && parts[7] != expectedResourceType { + return fmt.Errorf("expected resource type '%s', got '%s': %s", expectedResourceType, parts[7], resourceID) + } + + // Validate that it's a Microsoft.Network provider for VNet + if expectedResourceType == "virtualNetworks" && parts[6] != "Microsoft.Network" { + return fmt.Errorf("VNet resource must use Microsoft.Network provider, got '%s': %s", parts[6], resourceID) + } + + return nil +} + +// GetResourceGroupFromResourceID extracts the resource group name from an Azure resource ID +func GetResourceGroupFromResourceID(resourceID string) string { + parts := strings.Split(resourceID, "/") + for i, part := range parts { + if strings.EqualFold(part, "resourceGroups") && i+1 < len(parts) { + return parts[i+1] + } + } + return "" +} + +// GetResourceNameFromResourceID extracts the resource name from an Azure resource ID +func GetResourceNameFromResourceID(resourceID string) string { + parts := strings.Split(resourceID, "/") + if len(parts) > 0 { + return parts[len(parts)-1] + } + return "" +} + +// GetSubscriptionIDFromResourceID extracts the subscription ID from an Azure resource ID +func GetSubscriptionIDFromResourceID(resourceID string) string { + parts := strings.Split(resourceID, "/") + for i, part := range parts { + if strings.EqualFold(part, "subscriptions") && i+1 < len(parts) { + return parts[i+1] + } + } + return "" +}