diff --git a/rosidl_generator_py/resource/_msg.py.em b/rosidl_generator_py/resource/_msg.py.em index 98285ecb..e5a86724 100644 --- a/rosidl_generator_py/resource/_msg.py.em +++ b/rosidl_generator_py/resource/_msg.py.em @@ -9,6 +9,7 @@ from rosidl_generator_py.generate_py_impl import get_type_annotation_default from rosidl_generator_py.generate_py_impl import get_setter_and_getter_type from rosidl_generator_py.generate_py_impl import SPECIAL_NESTED_BASIC_TYPES from rosidl_generator_py.generate_py_impl import value_to_py +from rosidl_generator_py.generate_py_impl import generate_check_fields from rosidl_parser.definition import AbstractGenericString from rosidl_parser.definition import AbstractNestedType from rosidl_parser.definition import AbstractSequence @@ -504,161 +505,7 @@ if isinstance(member.type, (Array, AbstractSequence)): ' please use a subclass of collections.abc.Sequence like list', DeprecationWarning) @[ end if]@ - if self._check_fields: -@[ if isinstance(member.type, AbstractNestedType) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@ -@[ if isinstance(member.type, Array)]@ - if isinstance(value, numpy.ndarray): - assert value.dtype == @(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype']), \ - "The '@(member.name)' numpy.ndarray() must have the dtype of '@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'])'" - assert value.size == @(member.type.size), \ - "The '@(member.name)' numpy.ndarray() must have a size of @(member.type.size)" - self._@(member.name) = value - return -@[ elif isinstance(member.type, AbstractSequence)]@ - if isinstance(value, array.array): - assert value.typecode == '@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['type_code'])', \ - "The '@(member.name)' array.array() must have the type code of '@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['type_code'])'" -@[ if isinstance(member.type, BoundedSequence)]@ - assert len(value) <= @(member.type.maximum_size), \ - "The '@(member.name)' array.array() must have a size <= @(member.type.maximum_size)" -@[ end if]@ - self._@(member.name) = value - return -@[ end if]@ -@[ end if]@ -@[ if isinstance(type_, NamespacedType)]@ -@[ if ( - type_.name.endswith(ACTION_GOAL_SUFFIX) or - type_.name.endswith(ACTION_RESULT_SUFFIX) or - type_.name.endswith(ACTION_FEEDBACK_SUFFIX) - )]@ - from @('.'.join(type_.namespaces))._@(convert_camel_case_to_lower_case_underscore(type_.name.rsplit('_', 1)[0])) import @(type_.name) -@[ else]@ - from @('.'.join(type_.namespaces)) import @(type_.name) -@[ end if]@ -@[ end if]@ -@[ if isinstance(member.type, AbstractNestedType)]@ - from collections.abc import Sequence - from collections import UserString -@[ elif isinstance(type_, AbstractGenericString) and type_.has_maximum_size()]@ - from collections import UserString -@[ elif isinstance(type_, BasicType) and type_.typename in CHARACTER_TYPES]@ - from collections import UserString -@[ end if]@ - assert \ -@[ if isinstance(member.type, AbstractNestedType)]@ - ((isinstance(value, Sequence) or - isinstance(value, Set)) and - not isinstance(value, str) and - not isinstance(value, UserString) and -@{assert_msg_suffixes = ['sequence']}@ -@[ if isinstance(type_, AbstractGenericString) and type_.has_maximum_size()]@ - all(len(val) <= @(type_.maximum_size) for val in value) and -@{assert_msg_suffixes.append('and each string value not longer than %d' % type_.maximum_size)}@ -@[ end if]@ -@[ if isinstance(member.type, (Array, BoundedSequence))]@ -@[ if isinstance(member.type, BoundedSequence)]@ - len(value) <= @(member.type.maximum_size) and -@{assert_msg_suffixes.insert(1, 'with length <= %d' % member.type.maximum_size)}@ -@[ else]@ - len(value) == @(member.type.size) and -@{assert_msg_suffixes.insert(1, 'with length %d' % member.type.size)}@ -@[ end if]@ -@[ end if]@ - all(isinstance(v, @(get_python_type(type_))) for v in value) and -@{assert_msg_suffixes.append("and each value of type '%s'" % get_python_type(type_))}@ -@[ if isinstance(type_, BasicType) and type_.typename in SIGNED_INTEGER_TYPES]@ -@{ -nbits = int(type_.typename[3:]) -bound = 2**(nbits - 1) -}@ - all(val >= -@(bound) and val < @(bound) for val in value)), \ -@{assert_msg_suffixes.append('and each integer in [%d, %d]' % (-bound, bound - 1))}@ -@[ elif isinstance(type_, BasicType) and type_.typename in UNSIGNED_INTEGER_TYPES]@ -@{ -nbits = int(type_.typename[4:]) -bound = 2**nbits -}@ - all(val >= 0 and val < @(bound) for val in value)), \ -@{assert_msg_suffixes.append('and each unsigned integer in [0, %d]' % (bound - 1))}@ -@[ elif isinstance(type_, BasicType) and type_.typename == 'char']@ - all(ord(val) >= 0 and ord(val) < 256 for val in value)), \ -@{assert_msg_suffixes.append('and each char in [0, 255]')}@ -@[ elif isinstance(type_, BasicType) and type_.typename in FLOATING_POINT_TYPES]@ -@[ if type_.typename == "float"]@ -@{ -name = "float" -bound = 3.402823466e+38 -}@ - all(not (val < -@(bound) or val > @(bound)) or math.isinf(val) for val in value)), \ -@{assert_msg_suffixes.append('and each float in [%f, %f]' % (-bound, bound))}@ -@[ elif type_.typename == "double"]@ -@{ -name = "double" -bound = 1.7976931348623157e+308 -}@ - all(not (val < -@(bound) or val > @(bound)) or math.isinf(val) for val in value)), \ -@{assert_msg_suffixes.append('and each double in [%f, %f]' % (-bound, bound))}@ -@[ end if]@ -@[ else]@ - True), \ -@[ end if]@ - "The '@(member.name)' field must be @(' '.join(assert_msg_suffixes))" -@[ elif isinstance(member.type, AbstractGenericString) and member.type.has_maximum_size()]@ - (isinstance(value, (str, UserString)) and - len(value) <= @(member.type.maximum_size)), \ - "The '@(member.name)' field must be string value " \ - 'not longer than @(type_.maximum_size)' -@[ elif isinstance(type_, NamespacedType)]@ - isinstance(value, @(type_.name)), \ - "The '@(member.name)' field must be a sub message of type '@(type_.name)'" -@[ elif isinstance(type_, BasicType) and type_.typename == 'octet']@ - (isinstance(value, (bytes, bytearray, memoryview)) and - len(value) == 1), \ - "The '@(member.name)' field must be of type 'bytes' or 'ByteString' with length 1" -@[ elif isinstance(type_, BasicType) and type_.typename == 'char']@ - (isinstance(value, (str, UserString)) and - len(value) == 1 and ord(value) >= -128 and ord(value) < 128), \ - "The '@(member.name)' field must be of type 'str' or 'UserString' " \ - 'with length 1 and the character ord() in [-128, 127]' -@[ elif isinstance(type_, AbstractGenericString)]@ - isinstance(value, str), \ - "The '@(member.name)' field must be of type '@(get_python_type(type_))'" -@[ elif isinstance(type_, BasicType) and type_.typename in (BOOLEAN_TYPE, *FLOATING_POINT_TYPES, *INTEGER_TYPES)]@ - isinstance(value, @(get_python_type(type_))), \ - "The '@(member.name)' field must be of type '@(get_python_type(type_))'" -@[ if type_.typename in SIGNED_INTEGER_TYPES]@ -@{ -nbits = int(type_.typename[3:]) -bound = 2**(nbits - 1) -}@ - assert value >= -@(bound) and value < @(bound), \ - "The '@(member.name)' field must be an integer in [@(-bound), @(bound - 1)]" -@[ elif type_.typename in UNSIGNED_INTEGER_TYPES]@ -@{ -nbits = int(type_.typename[4:]) -bound = 2**nbits -}@ - assert value >= 0 and value < @(bound), \ - "The '@(member.name)' field must be an unsigned integer in [0, @(bound - 1)]" -@[ elif type_.typename in FLOATING_POINT_TYPES]@ -@[ if type_.typename == "float"]@ -@{ -name = "float" -bound = 3.402823466e+38 -}@ -@[ elif type_.typename == "double"]@ -@{ -name = "double" -bound = 1.7976931348623157e+308 -}@ -@[ end if]@ - assert not (value < -@(bound) or value > @(bound)) or math.isinf(value), \ - "The '@(member.name)' field must be a @(name) in [@(-bound), @(bound)]" -@[ end if]@ -@[ else]@ - False -@[ end if]@ +@(generate_check_fields(member)) @[ if isinstance(member.type, AbstractNestedType) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@ @[ if isinstance(member.type, Array)]@ self._@(member.name) = numpy.array(value, dtype=@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'])) diff --git a/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py b/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py index c0cf7c9b..a1086e0f 100644 --- a/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py +++ b/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py @@ -29,6 +29,10 @@ from rosidl_parser.definition import ACTION_RESULT_SUFFIX from rosidl_parser.definition import Array from rosidl_parser.definition import BasicType +from rosidl_parser.definition import BOOLEAN_TYPE +from rosidl_parser.definition import BoundedSequence +from rosidl_parser.definition import BoundedString +from rosidl_parser.definition import BoundedWString from rosidl_parser.definition import CHARACTER_TYPES from rosidl_parser.definition import Constant from rosidl_parser.definition import FLOATING_POINT_TYPES @@ -39,6 +43,8 @@ from rosidl_parser.definition import Message from rosidl_parser.definition import NamespacedType from rosidl_parser.definition import Service +from rosidl_parser.definition import SIGNED_INTEGER_TYPES +from rosidl_parser.definition import UNSIGNED_INTEGER_TYPES from rosidl_parser.parser import parse_idl_file from rosidl_pycommon import convert_camel_case_to_lower_case_underscore from rosidl_pycommon import expand_template @@ -452,3 +458,237 @@ def get_setter_and_getter_type(member: Member, type_imports: set[str]) -> tuple[ type_annotations_getter = type_annotations_setter return type_annotations_setter, type_annotations_getter + + +class CodeWriter: + + def __init__(self, start_level: int = 0, indent_size: int = 4): + self._indent = ' ' * indent_size + self._level = start_level + self._lines: list[str] = [] + + def write(self, line: str = '') -> None: + self._lines.append(f'{self._indent * self._level}{line}') + + def indent(self) -> None: + self._level += 1 + + def dedent(self) -> None: + assert self._level > 0 + self._level -= 1 + + def get_value(self) -> str: + return '\n'.join(self._lines) + + +def generate_check_fields(member: Member) -> str: + cw = CodeWriter(start_level=2) + cw.write('if self._check_fields:') + cw.indent() + + type_ = member.type + if isinstance(type_, AbstractNestedType): + type_ = type_.value_type + + if ( + isinstance(member.type, AbstractNestedType) and + isinstance(member.type.value_type, BasicType) and + member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES + ): + if isinstance(member.type, Array): + cw.write('if isinstance(value, numpy.ndarray):') + cw.indent() + + dtype = SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'] + cw.write(f'assert value.dtype == {dtype}, \\') + cw.indent() + cw.write(f'"The \'{member.name}\' numpy.ndarray() must have the dtype of \'{dtype}\'"') + cw.dedent() + size = member.type.size + cw.write(f'assert value.size == {size}, \\') + cw.indent() + cw.write(f'"The \'{member.name}\' numpy.ndarray() must have a size of {size}"') + cw.dedent() + + elif isinstance(member.type, AbstractSequence): + cw.write('if isinstance(value, array.array):') + cw.indent() + + type_code = SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['type_code'] + cw.write(f"assert value.typecode == '{type_code}', \\") + cw.indent() + cw.write(f'"The \'{member.name}\' array.array() ' + f'must have the type code of \'{type_code}\'"') + cw.dedent() + if isinstance(member.type, BoundedSequence): + max_size = member.type.maximum_size + cw.write(f'assert len(value) <= {max_size}, \\') + cw.indent() + cw.write(f'"The \'{member.name}\' array.array() must have a size <= {max_size}"') + cw.dedent() + + cw.write(f'self._{member.name} = value') + cw.write('return') + cw.dedent() + + if isinstance(type_, NamespacedType): + resolved_namespace = '.'.join(type_.namespaces) + if ( + type_.name.endswith(ACTION_GOAL_SUFFIX) or + type_.name.endswith(ACTION_RESULT_SUFFIX) or + type_.name.endswith(ACTION_FEEDBACK_SUFFIX) + ): + lower_case_action_name = convert_camel_case_to_lower_case_underscore( + type_.name.rsplit('_', 1)[0] + ) + cw.write(f'from {resolved_namespace}._{lower_case_action_name} import {type_.name}') + else: + cw.write(f'from {resolved_namespace} import {type_.name}') + + if isinstance(member.type, AbstractNestedType): + cw.write('from collections.abc import Sequence') + cw.write('from collections import UserString') + elif isinstance(type_, AbstractGenericString) and type_.has_maximum_size(): + cw.write('from collections import UserString') + elif isinstance(type_, BasicType) and type_.typename in CHARACTER_TYPES: + cw.write('from collections import UserString') + + cw.write('assert \\') + cw.indent() + if isinstance(member.type, AbstractNestedType): + cw.write('((isinstance(value, Sequence) or') + cw.write(' isinstance(value, Set)) and') + cw.write(' not isinstance(value, str) and') + cw.write(' not isinstance(value, UserString) and') + + assert_msg_suffixes: list[str] = ['sequence'] + + if isinstance(type_, (BoundedString, BoundedWString)) and type_.has_maximum_size(): + max_size_str = type_.maximum_size + cw.write(f' all(len(val) <= {max_size_str} for val in value) and') + assert_msg_suffixes.append(f'and each string value not longer than {max_size_str}') + + if isinstance(member.type, (Array, BoundedSequence)): + if isinstance(member.type, BoundedSequence): + cw.write(f' len(value) <= {member.type.maximum_size} and') + assert_msg_suffixes.insert(1, f'with length <= {member.type.maximum_size}') + else: + cw.write(f' len(value) == {member.type.size} and') + assert_msg_suffixes.insert(1, f'with length {member.type.size}') + + cw.write(f' all(isinstance(v, {get_python_type(type_)}) for v in value) and') + assert_msg_suffixes.append(f"and each value of type '{get_python_type(type_)}'") + + if ( + isinstance(type_, BasicType) and + type_.typename in (*SIGNED_INTEGER_TYPES, *UNSIGNED_INTEGER_TYPES) + ): + if type_.typename in SIGNED_INTEGER_TYPES: + name = 'integer' + nbits = int(type_.typename[3:]) + bound = 2**(nbits - 1) + lower_bound = -bound + elif type_.typename in UNSIGNED_INTEGER_TYPES: + name = 'unsigned integer' + nbits = int(type_.typename[4:]) + bound = 2**nbits + lower_bound = 0 + + cw.write(f' all(val >= {lower_bound} and val < {bound} for val in value)), \\') + assert_msg_suffixes.append(f'and each {name} in [{lower_bound}, {(bound - 1)}]') + elif isinstance(type_, BasicType) and type_.typename == 'char': + cw.write(' all(ord(val) >= 0 and ord(val) < 256 for val in value)), \\') + assert_msg_suffixes.append('and each char in [0, 255]') + elif isinstance(type_, BasicType) and type_.typename in FLOATING_POINT_TYPES: + if type_.typename == 'float': + name = 'float' + bound_str = '3.402823466e+38' + bound = '{0:.6f}'.format(float(bound_str)) + elif type_.typename == 'double': + name = 'double' + bound_str = '1.7976931348623157e+308' + bound = '{0:.6f}'.format(float(bound_str)) + elif type_.typename == 'long double': + name = 'long double' + bound_str = '1.189731495357231765e+4932' + bound = '{0:.6f}'.format(float(bound_str)) + + isinf_str = 'math.isinf(val) for val in value)' + cw.write(f' all(not (val < -{bound_str} or val > {bound_str}) or {isinf_str}), \\') + assert_msg_suffixes.append(f'and each {name} in [-{bound}, {bound}]') + + else: + cw.write(' True), \\') + + joined_assert_msg_suffixes = ' '.join(assert_msg_suffixes) + cw.write(f'"The \'{member.name}\' field must be {joined_assert_msg_suffixes}"') + + elif ( + isinstance(member.type, (BoundedString, BoundedWString)) and + member.type.has_maximum_size() + ): + cw.write('(isinstance(value, (str, UserString)) and') + cw.write(f' len(value) <= {member.type.maximum_size}), \\') + cw.write(f'"The \'{member.name}\' field must be string value " \\') + assert isinstance(type_, (BoundedString, BoundedWString)) + cw.write(f"'not longer than {type_.maximum_size}'") + elif isinstance(type_, NamespacedType): + cw.write(f'isinstance(value, {type_.name}), \\') + cw.write(f'"The \'{member.name}\' field must be a sub message of type \'{type_.name}\'"') + elif isinstance(type_, BasicType) and type_.typename == 'octet': + cw.write('(isinstance(value, (bytes, bytearray, memoryview)) and') + cw.write(' len(value) == 1), \\') + byte_alias = "'bytes' or 'ByteString'" + cw.write(f'"The \'{member.name}\' field must be of type {byte_alias} with length 1"') + elif isinstance(type_, BasicType) and type_.typename == 'char': + cw.write('(isinstance(value, (str, UserString)) and') + cw.write(' len(value) == 1 and ord(value) >= -128 and ord(value) < 128), \\') + cw.write(f'"The \'{member.name}\' field must be of type \'str\' or \'UserString\' " \\') + cw.write("'with length 1 and the character ord() in [-128, 127]'") + elif isinstance(type_, AbstractGenericString): + cw.write('isinstance(value, str), \\') + cw.write(f'"The \'{member.name}\' field must be of type \'{get_python_type(type_)}\'"') + elif ( + isinstance(type_, BasicType) and + type_.typename in (BOOLEAN_TYPE, *FLOATING_POINT_TYPES, *INTEGER_TYPES) + ): + cw.write(f'isinstance(value, {get_python_type(type_)}), \\') + cw.write(f'"The \'{member.name}\' field must be of type \'{get_python_type(type_)}\'"') + if type_.typename in (*SIGNED_INTEGER_TYPES, *UNSIGNED_INTEGER_TYPES): + if type_.typename in SIGNED_INTEGER_TYPES: + name = 'integer' + nbits = int(type_.typename[3:]) + bound = 2**(nbits - 1) + lower_bound = -bound + elif type_.typename in UNSIGNED_INTEGER_TYPES: + name = 'unsigned integer' + nbits = int(type_.typename[4:]) + bound = 2**nbits + lower_bound = 0 + + cw.dedent() + cw.write(f'assert value >= {lower_bound} and value < {bound}, \\') + cw.indent() + bounds_str = f'[{lower_bound}, {bound - 1}]' + cw.write(f'"The \'{member.name}\' field must be an {name} in {bounds_str}"') + + elif type_.typename in FLOATING_POINT_TYPES: + if type_.typename == 'float': + name = 'float' + bound = 3.402823466e+38 + elif type_.typename == 'double': + name = 'double' + bound = 1.7976931348623157e+308 + elif type_.typename == 'long double': + name = 'long double' + bound = 1.189731495357231765e+4932 + + cw.dedent() + inf_check = ' or math.isinf(value)' + cw.write(f'assert not (value < -{str(bound)} or value > {str(bound)}){inf_check}, \\') + cw.indent() + cw.write(f'"The \'{member.name}\' field must be a {name} in [-{bound}, {bound}]"') + else: + cw.write('False') + + return cw.get_value()