Skip to content
6 changes: 6 additions & 0 deletions src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,9 @@ class IdentityType(Enum):
USER_ASSIGNED = 'UserAssigned'
SYSTEM_ASSIGNED_USER_ASSIGNED = 'SystemAssigned, UserAssigned'
NONE = 'None'


class UpgradeMode(Enum):
AUTOMATIC = 'Automatic'
MANUAL = 'Manual'
ROLLING = 'Rolling'
6 changes: 4 additions & 2 deletions src/azure-cli/azure/cli/command_modules/vm/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,12 @@ def load_command_table(self, _):
g.custom_command('create', 'create_dedicated_host_group')
g.generic_update_command('update')

with self.command_group('vmss', compute_vmss_sdk, operation_group='virtual_machine_scale_sets') as g:
with self.command_group('vmss') as g:
g.custom_command('identity assign', 'assign_vmss_identity', validator=process_assign_identity_namespace)
g.custom_command('identity remove', 'remove_vmss_identity', validator=process_remove_identity_namespace, min_api='2017-12-01', is_preview=True)
g.custom_command('identity remove', 'remove_vmss_identity', validator=process_remove_identity_namespace, is_preview=True)
g.custom_show_command('identity show', 'show_vmss_identity')

with self.command_group('vmss', compute_vmss_sdk, operation_group='virtual_machine_scale_sets') as g:
g.custom_command('application set', 'set_vmss_applications', validator=process_set_applications_namespace, min_api='2021-07-01')
g.custom_command('application list', 'list_vmss_applications', min_api='2021-07-01')
g.custom_command('create', 'create_vmss', transform=DeploymentOutputLongRunningOperation(self.cli_ctx, 'Starting vmss create'), supports_no_wait=True, table_transformer=deployment_validate_table_format, validator=process_vmss_create_namespace, exception_handler=handle_template_based_exception)
Expand Down
175 changes: 92 additions & 83 deletions src/azure-cli/azure/cli/command_modules/vm/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,8 @@ def show_vm_identity(cmd, resource_group_name, vm_name):


def show_vmss_identity(cmd, resource_group_name, vm_name):
client = _compute_client_factory(cmd.cli_ctx)
return client.virtual_machine_scale_sets.get(resource_group_name, vm_name).identity
vm = get_vmss_by_aaz(cmd, resource_group_name, vm_name)
return vm.get("identity", {}) if vm else None


def assign_vm_identity(cmd, resource_group_name, vm_name, assign_identity=None, identity_role=None,
Expand Down Expand Up @@ -2564,45 +2564,6 @@ def list_vm_extension_images(


# region VirtualMachines Identity
def _remove_identities(cmd, resource_group_name, name, identities, getter, setter):
from ._vm_utils import MSI_LOCAL_ID
ResourceIdentityType = cmd.get_models('ResourceIdentityType', operation_group='virtual_machines')
remove_system_assigned_identity = False
if MSI_LOCAL_ID in identities:
remove_system_assigned_identity = True
identities.remove(MSI_LOCAL_ID)
resource = getter(cmd, resource_group_name, name)
if resource.identity is None:
return None
emsis_to_remove = []
if identities:
existing_emsis = {x.lower() for x in (resource.identity.user_assigned_identities or {}).keys()}
emsis_to_remove = {x.lower() for x in identities}
non_existing = emsis_to_remove.difference(existing_emsis)
if non_existing:
raise CLIError("'{}' are not associated with '{}'".format(','.join(non_existing), name))
if not list(existing_emsis - emsis_to_remove): # if all emsis are gone, we need to update the type
if resource.identity.type == ResourceIdentityType.user_assigned:
resource.identity.type = ResourceIdentityType.none
elif resource.identity.type == ResourceIdentityType.system_assigned_user_assigned:
resource.identity.type = ResourceIdentityType.system_assigned

resource.identity.user_assigned_identities = None
if remove_system_assigned_identity:
resource.identity.type = (ResourceIdentityType.none
if resource.identity.type == ResourceIdentityType.system_assigned
else ResourceIdentityType.user_assigned)

if emsis_to_remove:
if resource.identity.type not in [ResourceIdentityType.none, ResourceIdentityType.system_assigned]:
resource.identity.user_assigned_identities = {}
for identity in emsis_to_remove:
resource.identity.user_assigned_identities[identity] = None

result = LongRunningOperation(cmd.cli_ctx)(setter(resource_group_name, name, resource))
return result.identity


def _remove_identities_by_aaz(cmd, resource_group_name, name, identities, getter, setter):
from ._vm_utils import MSI_LOCAL_ID

Expand Down Expand Up @@ -2647,6 +2608,10 @@ def _remove_identities_by_aaz(cmd, resource_group_name, name, identities, getter
existing_identity['type'] = IdentityType.NONE.value

result = LongRunningOperation(cmd.cli_ctx)(setter(resource_group_name, name, resource))

if not result:
return None

return result.get('identity') or None


Expand Down Expand Up @@ -3589,49 +3554,64 @@ def reset_linux_ssh(cmd, resource_group_name, vm_name, no_wait=False):
# region VirtualMachineScaleSets
def assign_vmss_identity(cmd, resource_group_name, vmss_name, assign_identity=None, identity_role=None,
identity_role_id=None, identity_scope=None):
VirtualMachineScaleSetIdentity, UpgradeMode, ResourceIdentityType, VirtualMachineScaleSetUpdate = cmd.get_models(
'VirtualMachineScaleSetIdentity', 'UpgradeMode', 'ResourceIdentityType', 'VirtualMachineScaleSetUpdate')
IdentityUserAssignedIdentitiesValue = cmd.get_models(
'VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue') or cmd.get_models('UserAssignedIdentitiesValue')
from azure.cli.core.commands.arm import assign_identity as assign_identity_helper
client = _compute_client_factory(cmd.cli_ctx)
_, _, external_identities, enable_local_identity = _build_identities_info(assign_identity)
identity, _, external_identities, enable_local_identity = _build_identities_info(assign_identity)
from ._vm_utils import assign_identity as assign_identity_helper, UpgradeMode

command_args = {'resource_group': resource_group_name, 'vm_scale_set_name': vmss_name}

def getter():
return client.virtual_machine_scale_sets.get(resource_group_name, vmss_name)
return get_vmss_by_aaz(cmd, resource_group_name, vmss_name)

def setter(vmss, external_identities=external_identities):

if vmss.identity and vmss.identity.type == ResourceIdentityType.system_assigned_user_assigned:
identity_types = ResourceIdentityType.system_assigned_user_assigned
elif vmss.identity and vmss.identity.type == ResourceIdentityType.system_assigned and external_identities:
identity_types = ResourceIdentityType.system_assigned_user_assigned
elif vmss.identity and vmss.identity.type == ResourceIdentityType.user_assigned and enable_local_identity:
identity_types = ResourceIdentityType.system_assigned_user_assigned
if vmss.get('identity', {}).get('type', None) == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value:
identity_types = IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value
elif vmss.get('identity', {}).get('type', None) == IdentityType.SYSTEM_ASSIGNED.value and external_identities:
identity_types = IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value
elif vmss.get('identity', {}).get('type', None) == IdentityType.USER_ASSIGNED.value and enable_local_identity:
identity_types = IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value
elif external_identities and enable_local_identity:
identity_types = ResourceIdentityType.system_assigned_user_assigned
identity_types = IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value
elif external_identities:
identity_types = ResourceIdentityType.user_assigned
identity_types = IdentityType.USER_ASSIGNED.value
else:
identity_types = ResourceIdentityType.system_assigned
vmss.identity = VirtualMachineScaleSetIdentity(type=identity_types)
if external_identities:
vmss.identity.user_assigned_identities = {}
for identity in external_identities:
vmss.identity.user_assigned_identities[identity] = IdentityUserAssignedIdentitiesValue()
vmss_patch = VirtualMachineScaleSetUpdate()
vmss_patch.identity = vmss.identity
poller = client.virtual_machine_scale_sets.begin_update(resource_group_name, vmss_name, vmss_patch)
return LongRunningOperation(cmd.cli_ctx)(poller)
identity_types = IdentityType.SYSTEM_ASSIGNED.value

if identity_types == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value:
command_args['mi_system_assigned'] = "True"
command_args['mi_user_assigned'] = []
elif identity_types == IdentityType.USER_ASSIGNED.value:
command_args['mi_user_assigned'] = []
else:
command_args['mi_system_assigned'] = "True"
command_args['mi_user_assigned'] = []

if vmss.get('identity', {}).get('userAssignedIdentities', None):
for key in vmss.get('identity').get('userAssignedIdentities').keys():
command_args['mi_user_assigned'].append(key)

if identity.get('userAssignedIdentities'):
for key in identity.get('userAssignedIdentities', {}).keys():
if key not in command_args['mi_user_assigned']:
command_args['mi_user_assigned'].append(key)

from .operations.vmss import VMSSPatch
update_vmss_identity = VMSSPatch(cli_ctx=cmd.cli_ctx)(command_args=command_args)
LongRunningOperation(cmd.cli_ctx)(update_vmss_identity)
result = update_vmss_identity.result()
return result

assign_identity_helper(cmd.cli_ctx, getter, setter, identity_role=identity_role_id, identity_scope=identity_scope)
vmss = client.virtual_machine_scale_sets.get(resource_group_name, vmss_name)
if vmss.upgrade_policy.mode == UpgradeMode.manual:

vmss = getter()
if vmss.get('upgradePolicy', {}).get('mode', '') == UpgradeMode.MANUAL.value:
logger.warning("With manual upgrade mode, you will need to run 'az vmss update-instances -g %s -n %s "
"--instance-ids *' to propagate the change", resource_group_name, vmss_name)

return _construct_identity_info(identity_scope, identity_role, vmss.identity.principal_id,
vmss.identity.user_assigned_identities)
return _construct_identity_info(
identity_scope,
identity_role,
vmss.get('identity').get('principalId') if vmss.get('identity') else None,
vmss.get('identity').get('userAssignedIdentities') if vmss.get('identity') else None)


# pylint: disable=too-many-locals, too-many-statements
Expand Down Expand Up @@ -4201,6 +4181,24 @@ def get_vmss(cmd, resource_group_name, name, instance_id=None, include_user_data
return client.virtual_machine_scale_sets.get(resource_group_name, name)


def get_vmss_by_aaz(cmd, resource_group_name, name, instance_id=None, include_user_data=False):
from .operations.vmss import VMSSShow
from .operations.vmss_vms import VMSSVMSShow

command_args = {
'resource_group': resource_group_name,
'vm_scale_set_name': name,
}

if include_user_data:
command_args['expand'] = 'userData'

if instance_id is not None:
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_vmss_by_aaz accepts an instance_id argument but the instance_id is never added to command_args before calling VMSSVMSShow, so any future caller that passes instance_id will not get the expected per-instance result. Please include the instance_id in command_args when instance_id is not None, consistent with how get_vmss_modified_by_aaz and other helpers wire through this parameter.

Suggested change
if instance_id is not None:
if instance_id is not None:
command_args['instance_id'] = instance_id

Copilot uses AI. Check for mistakes.
command_args['instance_id'] = instance_id
return VMSSVMSShow(cli_ctx=cmd.cli_ctx)(command_args=command_args)
return VMSSShow(cli_ctx=cmd.cli_ctx)(command_args=command_args)


def _check_vmss_hyper_v_generation(cli_ctx, vmss):
hyper_v_generation = get_hyper_v_generation_from_vmss(
cli_ctx, vmss.virtual_machine_profile.storage_profile.image_reference, vmss.location)
Expand Down Expand Up @@ -5405,24 +5403,35 @@ def vmss_run_command_show(cmd,

# region VirtualMachineScaleSets Identity
def remove_vmss_identity(cmd, resource_group_name, vmss_name, identities=None):
client = _compute_client_factory(cmd.cli_ctx)
def setter(resource_group_name, vmss_name, vmss):
command_args = {
'resource_group': resource_group_name,
'vm_scale_set_name': vmss_name
}

def _get_vmss(_, resource_group_name, vmss_name):
return client.virtual_machine_scale_sets.get(resource_group_name, vmss_name)
if vmss.get('identity') and vmss['identity'].get('type') == IdentityType.USER_ASSIGNED.value:
# NOTE: The literal 'UserAssigned' is intentionally appended as a marker for
# VMSSIdentityRemove._format_content, which uses it to apply special handling
# for purely user-assigned identities. It is not a real identity resource ID.
command_args['mi_user_assigned'] = \
list(vmss.get('identity', {}).get('userAssignedIdentities', {}).keys()) + ['UserAssigned']
elif vmss.get('identity') and vmss['identity'].get('type') == IdentityType.SYSTEM_ASSIGNED.value:
command_args['mi_user_assigned'] = []
command_args['mi_system_assigned'] = 'True'
elif vmss.get('identity') and vmss['identity'].get('type') == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value:
command_args['mi_user_assigned'] = list(vmss.get('identity', {}).get('userAssignedIdentities', {}).keys())
command_args['mi_system_assigned'] = 'True'
else:
command_args['mi_user_assigned'] = []

def _set_vmss(resource_group_name, name, vmss_instance):
VirtualMachineScaleSetUpdate = cmd.get_models('VirtualMachineScaleSetUpdate',
operation_group='virtual_machine_scale_sets')
vmss_update = VirtualMachineScaleSetUpdate(identity=vmss_instance.identity)
return client.virtual_machine_scale_sets.begin_update(resource_group_name, vmss_name, vmss_update)
from .operations.vmss import VMSSIdentityRemove
return VMSSIdentityRemove(cli_ctx=cmd.cli_ctx)(command_args=command_args)

if identities is None:
from ._vm_utils import MSI_LOCAL_ID
identities = [MSI_LOCAL_ID]

return _remove_identities(cmd, resource_group_name, vmss_name, identities,
_get_vmss,
_set_vmss)
return _remove_identities_by_aaz(cmd, resource_group_name, vmss_name, identities, get_vmss_by_aaz, setter)
# endregion


Expand Down
78 changes: 77 additions & 1 deletion src/azure-cli/azure/cli/command_modules/vm/operations/vmss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint: disable=no-self-use, line-too-long, protected-access, too-few-public-methods, unused-argument, too-many-statements, too-many-branches, too-many-locals
import json
from knack.log import get_logger

from ..aaz.latest.vmss import (ListInstances as _VMSSListInstances,
Start as _Start,
Create as _VMSSCreate,
Show as _VMSSShow)
Show as _VMSSShow,
Patch as _VMSSPatch)
from azure.cli.core.aaz import AAZUndefined, has_value
from .._vm_utils import IdentityType

logger = get_logger(__name__)

Expand Down Expand Up @@ -66,6 +69,79 @@ def _output(self, *args, **kwargs):
return result


class VMSSPatch(_VMSSPatch):

def _output(self, *args, **kwargs):
# Resolve flatten conflict
# When the type field conflicts, the type in inner layer is ignored and the outer layer is applied
if has_value(self.ctx.vars.instance.properties.virtual_machine_profile.extension_profile.extensions):
for extension in self.ctx.vars.instance.properties.virtual_machine_profile.extension_profile.extensions:
if has_value(extension.type):
extension.type = AAZUndefined

result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True)
return result


class VMSSIdentityRemove(_VMSSPatch):
def _output(self, *args, **kwargs):
# Resolve flatten conflict
# When the type field conflicts, the type in inner layer is ignored and the outer layer is applied
if has_value(self.ctx.vars.instance.properties.virtual_machine_profile.extension_profile.extensions):
for extension in self.ctx.vars.instance.properties.virtual_machine_profile.extension_profile.extensions:
if has_value(extension.type):
extension.type = AAZUndefined

result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True)

identity = result.get('identity')
if not identity:
return result

if not identity.get('userAssignedIdentities'):
return result

return result

class VirtualMachineScaleSetsUpdate(_VMSSPatch.VirtualMachineScaleSetsUpdate):
def _format_content(self, content):
if isinstance(content, str):
content = json.loads(content)

if not content.get('identity'):
content['identity'] = {
'userAssignedIdentities': None,
'type': IdentityType.NONE.value
}
return json.dumps(content)

identities = content.get('identity', {}).get('userAssignedIdentities')
if identities:
if 'UserAssigned' in identities.keys():
identities.pop('UserAssigned')

for key in list(identities.keys()):
identities[key] = None

return json.dumps(content)

def __call__(self, *args, **kwargs):
request = self.make_request()
request.data = self._format_content(request.data)
session = self.client.send_request(request=request, stream=False, **kwargs)
if session.http_response.status_code in [200, 202]:
return self.client.build_lro_polling(
self.ctx.args.no_wait,
session,
self.on_200,
self.on_error,
lro_options={"final-state-via": "azure-async-operation"},
path_format_arguments=self.url_parameters,
)

return self.on_error(session.http_response)


def convert_show_result_to_snake_case(result):
new_result = {}
if "extendedLocation" in result:
Expand Down
Loading
Loading