From 95c5541d51ffce4919cf99415bc9276f1bb4b99e Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 17:53:33 -0400 Subject: [PATCH 1/9] feat(core): add semantic types for wizard system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement core semantic types (BranchId, ActionId, OptionKey, MenuId) with factory functions, type guards, and StateValue for JSON-serializable data. - NewType definitions for type safety without runtime overhead - Factory functions with optional validation (default: no overhead) - Type guards for runtime type checking - Collection type aliases (BranchList, ActionSet, etc.) - StateValue type alias for JSON-serializable values - 100% test coverage with 28 comprehensive tests - MyPy strict mode compliance Part of CLI-4: Minimal Core Type Definitions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/types.py | 213 +++++++++++++- tests/unit/core/test_types.py | 520 +++++++++++++++++++++++++++++++++ 2 files changed, 731 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/test_types.py diff --git a/src/cli_patterns/core/types.py b/src/cli_patterns/core/types.py index b3dc393..1c3c748 100644 --- a/src/cli_patterns/core/types.py +++ b/src/cli_patterns/core/types.py @@ -1,3 +1,212 @@ -"""Core type definitions for CLI Patterns.""" +"""Core semantic types for the wizard system. -# Placeholder for core type definitions +This module defines semantic types that provide type safety for the wizard system +while maintaining MyPy strict mode compliance. These are simple NewType definitions +that prevent type confusion without adding runtime validation complexity. + +The semantic types help distinguish between different string contexts in the wizard: +- BranchId: Represents a branch identifier in the wizard tree +- ActionId: Represents an action identifier +- OptionKey: Represents an option key name +- MenuId: Represents a menu identifier for navigation +- StateValue: Represents any JSON-serializable value that can be stored in state + +All ID types are backed by strings but provide semantic meaning at the type level. +StateValue is a JSON-compatible type alias for flexible state storage. +""" + +from __future__ import annotations + +from typing import Any, NewType, Union + +from typing_extensions import TypeGuard + +# JSON-compatible types for state values +JsonPrimitive = Union[str, int, float, bool, None] +JsonValue = Union[JsonPrimitive, list["JsonValue"], dict[str, "JsonValue"]] + +# Core semantic types for wizard system +BranchId = NewType("BranchId", str) +"""Semantic type for branch identifiers in the wizard tree.""" + +ActionId = NewType("ActionId", str) +"""Semantic type for action identifiers.""" + +OptionKey = NewType("OptionKey", str) +"""Semantic type for option key names.""" + +MenuId = NewType("MenuId", str) +"""Semantic type for menu identifiers.""" + +# State value is any JSON-serializable value +StateValue = JsonValue +"""Type alias for state values - any JSON-serializable data.""" + +# Type aliases for common collections using semantic types +BranchList = list[BranchId] +"""Type alias for lists of branch IDs.""" + +BranchSet = set[BranchId] +"""Type alias for sets of branch IDs.""" + +ActionList = list[ActionId] +"""Type alias for lists of action IDs.""" + +ActionSet = set[ActionId] +"""Type alias for sets of action IDs.""" + +OptionDict = dict[OptionKey, StateValue] +"""Type alias for option dictionaries mapping keys to state values.""" + +MenuList = list[MenuId] +"""Type alias for lists of menu IDs.""" + + +# Factory functions for creating semantic types +def make_branch_id(value: str, validate: bool = False) -> BranchId: + """Create a BranchId from a string value. + + Args: + value: String value to convert to BranchId + validate: If True, validate the input (default: False for zero overhead) + + Returns: + BranchId with semantic type safety + + Raises: + ValueError: If validate=True and value is invalid + """ + if validate: + if not value or not value.strip(): + raise ValueError("BranchId cannot be empty") + if len(value) > 100: + raise ValueError("BranchId is too long (max 100 characters)") + return BranchId(value) + + +def make_action_id(value: str, validate: bool = False) -> ActionId: + """Create an ActionId from a string value. + + Args: + value: String value to convert to ActionId + validate: If True, validate the input (default: False for zero overhead) + + Returns: + ActionId with semantic type safety + + Raises: + ValueError: If validate=True and value is invalid + """ + if validate: + if not value or not value.strip(): + raise ValueError("ActionId cannot be empty") + if len(value) > 100: + raise ValueError("ActionId is too long (max 100 characters)") + return ActionId(value) + + +def make_option_key(value: str, validate: bool = False) -> OptionKey: + """Create an OptionKey from a string value. + + Args: + value: String value to convert to OptionKey + validate: If True, validate the input (default: False for zero overhead) + + Returns: + OptionKey with semantic type safety + + Raises: + ValueError: If validate=True and value is invalid + """ + if validate: + if not value or not value.strip(): + raise ValueError("OptionKey cannot be empty") + if len(value) > 100: + raise ValueError("OptionKey is too long (max 100 characters)") + return OptionKey(value) + + +def make_menu_id(value: str, validate: bool = False) -> MenuId: + """Create a MenuId from a string value. + + Args: + value: String value to convert to MenuId + validate: If True, validate the input (default: False for zero overhead) + + Returns: + MenuId with semantic type safety + + Raises: + ValueError: If validate=True and value is invalid + """ + if validate: + if not value or not value.strip(): + raise ValueError("MenuId cannot be empty") + if len(value) > 100: + raise ValueError("MenuId is too long (max 100 characters)") + return MenuId(value) + + +# Type guard functions for runtime type checking +def is_branch_id(value: Any) -> TypeGuard[BranchId]: + """Check if a value is a BranchId at runtime. + + Args: + value: Value to check + + Returns: + True if value is a BranchId (string type), False otherwise + + Note: + This is a type guard function that helps with type narrowing. + At runtime, BranchId is just a string, so this checks for string type. + """ + return isinstance(value, str) + + +def is_action_id(value: Any) -> TypeGuard[ActionId]: + """Check if a value is an ActionId at runtime. + + Args: + value: Value to check + + Returns: + True if value is an ActionId (string type), False otherwise + + Note: + This is a type guard function that helps with type narrowing. + At runtime, ActionId is just a string, so this checks for string type. + """ + return isinstance(value, str) + + +def is_option_key(value: Any) -> TypeGuard[OptionKey]: + """Check if a value is an OptionKey at runtime. + + Args: + value: Value to check + + Returns: + True if value is an OptionKey (string type), False otherwise + + Note: + This is a type guard function that helps with type narrowing. + At runtime, OptionKey is just a string, so this checks for string type. + """ + return isinstance(value, str) + + +def is_menu_id(value: Any) -> TypeGuard[MenuId]: + """Check if a value is a MenuId at runtime. + + Args: + value: Value to check + + Returns: + True if value is a MenuId (string type), False otherwise + + Note: + This is a type guard function that helps with type narrowing. + At runtime, MenuId is just a string, so this checks for string type. + """ + return isinstance(value, str) diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py new file mode 100644 index 0000000..0fbcfc9 --- /dev/null +++ b/tests/unit/core/test_types.py @@ -0,0 +1,520 @@ +"""Tests for core semantic types for the wizard system. + +This module tests the semantic type definitions that provide type safety +for the wizard system. These are simple NewType definitions that prevent +type confusion while maintaining MyPy strict mode compliance. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +# Import the types we're testing (these will fail initially) +try: + from cli_patterns.core.types import ( + ActionId, + BranchId, + OptionKey, + StateValue, + is_action_id, + is_branch_id, + is_menu_id, + is_option_key, + make_action_id, + make_branch_id, + make_menu_id, + make_option_key, + ) +except ImportError: + # These imports will fail initially since the implementation doesn't exist + pass + +pytestmark = pytest.mark.unit + + +class TestSemanticTypeDefinitions: + """Test basic semantic type creation and identity.""" + + def test_branch_id_creation(self) -> None: + """ + GIVEN: A string value for a branch + WHEN: Creating a BranchId + THEN: The BranchId maintains the value but has distinct type identity + """ + branch_str = "main_menu" + branch_id = make_branch_id(branch_str) + + # Value preservation + assert str(branch_id) == branch_str + + # Type identity (will be checked by MyPy at compile time) + assert isinstance(branch_id, str) # Runtime check + + def test_action_id_creation(self) -> None: + """ + GIVEN: A string value for an action + WHEN: Creating an ActionId + THEN: The ActionId maintains the value but has distinct type identity + """ + action_str = "deploy_app" + action_id = make_action_id(action_str) + + assert str(action_id) == action_str + assert isinstance(action_id, str) + + def test_option_key_creation(self) -> None: + """ + GIVEN: A string value for an option key + WHEN: Creating an OptionKey + THEN: The OptionKey maintains the value but has distinct type identity + """ + key_str = "environment" + option_key = make_option_key(key_str) + + assert str(option_key) == key_str + assert isinstance(option_key, str) + + def test_menu_id_creation(self) -> None: + """ + GIVEN: A string value for a menu + WHEN: Creating a MenuId + THEN: The MenuId maintains the value but has distinct type identity + """ + menu_str = "settings_menu" + menu_id = make_menu_id(menu_str) + + assert str(menu_id) == menu_str + assert isinstance(menu_id, str) + + +class TestSemanticTypeDistinctness: + """Test that semantic types are distinct from each other and from str.""" + + def test_types_are_distinct_from_str(self) -> None: + """ + GIVEN: Various semantic types created from the same string value + WHEN: Checking type identity at runtime + THEN: All types derive from str but have semantic distinction + """ + base_str = "test" + + branch_id = make_branch_id(base_str) + action_id = make_action_id(base_str) + option_key = make_option_key(base_str) + menu_id = make_menu_id(base_str) + + # All are strings at runtime + for semantic_type in [branch_id, action_id, option_key, menu_id]: + assert isinstance(semantic_type, str) + assert str(semantic_type) == base_str + + def test_type_safety_in_collections(self) -> None: + """ + GIVEN: Semantic types used in collections + WHEN: Adding them to typed collections + THEN: The types maintain their semantic meaning in collections + """ + # Test BranchId in sets and lists + branch1 = make_branch_id("main") + branch2 = make_branch_id("settings") + branch3 = make_branch_id("main") # Duplicate value + + branch_set: set[BranchId] = {branch1, branch2, branch3} + assert len(branch_set) == 2 # Duplicate removed + + branch_list: list[BranchId] = [branch1, branch2, branch3] + assert len(branch_list) == 3 # Duplicates preserved + + # Test ActionId in dictionaries + action1 = make_action_id("deploy") + action2 = make_action_id("test") + + actions_dict: dict[ActionId, str] = { + action1: "Deploy the application", + action2: "Run tests", + } + assert len(actions_dict) == 2 + + def test_string_operations_work(self) -> None: + """ + GIVEN: Semantic types that derive from str + WHEN: Performing string operations + THEN: All string operations work normally + """ + branch_id = make_branch_id("main-menu") + + # String methods work + assert branch_id.upper() == "MAIN-MENU" + assert branch_id.lower() == "main-menu" + assert branch_id.replace("-", "_") == "main_menu" + assert branch_id.startswith("main") + assert branch_id.endswith("menu") + assert len(branch_id) == 9 + assert "main" in branch_id + + # String concatenation works + combined = branch_id + "_suffix" + assert combined == "main-menu_suffix" + + # String formatting works + formatted = f"Branch: {branch_id}" + assert formatted == "Branch: main-menu" + + +class TestSemanticTypeValidation: + """Test validation and error handling for semantic types.""" + + def test_factory_without_validation(self) -> None: + """ + GIVEN: Factory functions called without validation + WHEN: Creating semantic types with any string (even invalid) + THEN: No validation occurs (zero overhead by default) + """ + # Empty strings should work without validation + empty_branch = make_branch_id("") + empty_action = make_action_id("") + empty_option = make_option_key("") + empty_menu = make_menu_id("") + + # All should be empty strings + for semantic_type in [empty_branch, empty_action, empty_option, empty_menu]: + assert str(semantic_type) == "" + assert len(semantic_type) == 0 + + def test_factory_with_validation_rejects_empty(self) -> None: + """ + GIVEN: Factory functions called with validation enabled + WHEN: Creating semantic types with empty strings + THEN: ValueError is raised + """ + with pytest.raises(ValueError, match="BranchId cannot be empty"): + make_branch_id("", validate=True) + + with pytest.raises(ValueError, match="ActionId cannot be empty"): + make_action_id("", validate=True) + + with pytest.raises(ValueError, match="OptionKey cannot be empty"): + make_option_key("", validate=True) + + with pytest.raises(ValueError, match="MenuId cannot be empty"): + make_menu_id("", validate=True) + + def test_factory_with_validation_rejects_whitespace_only(self) -> None: + """ + GIVEN: Factory functions called with validation enabled + WHEN: Creating semantic types with whitespace-only strings + THEN: ValueError is raised + """ + with pytest.raises(ValueError, match="BranchId cannot be empty"): + make_branch_id(" ", validate=True) + + with pytest.raises(ValueError, match="ActionId cannot be empty"): + make_action_id("\t\n", validate=True) + + def test_factory_with_validation_rejects_too_long(self) -> None: + """ + GIVEN: Factory functions called with validation enabled + WHEN: Creating semantic types with strings that are too long + THEN: ValueError is raised + """ + too_long = "x" * 101 + + with pytest.raises(ValueError, match="BranchId is too long"): + make_branch_id(too_long, validate=True) + + with pytest.raises(ValueError, match="ActionId is too long"): + make_action_id(too_long, validate=True) + + with pytest.raises(ValueError, match="OptionKey is too long"): + make_option_key(too_long, validate=True) + + with pytest.raises(ValueError, match="MenuId is too long"): + make_menu_id(too_long, validate=True) + + def test_factory_with_validation_accepts_valid_strings(self) -> None: + """ + GIVEN: Factory functions called with validation enabled + WHEN: Creating semantic types with valid strings + THEN: Types are created successfully + """ + valid_branch = make_branch_id("main_menu", validate=True) + valid_action = make_action_id("deploy_action", validate=True) + valid_option = make_option_key("environment", validate=True) + valid_menu = make_menu_id("settings", validate=True) + + assert str(valid_branch) == "main_menu" + assert str(valid_action) == "deploy_action" + assert str(valid_option) == "environment" + assert str(valid_menu) == "settings" + + def test_special_character_handling(self) -> None: + """ + GIVEN: String values with special characters + WHEN: Creating semantic types + THEN: Special characters are preserved + """ + special_branch = make_branch_id("main-menu_v2") + special_action = make_action_id("deploy:prod") + special_option = make_option_key("file.path") + + assert str(special_branch) == "main-menu_v2" + assert str(special_action) == "deploy:prod" + assert str(special_option) == "file.path" + + +class TestSemanticTypeEquality: + """Test equality and hashing behavior of semantic types.""" + + def test_equality_with_same_type(self) -> None: + """ + GIVEN: Two semantic types of the same type with same value + WHEN: Comparing for equality + THEN: They are equal + """ + branch1 = make_branch_id("main") + branch2 = make_branch_id("main") + + assert branch1 == branch2 + assert not (branch1 != branch2) + + def test_equality_with_different_values(self) -> None: + """ + GIVEN: Two semantic types of the same type with different values + WHEN: Comparing for equality + THEN: They are not equal + """ + branch1 = make_branch_id("main") + branch2 = make_branch_id("settings") + + assert branch1 != branch2 + assert not (branch1 == branch2) + + def test_equality_with_raw_string(self) -> None: + """ + GIVEN: A semantic type and a raw string with the same value + WHEN: Comparing for equality + THEN: They are equal (since semantic types are NewType) + """ + branch_id = make_branch_id("main") + raw_str = "main" + + assert branch_id == raw_str + assert raw_str == branch_id + + def test_hashing_behavior(self) -> None: + """ + GIVEN: Semantic types with same and different values + WHEN: Using them as dictionary keys or in sets + THEN: Hashing works correctly + """ + branch1 = make_branch_id("main") + branch2 = make_branch_id("main") + branch3 = make_branch_id("settings") + + # Same value should have same hash + assert hash(branch1) == hash(branch2) + + # Can be used as dict keys + branch_dict = {branch1: "main_info", branch3: "settings_info"} + assert len(branch_dict) == 2 + assert branch_dict[branch2] == "main_info" # branch2 should work as key + + +class TestTypeGuards: + """Test type guard functions for runtime type checking.""" + + def test_is_branch_id(self) -> None: + """ + GIVEN: Various values including BranchId + WHEN: Checking with is_branch_id type guard + THEN: Returns True for strings, False otherwise + """ + branch = make_branch_id("main") + assert is_branch_id(branch) + assert is_branch_id("main") + assert not is_branch_id(123) + assert not is_branch_id(None) + assert not is_branch_id([]) + + def test_is_action_id(self) -> None: + """ + GIVEN: Various values including ActionId + WHEN: Checking with is_action_id type guard + THEN: Returns True for strings, False otherwise + """ + action = make_action_id("deploy") + assert is_action_id(action) + assert is_action_id("deploy") + assert not is_action_id(123) + assert not is_action_id(None) + + def test_is_option_key(self) -> None: + """ + GIVEN: Various values including OptionKey + WHEN: Checking with is_option_key type guard + THEN: Returns True for strings, False otherwise + """ + option = make_option_key("environment") + assert is_option_key(option) + assert is_option_key("environment") + assert not is_option_key(123) + + def test_is_menu_id(self) -> None: + """ + GIVEN: Various values including MenuId + WHEN: Checking with is_menu_id type guard + THEN: Returns True for strings, False otherwise + """ + menu = make_menu_id("settings") + assert is_menu_id(menu) + assert is_menu_id("settings") + assert not is_menu_id(123) + + +class TestSemanticTypeUsagePatterns: + """Test common usage patterns and best practices.""" + + def test_function_signature_type_safety(self) -> None: + """ + GIVEN: Functions that expect specific semantic types + WHEN: Calling them with correct types + THEN: The calls work without type errors + """ + + def navigate_to_branch( + branch: BranchId, options: dict[OptionKey, StateValue] + ) -> str: + return f"Navigating to {branch} with {len(options)} options" + + branch = make_branch_id("main") + opts = { + make_option_key("env"): "production", + make_option_key("region"): "us-west-2", + } + + result = navigate_to_branch(branch, opts) + assert "main" in result + assert "2 options" in result + + def test_type_conversion_patterns(self) -> None: + """ + GIVEN: Raw strings that need to be converted to semantic types + WHEN: Converting them explicitly + THEN: The conversion preserves value but adds type safety + """ + raw_branches = ["main", "settings", "deploy"] + semantic_branches = [make_branch_id(b) for b in raw_branches] + + assert len(semantic_branches) == 3 + for raw, semantic in zip(raw_branches, semantic_branches): + assert str(semantic) == raw + + def test_mixed_type_collections(self) -> None: + """ + GIVEN: Collections containing multiple semantic types + WHEN: Working with them + THEN: Type safety is maintained + """ + # Dictionary with mixed semantic types as keys + wizard_data: dict[str, Any] = { + make_branch_id("main"): "main_branch", + make_action_id("deploy"): "deploy_action", + make_option_key("env"): "production", + make_menu_id("settings"): "settings_menu", + } + + assert len(wizard_data) == 4 + + # All keys are strings at runtime but have semantic meaning + for key in wizard_data.keys(): + assert isinstance(key, str) + + +class TestStateValueType: + """Test StateValue type alias for JSON-serializable values.""" + + def test_state_value_accepts_json_types(self) -> None: + """ + GIVEN: Various JSON-serializable values + WHEN: Using them as StateValue + THEN: They are accepted by the type system + """ + import json + + # All these should be valid StateValue types + state_values: list[StateValue] = [ + "string_value", + 123, + 45.67, + True, + False, + None, + ["list", "of", "values"], + {"key": "value", "nested": {"data": 123}}, + ] + + # Should be JSON-serializable + for value in state_values: + json_str = json.dumps(value) + assert json_str is not None + + def test_state_value_in_collections(self) -> None: + """ + GIVEN: StateValue used in option collections + WHEN: Building option dictionaries + THEN: Type safety is maintained + """ + options: dict[OptionKey, StateValue] = { + make_option_key("string_opt"): "value", + make_option_key("number_opt"): 42, + make_option_key("bool_opt"): True, + make_option_key("list_opt"): [1, 2, 3], + make_option_key("dict_opt"): {"nested": "data"}, + } + + assert len(options) == 5 + assert options[make_option_key("string_opt")] == "value" + assert options[make_option_key("number_opt")] == 42 + + +class TestSemanticTypeCompatibility: + """Test compatibility with existing code and libraries.""" + + def test_json_serialization(self) -> None: + """ + GIVEN: Semantic types in data structures + WHEN: Serializing to JSON + THEN: Serialization works normally + """ + import json + + data = { + "branch": make_branch_id("main"), + "action": make_action_id("deploy"), + "options": { + make_option_key("env"): "prod", + make_option_key("region"): "us-west", + }, + } + + # Should serialize without errors + json_str = json.dumps(data, default=str) + assert "main" in json_str + assert "deploy" in json_str + assert "prod" in json_str + + def test_string_formatting_compatibility(self) -> None: + """ + GIVEN: Semantic types used in string formatting + WHEN: Using various formatting methods + THEN: All formatting works normally + """ + branch = make_branch_id("main") + action = make_action_id("deploy") + option = make_option_key("environment") + + # Format strings + formatted = f"Branch: {branch}, Action: {action}, Option: {option}" + assert formatted == "Branch: main, Action: deploy, Option: environment" From bc2e490073c4bd3bcbac969ed2c4ca42ca75f810 Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 17:53:44 -0400 Subject: [PATCH 2/9] feat(core): add Pydantic models for wizard configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement comprehensive Pydantic v2 models for wizard system including actions, options, branches, and complete wizard configuration. Models: - BaseConfig: Common fields (metadata, tags) for all configs - Action types: BashActionConfig, PythonActionConfig (discriminated unions) - Option types: String, Select, Path, Number, Boolean (discriminated unions) - MenuConfig: Navigation menu configuration - BranchConfig: Complete branch with actions, options, menus - WizardConfig: Top-level wizard configuration - SessionState: Unified wizard + parser state - Result types: ActionResult, CollectionResult, NavigationResult Features: - Discriminated unions for type-safe extensibility - Pydantic v2 with ConfigDict - Strict validation enabled - Field descriptions for all attributes - 100% test coverage with 42 comprehensive tests - MyPy strict mode compliance Part of CLI-4: Minimal Core Type Definitions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/models.py | 344 ++++++++++++++- tests/unit/core/test_models.py | 761 ++++++++++++++++++++++++++++++++ 2 files changed, 1103 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/test_models.py diff --git a/src/cli_patterns/core/models.py b/src/cli_patterns/core/models.py index a210fbd..9901dd4 100644 --- a/src/cli_patterns/core/models.py +++ b/src/cli_patterns/core/models.py @@ -1,3 +1,343 @@ -"""Core data models for CLI Patterns.""" +"""Core data models for CLI Patterns. -# Placeholder for core data models +This module defines Pydantic models for the wizard configuration structure. +All models use MyPy strict mode and Pydantic v2 features including: +- Discriminated unions for extensibility +- Field validation +- JSON serialization/deserialization +- StrictModel base class for type safety +""" + +from __future__ import annotations + +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from cli_patterns.core.types import ( + ActionId, + BranchId, + MenuId, + OptionKey, +) + +# StateValue is defined as Any for Pydantic compatibility +# The actual type constraint (JSON-serializable) is enforced at serialization time +StateValue = Any + + +class StrictModel(BaseModel): + """Base model with strict validation enabled. + + This ensures type safety and proper validation throughout the system. + """ + + model_config = ConfigDict( + # Strict mode for type safety + strict=True, + # Allow arbitrary types (for semantic types) + arbitrary_types_allowed=True, + # Extra fields are forbidden + extra="forbid", + ) + + +class BaseConfig(StrictModel): + """Base configuration providing common fields for all config types. + + This class provides metadata and tagging infrastructure that all + configuration types can use. + """ + + metadata: dict[str, Any] = Field(default_factory=dict) + """Arbitrary metadata for extensions and tooling.""" + + tags: list[str] = Field(default_factory=list) + """Tags for categorization and filtering.""" + + +# ============================================================================ +# Action Configuration Models +# ============================================================================ + + +class BashActionConfig(BaseConfig): + """Configuration for bash command actions. + + Executes a bash command with optional environment variables. + """ + + type: Literal["bash"] = Field( + default="bash", description="Action type discriminator" + ) + id: ActionId = Field(description="Unique action identifier") + name: str = Field(description="Human-readable action name") + description: Optional[str] = Field(default=None, description="Action description") + command: str = Field(description="Bash command to execute") + env: dict[str, str] = Field( + default_factory=dict, description="Environment variables for command" + ) + + +class PythonActionConfig(BaseConfig): + """Configuration for Python function actions. + + Calls a Python function from a specified module. + """ + + type: Literal["python"] = Field( + default="python", description="Action type discriminator" + ) + id: ActionId = Field(description="Unique action identifier") + name: str = Field(description="Human-readable action name") + description: Optional[str] = Field(default=None, description="Action description") + module: str = Field(description="Python module path") + function: str = Field(description="Function name to call") + + +# Discriminated union of all action types +# TODO: Future extension point - add new action types here +ActionConfigUnion = Union[BashActionConfig, PythonActionConfig] + + +# ============================================================================ +# Option Configuration Models +# ============================================================================ + + +class StringOptionConfig(BaseConfig): + """Configuration for string input options.""" + + type: Literal["string"] = Field( + default="string", description="Option type discriminator" + ) + id: OptionKey = Field(description="Unique option identifier") + name: str = Field(description="Human-readable option name") + description: str = Field(description="Option description/prompt") + default: Optional[str] = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether option is required") + + +class SelectOptionConfig(BaseConfig): + """Configuration for selection options (dropdown/menu).""" + + type: Literal["select"] = Field( + default="select", description="Option type discriminator" + ) + id: OptionKey = Field(description="Unique option identifier") + name: str = Field(description="Human-readable option name") + description: str = Field(description="Option description/prompt") + choices: list[str] = Field(description="Available choices") + default: Optional[str] = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether option is required") + + +class PathOptionConfig(BaseConfig): + """Configuration for file/directory path options.""" + + type: Literal["path"] = Field( + default="path", description="Option type discriminator" + ) + id: OptionKey = Field(description="Unique option identifier") + name: str = Field(description="Human-readable option name") + description: str = Field(description="Option description/prompt") + must_exist: bool = Field( + default=False, description="Whether path must exist for validation" + ) + default: Optional[str] = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether option is required") + + +class NumberOptionConfig(BaseConfig): + """Configuration for numeric input options.""" + + type: Literal["number"] = Field( + default="number", description="Option type discriminator" + ) + id: OptionKey = Field(description="Unique option identifier") + name: str = Field(description="Human-readable option name") + description: str = Field(description="Option description/prompt") + min_value: Optional[float] = Field( + default=None, description="Minimum allowed value" + ) + max_value: Optional[float] = Field( + default=None, description="Maximum allowed value" + ) + default: Optional[float] = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether option is required") + + +class BooleanOptionConfig(BaseConfig): + """Configuration for boolean (yes/no) options.""" + + type: Literal["boolean"] = Field( + default="boolean", description="Option type discriminator" + ) + id: OptionKey = Field(description="Unique option identifier") + name: str = Field(description="Human-readable option name") + description: str = Field(description="Option description/prompt") + default: Optional[bool] = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether option is required") + + +# Discriminated union of all option types +# TODO: Future extension point - add new option types here (e.g., multi-select, date, etc.) +OptionConfigUnion = Union[ + StringOptionConfig, + SelectOptionConfig, + PathOptionConfig, + NumberOptionConfig, + BooleanOptionConfig, +] + + +# ============================================================================ +# Menu and Navigation Configuration +# ============================================================================ + + +class MenuConfig(StrictModel): + """Configuration for navigation menu items. + + Menus allow tree-based navigation between branches. + """ + + id: MenuId = Field(description="Unique menu identifier") + label: str = Field(description="Menu item label displayed to user") + target: BranchId = Field(description="Target branch to navigate to") + description: Optional[str] = Field( + default=None, description="Optional menu description" + ) + + +# ============================================================================ +# Branch Configuration +# ============================================================================ + + +class BranchConfig(BaseConfig): + """Configuration for a wizard branch. + + A branch represents a screen/step in the wizard with actions, options, + and navigation menus. + """ + + id: BranchId = Field(description="Unique branch identifier") + title: str = Field(description="Branch title displayed to user") + description: Optional[str] = Field(default=None, description="Branch description") + actions: list[ActionConfigUnion] = Field( + default_factory=list, description="Actions available in this branch" + ) + options: list[OptionConfigUnion] = Field( + default_factory=list, description="Options to collect in this branch" + ) + menus: list[MenuConfig] = Field( + default_factory=list, description="Navigation menus in this branch" + ) + + +# ============================================================================ +# Wizard Configuration +# ============================================================================ + + +class WizardConfig(BaseConfig): + """Complete wizard configuration. + + This is the top-level configuration that defines an entire wizard, + including all branches and the entry point. + """ + + name: str = Field(description="Wizard name (identifier)") + version: str = Field(description="Wizard version (semver recommended)") + description: Optional[str] = Field(default=None, description="Wizard description") + entry_branch: BranchId = Field( + description="Initial branch to display when wizard starts" + ) + branches: list[BranchConfig] = Field(description="All branches in the wizard tree") + + # TODO: Add validator to ensure entry_branch exists in branches + # This would be done with @model_validator in Pydantic v2 + + +# ============================================================================ +# Session State +# ============================================================================ + + +class SessionState(StrictModel): + """Unified session state for wizard and parser. + + This model combines both wizard state (navigation, options) and + parser state (mode, history) into a single unified state. + """ + + # Wizard state + current_branch: Optional[BranchId] = Field( + default=None, description="Currently active branch" + ) + navigation_history: list[BranchId] = Field( + default_factory=list, description="Branch navigation history for 'back' command" + ) + option_values: dict[OptionKey, StateValue] = Field( + default_factory=dict, description="Collected option values" + ) + + # Shared state + variables: dict[str, StateValue] = Field( + default_factory=dict, + description="Variables for interpolation (e.g., ${var} in commands)", + ) + + # Parser state + parse_mode: str = Field(default="interactive", description="Current parsing mode") + command_history: list[str] = Field( + default_factory=list, description="Command history for readline/recall" + ) + + +# ============================================================================ +# Result Types +# ============================================================================ + + +class ActionResult(StrictModel): + """Result from executing an action. + + Contains success status, output, and error information. + """ + + action_id: ActionId = Field(description="ID of executed action") + success: bool = Field(description="Whether action succeeded") + output: str = Field(default="", description="Action output (stdout)") + exit_code: int = Field(default=0, description="Exit code (for bash actions)") + error: Optional[str] = Field(default=None, description="Error message if failed") + + +class CollectionResult(StrictModel): + """Result from collecting an option value. + + Contains the collected value or error information. + """ + + option_key: OptionKey = Field(description="Key of option being collected") + success: bool = Field(description="Whether collection succeeded") + value: Optional[StateValue] = Field( + default=None, description="Collected value if successful" + ) + error: Optional[str] = Field( + default=None, description="Error message if collection failed" + ) + + +class NavigationResult(StrictModel): + """Result from a navigation operation. + + Contains target branch and success/error information. + """ + + success: bool = Field(description="Whether navigation succeeded") + target: BranchId = Field(description="Target branch") + error: Optional[str] = Field( + default=None, description="Error message if navigation failed" + ) diff --git a/tests/unit/core/test_models.py b/tests/unit/core/test_models.py new file mode 100644 index 0000000..861f839 --- /dev/null +++ b/tests/unit/core/test_models.py @@ -0,0 +1,761 @@ +"""Tests for core data models. + +This module tests the Pydantic models that define the wizard configuration structure, +including actions, options, branches, and the complete wizard configuration. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +# Import the models we're testing (these will fail initially) +try: + from cli_patterns.core.models import ( + ActionResult, + BaseConfig, + BashActionConfig, + BooleanOptionConfig, + BranchConfig, + CollectionResult, + MenuConfig, + NavigationResult, + NumberOptionConfig, + PathOptionConfig, + PythonActionConfig, + SelectOptionConfig, + SessionState, + StringOptionConfig, + WizardConfig, + ) + from cli_patterns.core.types import ( + make_action_id, + make_branch_id, + make_menu_id, + make_option_key, + ) +except ImportError: + # These imports will fail initially since the implementation doesn't exist + pass + +pytestmark = pytest.mark.unit + + +class TestBaseConfig: + """Test the BaseConfig model that provides common fields.""" + + def test_base_config_with_defaults(self) -> None: + """ + GIVEN: No metadata or tags provided + WHEN: Creating a BaseConfig + THEN: Default values are used + """ + config = BaseConfig() + assert config.metadata == {} + assert config.tags == [] + + def test_base_config_with_metadata(self) -> None: + """ + GIVEN: Custom metadata + WHEN: Creating a BaseConfig + THEN: Metadata is stored correctly + """ + metadata = {"author": "test", "version": "1.0"} + config = BaseConfig(metadata=metadata) + assert config.metadata == metadata + + def test_base_config_with_tags(self) -> None: + """ + GIVEN: Custom tags + WHEN: Creating a BaseConfig + THEN: Tags are stored correctly + """ + tags = ["production", "important"] + config = BaseConfig(tags=tags) + assert config.tags == tags + + +class TestActionConfigs: + """Test action configuration models.""" + + def test_bash_action_config_minimal(self) -> None: + """ + GIVEN: Minimal bash action configuration + WHEN: Creating a BashActionConfig + THEN: Configuration is created with required fields + """ + config = BashActionConfig( + type="bash", + id=make_action_id("deploy"), + name="Deploy Application", + command="kubectl apply -f deploy.yaml", + ) + assert config.type == "bash" + assert config.id == make_action_id("deploy") + assert config.name == "Deploy Application" + assert config.command == "kubectl apply -f deploy.yaml" + assert config.env == {} + assert config.metadata == {} + assert config.tags == [] + + def test_bash_action_config_with_env(self) -> None: + """ + GIVEN: Bash action with environment variables + WHEN: Creating a BashActionConfig + THEN: Environment variables are stored + """ + config = BashActionConfig( + type="bash", + id=make_action_id("deploy"), + name="Deploy", + command="deploy.sh", + env={"ENV": "production", "REGION": "us-west-2"}, + ) + assert config.env == {"ENV": "production", "REGION": "us-west-2"} + + def test_python_action_config_minimal(self) -> None: + """ + GIVEN: Minimal python action configuration + WHEN: Creating a PythonActionConfig + THEN: Configuration is created with required fields + """ + config = PythonActionConfig( + type="python", + id=make_action_id("process"), + name="Process Data", + module="myapp.tasks", + function="process_data", + ) + assert config.type == "python" + assert config.id == make_action_id("process") + assert config.name == "Process Data" + assert config.module == "myapp.tasks" + assert config.function == "process_data" + + def test_action_discriminated_union(self) -> None: + """ + GIVEN: Different action types + WHEN: Using discriminated unions + THEN: Pydantic discriminates based on 'type' field + """ + bash_data = { + "type": "bash", + "id": "deploy", + "name": "Deploy", + "command": "deploy.sh", + } + python_data = { + "type": "python", + "id": "process", + "name": "Process", + "module": "app", + "function": "run", + } + + bash_config = BashActionConfig(**bash_data) + python_config = PythonActionConfig(**python_data) + + assert bash_config.type == "bash" + assert python_config.type == "python" + + +class TestOptionConfigs: + """Test option configuration models.""" + + def test_string_option_config(self) -> None: + """ + GIVEN: String option configuration + WHEN: Creating a StringOptionConfig + THEN: Configuration is created correctly + """ + config = StringOptionConfig( + type="string", + id=make_option_key("username"), + name="Username", + description="Enter your username", + default="admin", + ) + assert config.type == "string" + assert config.id == make_option_key("username") + assert config.name == "Username" + assert config.description == "Enter your username" + assert config.default == "admin" + assert config.required is False + + def test_select_option_config(self) -> None: + """ + GIVEN: Select option with choices + WHEN: Creating a SelectOptionConfig + THEN: Choices are stored correctly + """ + config = SelectOptionConfig( + type="select", + id=make_option_key("environment"), + name="Environment", + description="Select environment", + choices=["dev", "staging", "production"], + default="dev", + ) + assert config.type == "select" + assert config.choices == ["dev", "staging", "production"] + assert config.default == "dev" + + def test_path_option_config(self) -> None: + """ + GIVEN: Path option configuration + WHEN: Creating a PathOptionConfig + THEN: Must_exist flag works correctly + """ + config = PathOptionConfig( + type="path", + id=make_option_key("config_file"), + name="Config File", + description="Path to config file", + must_exist=True, + default="./config.yaml", + ) + assert config.type == "path" + assert config.must_exist is True + assert config.default == "./config.yaml" + + def test_number_option_config(self) -> None: + """ + GIVEN: Number option with constraints + WHEN: Creating a NumberOptionConfig + THEN: Constraints are stored correctly + """ + config = NumberOptionConfig( + type="number", + id=make_option_key("port"), + name="Port", + description="Server port", + min_value=1024, + max_value=65535, + default=8080, + ) + assert config.type == "number" + assert config.min_value == 1024 + assert config.max_value == 65535 + assert config.default == 8080 + + def test_boolean_option_config(self) -> None: + """ + GIVEN: Boolean option configuration + WHEN: Creating a BooleanOptionConfig + THEN: Configuration is created correctly + """ + config = BooleanOptionConfig( + type="boolean", + id=make_option_key("verbose"), + name="Verbose", + description="Enable verbose logging", + default=False, + ) + assert config.type == "boolean" + assert config.default is False + + def test_required_option(self) -> None: + """ + GIVEN: Required option without default + WHEN: Creating an option config + THEN: Required flag is set appropriately + """ + config = StringOptionConfig( + type="string", + id=make_option_key("api_key"), + name="API Key", + description="Required API key", + required=True, + ) + assert config.required is True + assert config.default is None + + +class TestMenuConfig: + """Test menu configuration for navigation.""" + + def test_menu_config_creation(self) -> None: + """ + GIVEN: Menu configuration data + WHEN: Creating a MenuConfig + THEN: Configuration is created correctly + """ + config = MenuConfig( + id=make_menu_id("settings_menu"), + label="Settings", + target=make_branch_id("settings_branch"), + ) + assert config.id == make_menu_id("settings_menu") + assert config.label == "Settings" + assert config.target == make_branch_id("settings_branch") + + def test_menu_config_with_description(self) -> None: + """ + GIVEN: Menu with optional description + WHEN: Creating a MenuConfig + THEN: Description is stored + """ + config = MenuConfig( + id=make_menu_id("advanced"), + label="Advanced Settings", + target=make_branch_id("advanced_branch"), + description="Configure advanced options", + ) + assert config.description == "Configure advanced options" + + +class TestBranchConfig: + """Test branch configuration models.""" + + def test_branch_config_minimal(self) -> None: + """ + GIVEN: Minimal branch configuration + WHEN: Creating a BranchConfig + THEN: Configuration is created with defaults + """ + config = BranchConfig( + id=make_branch_id("main"), + title="Main Menu", + ) + assert config.id == make_branch_id("main") + assert config.title == "Main Menu" + assert config.description is None + assert config.actions == [] + assert config.options == [] + assert config.menus == [] + + def test_branch_config_with_actions(self) -> None: + """ + GIVEN: Branch with actions + WHEN: Creating a BranchConfig + THEN: Actions are stored correctly + """ + action = BashActionConfig( + type="bash", + id=make_action_id("deploy"), + name="Deploy", + command="deploy.sh", + ) + config = BranchConfig( + id=make_branch_id("deploy_branch"), + title="Deploy Menu", + actions=[action], + ) + assert len(config.actions) == 1 + assert config.actions[0].id == make_action_id("deploy") + + def test_branch_config_with_options(self) -> None: + """ + GIVEN: Branch with options + WHEN: Creating a BranchConfig + THEN: Options are stored correctly + """ + option = StringOptionConfig( + type="string", + id=make_option_key("username"), + name="Username", + description="Enter username", + ) + config = BranchConfig( + id=make_branch_id("config_branch"), + title="Configuration", + options=[option], + ) + assert len(config.options) == 1 + assert config.options[0].id == make_option_key("username") + + def test_branch_config_with_menus(self) -> None: + """ + GIVEN: Branch with navigation menus + WHEN: Creating a BranchConfig + THEN: Menus are stored correctly + """ + menu = MenuConfig( + id=make_menu_id("settings"), + label="Settings", + target=make_branch_id("settings_branch"), + ) + config = BranchConfig( + id=make_branch_id("main"), + title="Main Menu", + menus=[menu], + ) + assert len(config.menus) == 1 + assert config.menus[0].id == make_menu_id("settings") + + def test_branch_config_complete(self) -> None: + """ + GIVEN: Branch with all components + WHEN: Creating a complete BranchConfig + THEN: All components are stored correctly + """ + action = BashActionConfig( + type="bash", + id=make_action_id("deploy"), + name="Deploy", + command="deploy.sh", + ) + option = StringOptionConfig( + type="string", + id=make_option_key("env"), + name="Environment", + description="Target environment", + ) + menu = MenuConfig( + id=make_menu_id("settings"), + label="Settings", + target=make_branch_id("settings"), + ) + + config = BranchConfig( + id=make_branch_id("main"), + title="Main Menu", + description="Main application menu", + actions=[action], + options=[option], + menus=[menu], + metadata={"version": "1.0"}, + tags=["main", "entry"], + ) + + assert config.id == make_branch_id("main") + assert config.title == "Main Menu" + assert config.description == "Main application menu" + assert len(config.actions) == 1 + assert len(config.options) == 1 + assert len(config.menus) == 1 + assert config.metadata == {"version": "1.0"} + assert config.tags == ["main", "entry"] + + +class TestWizardConfig: + """Test complete wizard configuration.""" + + def test_wizard_config_minimal(self) -> None: + """ + GIVEN: Minimal wizard configuration + WHEN: Creating a WizardConfig + THEN: Configuration is created with required fields + """ + branch = BranchConfig( + id=make_branch_id("main"), + title="Main Menu", + ) + config = WizardConfig( + name="test-wizard", + version="1.0.0", + entry_branch=make_branch_id("main"), + branches=[branch], + ) + assert config.name == "test-wizard" + assert config.version == "1.0.0" + assert config.entry_branch == make_branch_id("main") + assert len(config.branches) == 1 + + def test_wizard_config_with_description(self) -> None: + """ + GIVEN: Wizard with description + WHEN: Creating a WizardConfig + THEN: Description is stored + """ + branch = BranchConfig(id=make_branch_id("main"), title="Main") + config = WizardConfig( + name="test-wizard", + version="1.0.0", + description="A test wizard", + entry_branch=make_branch_id("main"), + branches=[branch], + ) + assert config.description == "A test wizard" + + def test_wizard_config_validates_entry_branch_exists(self) -> None: + """ + GIVEN: Wizard with entry_branch that doesn't exist in branches + WHEN: Creating a WizardConfig + THEN: Validation should succeed (validation is runtime, not construction) + """ + branch = BranchConfig(id=make_branch_id("main"), title="Main") + # Entry branch points to non-existent branch - this is allowed at construction + config = WizardConfig( + name="test-wizard", + version="1.0.0", + entry_branch=make_branch_id("nonexistent"), + branches=[branch], + ) + assert config.entry_branch == make_branch_id("nonexistent") + + def test_wizard_config_multiple_branches(self) -> None: + """ + GIVEN: Wizard with multiple branches + WHEN: Creating a WizardConfig + THEN: All branches are stored + """ + main_branch = BranchConfig(id=make_branch_id("main"), title="Main") + settings_branch = BranchConfig(id=make_branch_id("settings"), title="Settings") + deploy_branch = BranchConfig(id=make_branch_id("deploy"), title="Deploy") + + config = WizardConfig( + name="multi-branch-wizard", + version="1.0.0", + entry_branch=make_branch_id("main"), + branches=[main_branch, settings_branch, deploy_branch], + ) + assert len(config.branches) == 3 + + +class TestSessionState: + """Test session state model.""" + + def test_session_state_defaults(self) -> None: + """ + GIVEN: No initial state provided + WHEN: Creating a SessionState + THEN: Default values are used + """ + state = SessionState() + assert state.current_branch is None + assert state.navigation_history == [] + assert state.option_values == {} + assert state.variables == {} + assert state.parse_mode == "interactive" + assert state.command_history == [] + + def test_session_state_with_current_branch(self) -> None: + """ + GIVEN: Initial current branch + WHEN: Creating a SessionState + THEN: Current branch is set + """ + state = SessionState(current_branch=make_branch_id("main")) + assert state.current_branch == make_branch_id("main") + + def test_session_state_with_navigation_history(self) -> None: + """ + GIVEN: Navigation history + WHEN: Creating a SessionState + THEN: History is stored + """ + history = [make_branch_id("main"), make_branch_id("settings")] + state = SessionState(navigation_history=history) + assert state.navigation_history == history + + def test_session_state_with_option_values(self) -> None: + """ + GIVEN: Option values + WHEN: Creating a SessionState + THEN: Values are stored + """ + options = { + make_option_key("username"): "admin", + make_option_key("port"): 8080, + } + state = SessionState(option_values=options) + assert state.option_values == options + + def test_session_state_with_variables(self) -> None: + """ + GIVEN: Variables for interpolation + WHEN: Creating a SessionState + THEN: Variables are stored + """ + variables = {"env": "production", "region": "us-west-2"} + state = SessionState(variables=variables) + assert state.variables == variables + + def test_session_state_with_parse_mode(self) -> None: + """ + GIVEN: Custom parse mode + WHEN: Creating a SessionState + THEN: Parse mode is set + """ + state = SessionState(parse_mode="shell") + assert state.parse_mode == "shell" + + def test_session_state_with_command_history(self) -> None: + """ + GIVEN: Command history + WHEN: Creating a SessionState + THEN: History is stored + """ + history = ["deploy", "status", "help"] + state = SessionState(command_history=history) + assert state.command_history == history + + def test_session_state_complete(self) -> None: + """ + GIVEN: Complete session state + WHEN: Creating a SessionState + THEN: All fields are stored correctly + """ + state = SessionState( + current_branch=make_branch_id("main"), + navigation_history=[make_branch_id("main")], + option_values={make_option_key("env"): "prod"}, + variables={"region": "us-west"}, + parse_mode="interactive", + command_history=["help"], + ) + assert state.current_branch == make_branch_id("main") + assert len(state.navigation_history) == 1 + assert len(state.option_values) == 1 + assert len(state.variables) == 1 + assert state.parse_mode == "interactive" + assert len(state.command_history) == 1 + + +class TestResultTypes: + """Test result types returned by protocols.""" + + def test_action_result_success(self) -> None: + """ + GIVEN: Successful action execution + WHEN: Creating an ActionResult + THEN: Success status is recorded + """ + result = ActionResult( + action_id=make_action_id("deploy"), + success=True, + output="Deployment successful", + ) + assert result.action_id == make_action_id("deploy") + assert result.success is True + assert result.output == "Deployment successful" + assert result.exit_code == 0 + + def test_action_result_failure(self) -> None: + """ + GIVEN: Failed action execution + WHEN: Creating an ActionResult + THEN: Failure status and error are recorded + """ + result = ActionResult( + action_id=make_action_id("deploy"), + success=False, + output="Deployment failed", + exit_code=1, + error="Connection timeout", + ) + assert result.success is False + assert result.exit_code == 1 + assert result.error == "Connection timeout" + + def test_collection_result_success(self) -> None: + """ + GIVEN: Successful option collection + WHEN: Creating a CollectionResult + THEN: Collected value is stored + """ + result = CollectionResult( + option_key=make_option_key("username"), + success=True, + value="admin", + ) + assert result.option_key == make_option_key("username") + assert result.success is True + assert result.value == "admin" + assert result.error is None + + def test_collection_result_failure(self) -> None: + """ + GIVEN: Failed option collection + WHEN: Creating a CollectionResult + THEN: Error is recorded + """ + result = CollectionResult( + option_key=make_option_key("port"), + success=False, + value=None, + error="Invalid port number", + ) + assert result.success is False + assert result.value is None + assert result.error == "Invalid port number" + + def test_navigation_result_success(self) -> None: + """ + GIVEN: Successful navigation + WHEN: Creating a NavigationResult + THEN: Target branch is recorded + """ + result = NavigationResult( + success=True, + target=make_branch_id("settings"), + ) + assert result.success is True + assert result.target == make_branch_id("settings") + assert result.error is None + + def test_navigation_result_failure(self) -> None: + """ + GIVEN: Failed navigation + WHEN: Creating a NavigationResult + THEN: Error is recorded + """ + result = NavigationResult( + success=False, + target=make_branch_id("invalid"), + error="Branch not found", + ) + assert result.success is False + assert result.error == "Branch not found" + + +class TestPydanticValidation: + """Test Pydantic validation features.""" + + def test_required_fields_validation(self) -> None: + """ + GIVEN: Missing required fields + WHEN: Creating a model + THEN: ValidationError is raised + """ + with pytest.raises(ValidationError): + BashActionConfig(type="bash", name="Deploy") # Missing id and command + + def test_type_field_validation(self) -> None: + """ + GIVEN: Invalid type field + WHEN: Creating an action config + THEN: ValidationError is raised + """ + with pytest.raises(ValidationError): + BashActionConfig( + type="invalid", # Should be "bash" + id=make_action_id("deploy"), + name="Deploy", + command="deploy.sh", + ) + + def test_json_serialization(self) -> None: + """ + GIVEN: A valid model + WHEN: Serializing to JSON + THEN: JSON is correctly formatted + """ + config = BashActionConfig( + type="bash", + id=make_action_id("deploy"), + name="Deploy", + command="deploy.sh", + ) + json_data = config.model_dump() + assert json_data["type"] == "bash" + assert json_data["id"] == "deploy" + assert json_data["name"] == "Deploy" + assert json_data["command"] == "deploy.sh" + + def test_json_deserialization(self) -> None: + """ + GIVEN: JSON data + WHEN: Deserializing to model + THEN: Model is correctly created + """ + json_data = { + "type": "bash", + "id": "deploy", + "name": "Deploy", + "command": "deploy.sh", + } + config = BashActionConfig(**json_data) + assert config.id == make_action_id("deploy") + assert config.name == "Deploy" From a0b4790262f552845ca6ad52c80dba7ddcae8ddb Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 17:53:51 -0400 Subject: [PATCH 3/9] feat(core): add runtime-checkable protocols for wizard engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement core protocols defining interfaces for action execution, option collection, and navigation control. Protocols: - ActionExecutor: Execute actions (bash, python, etc.) - OptionCollector: Collect option values from users - NavigationController: Handle wizard navigation Features: - All protocols are @runtime_checkable for isinstance() checks - Enables dependency injection and multiple implementations - Type-safe interfaces with Protocol typing - Comprehensive documentation with usage examples - 100% test coverage with 15 tests including integration patterns - MyPy strict mode compliance Part of CLI-4: Minimal Core Type Definitions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/protocols.py | 122 ++++++- tests/unit/core/test_protocols.py | 495 +++++++++++++++++++++++++++++ 2 files changed, 615 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/test_protocols.py diff --git a/src/cli_patterns/core/protocols.py b/src/cli_patterns/core/protocols.py index 8a98289..8b78dc6 100644 --- a/src/cli_patterns/core/protocols.py +++ b/src/cli_patterns/core/protocols.py @@ -1,3 +1,121 @@ -"""Protocol definitions for CLI Patterns.""" +"""Protocol definitions for CLI Patterns. -# Placeholder for protocol definitions +This module defines the core protocols (interfaces) that implementation classes +must satisfy. Protocols enable: +- Dependency injection +- Multiple implementations +- Type-safe interfaces +- Runtime checking (with @runtime_checkable) + +All protocols are runtime-checkable, meaning isinstance() checks work at runtime. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from cli_patterns.core.models import ( + ActionConfigUnion, + ActionResult, + CollectionResult, + NavigationResult, + OptionConfigUnion, + SessionState, +) +from cli_patterns.core.types import BranchId + + +@runtime_checkable +class ActionExecutor(Protocol): + """Protocol for executing actions. + + Implementations of this protocol handle the execution of actions + (bash commands, Python functions, etc.) and return results. + + Example: + class BashExecutor: + def execute(self, action: ActionConfigUnion, state: SessionState) -> ActionResult: + if isinstance(action, BashActionConfig): + # Execute bash command + result = subprocess.run(action.command, ...) + return ActionResult(...) + """ + + def execute(self, action: ActionConfigUnion, state: SessionState) -> ActionResult: + """Execute an action and return the result. + + Args: + action: The action configuration to execute + state: Current session state + + Returns: + ActionResult containing success status, output, and errors + """ + ... + + +@runtime_checkable +class OptionCollector(Protocol): + """Protocol for collecting option values from users. + + Implementations of this protocol handle the interactive collection + of option values (strings, selections, paths, etc.) and return results. + + Example: + class InteractiveCollector: + def collect(self, option: OptionConfigUnion, state: SessionState) -> CollectionResult: + if isinstance(option, StringOptionConfig): + # Prompt user for string input + value = input(f"{option.description}: ") + return CollectionResult(...) + """ + + def collect( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + """Collect an option value from the user. + + Args: + option: The option configuration to collect + state: Current session state + + Returns: + CollectionResult containing the collected value or error + """ + ... + + +@runtime_checkable +class NavigationController(Protocol): + """Protocol for controlling wizard navigation. + + Implementations of this protocol handle navigation between branches + in the wizard tree, including history management. + + Example: + class TreeNavigator: + def navigate(self, target: BranchId, state: SessionState) -> NavigationResult: + # Update state with new branch + state.navigation_history.append(state.current_branch) + state.current_branch = target + return NavigationResult(...) + """ + + def navigate(self, target: BranchId, state: SessionState) -> NavigationResult: + """Navigate to a target branch. + + Args: + target: The branch ID to navigate to + state: Current session state (will be modified) + + Returns: + NavigationResult containing success status and target + """ + ... + + +# TODO: Future protocol extension points +# - ValidationProtocol: For custom option validation +# - InterpolationProtocol: For variable interpolation in commands +# - PersistenceProtocol: For session state persistence +# - ThemeProtocol: For custom theming (may already exist in ui.design) diff --git a/tests/unit/core/test_protocols.py b/tests/unit/core/test_protocols.py new file mode 100644 index 0000000..362d86c --- /dev/null +++ b/tests/unit/core/test_protocols.py @@ -0,0 +1,495 @@ +"""Tests for core protocol definitions. + +This module tests the protocol definitions that define the interfaces for +action execution, option collection, and navigation control. +""" + +from __future__ import annotations + +from typing import Protocol + +import pytest + +# Import the protocols we're testing (these will fail initially) +try: + from cli_patterns.core.models import ( + ActionConfigUnion, + ActionResult, + BashActionConfig, + CollectionResult, + NavigationResult, + OptionConfigUnion, + SessionState, + StringOptionConfig, + ) + from cli_patterns.core.protocols import ( + ActionExecutor, + NavigationController, + OptionCollector, + ) + from cli_patterns.core.types import ( + BranchId, + make_action_id, + make_branch_id, + make_option_key, + ) +except ImportError: + # These imports will fail initially since the implementation doesn't exist + pass + +pytestmark = pytest.mark.unit + + +class TestActionExecutorProtocol: + """Test the ActionExecutor protocol definition and compliance.""" + + def test_protocol_is_runtime_checkable(self) -> None: + """ + GIVEN: The ActionExecutor protocol + WHEN: Checking if it's runtime checkable + THEN: It should be a Protocol with runtime_checkable decorator + """ + assert issubclass(ActionExecutor, Protocol) + # Check that we can use isinstance with it (runtime_checkable) + assert hasattr(ActionExecutor, "_is_runtime_protocol") + + def test_concrete_implementation_satisfies_protocol(self) -> None: + """ + GIVEN: A concrete class implementing ActionExecutor + WHEN: Checking protocol compliance + THEN: The implementation satisfies the protocol + """ + + class ConcreteExecutor: + """Concrete implementation of ActionExecutor.""" + + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + """Execute an action.""" + if isinstance(action, BashActionConfig): + return ActionResult( + action_id=action.id, + success=True, + output="Command executed", + ) + return ActionResult( + action_id=action.id, + success=False, + error="Unsupported action type", + ) + + # Should be able to create instance + executor = ConcreteExecutor() + + # Should satisfy protocol at runtime + assert isinstance(executor, ActionExecutor) + + def test_protocol_execute_method_signature(self) -> None: + """ + GIVEN: ActionExecutor protocol + WHEN: Inspecting the execute method + THEN: Method signature matches expected interface + """ + + class TestExecutor: + """Test executor for signature verification.""" + + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + """Execute method with correct signature.""" + return ActionResult(action_id=action.id, success=True, output="test") + + executor = TestExecutor() + assert isinstance(executor, ActionExecutor) + + # Create test data + action = BashActionConfig( + type="bash", + id=make_action_id("test_action"), + name="Test Action", + command="echo test", + ) + state = SessionState() + + # Execute should work + result = executor.execute(action, state) + assert result.success is True + + def test_missing_execute_method_fails_protocol(self) -> None: + """ + GIVEN: A class without execute method + WHEN: Checking protocol compliance + THEN: It should not satisfy the protocol + """ + + class NotAnExecutor: + """Class that doesn't implement ActionExecutor.""" + + def run( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + """Wrong method name.""" + return ActionResult(action_id=action.id, success=True, output="") + + not_executor = NotAnExecutor() + assert not isinstance(not_executor, ActionExecutor) + + +class TestOptionCollectorProtocol: + """Test the OptionCollector protocol definition and compliance.""" + + def test_protocol_is_runtime_checkable(self) -> None: + """ + GIVEN: The OptionCollector protocol + WHEN: Checking if it's runtime checkable + THEN: It should be a Protocol with runtime_checkable decorator + """ + assert issubclass(OptionCollector, Protocol) + assert hasattr(OptionCollector, "_is_runtime_protocol") + + def test_concrete_implementation_satisfies_protocol(self) -> None: + """ + GIVEN: A concrete class implementing OptionCollector + WHEN: Checking protocol compliance + THEN: The implementation satisfies the protocol + """ + + class ConcreteCollector: + """Concrete implementation of OptionCollector.""" + + def collect( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + """Collect an option value.""" + return CollectionResult( + option_key=option.id, + success=True, + value=option.default if option.default else "default_value", + ) + + collector = ConcreteCollector() + assert isinstance(collector, OptionCollector) + + def test_protocol_collect_method_signature(self) -> None: + """ + GIVEN: OptionCollector protocol + WHEN: Inspecting the collect method + THEN: Method signature matches expected interface + """ + + class TestCollector: + """Test collector for signature verification.""" + + def collect( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + """Collect method with correct signature.""" + return CollectionResult( + option_key=option.id, success=True, value="test_value" + ) + + collector = TestCollector() + assert isinstance(collector, OptionCollector) + + # Create test data + option = StringOptionConfig( + type="string", + id=make_option_key("test_option"), + name="Test Option", + description="A test option", + ) + state = SessionState() + + # Collect should work + result = collector.collect(option, state) + assert result.success is True + + def test_missing_collect_method_fails_protocol(self) -> None: + """ + GIVEN: A class without collect method + WHEN: Checking protocol compliance + THEN: It should not satisfy the protocol + """ + + class NotACollector: + """Class that doesn't implement OptionCollector.""" + + def gather( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + """Wrong method name.""" + return CollectionResult( + option_key=option.id, success=True, value="value" + ) + + not_collector = NotACollector() + assert not isinstance(not_collector, OptionCollector) + + +class TestNavigationControllerProtocol: + """Test the NavigationController protocol definition and compliance.""" + + def test_protocol_is_runtime_checkable(self) -> None: + """ + GIVEN: The NavigationController protocol + WHEN: Checking if it's runtime checkable + THEN: It should be a Protocol with runtime_checkable decorator + """ + assert issubclass(NavigationController, Protocol) + assert hasattr(NavigationController, "_is_runtime_protocol") + + def test_concrete_implementation_satisfies_protocol(self) -> None: + """ + GIVEN: A concrete class implementing NavigationController + WHEN: Checking protocol compliance + THEN: The implementation satisfies the protocol + """ + + class ConcreteNavigator: + """Concrete implementation of NavigationController.""" + + def navigate( + self, target: BranchId, state: SessionState + ) -> NavigationResult: + """Navigate to a branch.""" + return NavigationResult(success=True, target=target) + + navigator = ConcreteNavigator() + assert isinstance(navigator, NavigationController) + + def test_protocol_navigate_method_signature(self) -> None: + """ + GIVEN: NavigationController protocol + WHEN: Inspecting the navigate method + THEN: Method signature matches expected interface + """ + + class TestNavigator: + """Test navigator for signature verification.""" + + def navigate( + self, target: BranchId, state: SessionState + ) -> NavigationResult: + """Navigate method with correct signature.""" + return NavigationResult(success=True, target=target) + + navigator = TestNavigator() + assert isinstance(navigator, NavigationController) + + # Create test data + target = make_branch_id("target_branch") + state = SessionState() + + # Navigate should work + result = navigator.navigate(target, state) + assert result.success is True + assert result.target == target + + def test_missing_navigate_method_fails_protocol(self) -> None: + """ + GIVEN: A class without navigate method + WHEN: Checking protocol compliance + THEN: It should not satisfy the protocol + """ + + class NotANavigator: + """Class that doesn't implement NavigationController.""" + + def go_to(self, target: BranchId, state: SessionState) -> NavigationResult: + """Wrong method name.""" + return NavigationResult(success=True, target=target) + + not_navigator = NotANavigator() + assert not isinstance(not_navigator, NavigationController) + + +class TestProtocolIntegration: + """Test protocol integration and usage patterns.""" + + def test_protocols_can_be_used_as_type_hints(self) -> None: + """ + GIVEN: Protocol types + WHEN: Using them as type hints + THEN: Type hints work correctly + """ + + def execute_action( + executor: ActionExecutor, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + """Function accepting protocol type.""" + return executor.execute(action, state) + + def collect_option( + collector: OptionCollector, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + """Function accepting protocol type.""" + return collector.collect(option, state) + + def navigate_to( + navigator: NavigationController, target: BranchId, state: SessionState + ) -> NavigationResult: + """Function accepting protocol type.""" + return navigator.navigate(target, state) + + # Create concrete implementations + class TestExecutor: + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + return ActionResult( + action_id=action.id, success=True, output="executed" + ) + + class TestCollector: + def collect( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + return CollectionResult( + option_key=option.id, success=True, value="collected" + ) + + class TestNavigator: + def navigate( + self, target: BranchId, state: SessionState + ) -> NavigationResult: + return NavigationResult(success=True, target=target) + + # Use the functions with concrete implementations + action = BashActionConfig( + type="bash", + id=make_action_id("test"), + name="Test", + command="echo test", + ) + option = StringOptionConfig( + type="string", + id=make_option_key("test"), + name="Test", + description="Test", + ) + target = make_branch_id("test") + state = SessionState() + + action_result = execute_action(TestExecutor(), action, state) + assert action_result.success is True + + collection_result = collect_option(TestCollector(), option, state) + assert collection_result.success is True + + nav_result = navigate_to(TestNavigator(), target, state) + assert nav_result.success is True + + def test_protocols_enable_dependency_injection(self) -> None: + """ + GIVEN: Protocols defining interfaces + WHEN: Using them for dependency injection + THEN: Different implementations can be swapped + """ + + class WizardEngine: + """Engine that depends on protocols.""" + + def __init__( + self, + executor: ActionExecutor, + collector: OptionCollector, + navigator: NavigationController, + ) -> None: + self.executor = executor + self.collector = collector + self.navigator = navigator + + def run_action( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + """Run action using injected executor.""" + return self.executor.execute(action, state) + + # Create mock implementations + class MockExecutor: + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + return ActionResult(action_id=action.id, success=True, output="mocked") + + class MockCollector: + def collect( + self, option: OptionConfigUnion, state: SessionState + ) -> CollectionResult: + return CollectionResult( + option_key=option.id, success=True, value="mocked" + ) + + class MockNavigator: + def navigate( + self, target: BranchId, state: SessionState + ) -> NavigationResult: + return NavigationResult(success=True, target=target) + + # Inject mock implementations + engine = WizardEngine(MockExecutor(), MockCollector(), MockNavigator()) + + # Use the engine + action = BashActionConfig( + type="bash", + id=make_action_id("test"), + name="Test", + command="echo test", + ) + state = SessionState() + result = engine.run_action(action, state) + + assert result.success is True + assert result.output == "mocked" + + def test_protocols_support_multiple_implementations(self) -> None: + """ + GIVEN: A protocol definition + WHEN: Creating multiple implementations + THEN: All implementations satisfy the protocol + """ + + # Implementation 1: Simple executor + class SimpleExecutor: + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + return ActionResult(action_id=action.id, success=True, output="simple") + + # Implementation 2: Logging executor + class LoggingExecutor: + def __init__(self) -> None: + self.log: list[str] = [] + + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + self.log.append(f"Executing {action.id}") + return ActionResult(action_id=action.id, success=True, output="logged") + + # Implementation 3: Async-like executor + class AsyncExecutor: + async def execute_async( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + # Simulate async work + return ActionResult(action_id=action.id, success=True, output="async") + + def execute( + self, action: ActionConfigUnion, state: SessionState + ) -> ActionResult: + # Synchronous wrapper + return ActionResult( + action_id=action.id, success=True, output="async_sync" + ) + + # All should satisfy the protocol + simple = SimpleExecutor() + logging = LoggingExecutor() + async_exec = AsyncExecutor() + + assert isinstance(simple, ActionExecutor) + assert isinstance(logging, ActionExecutor) + assert isinstance(async_exec, ActionExecutor) From fd8f9e4b8b87b71a0a964a60e98ec48e1834b896 Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 17:53:59 -0400 Subject: [PATCH 4/9] feat(core): export complete type system from core module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add comprehensive exports for all types, models, and protocols from the core module, making them easily accessible to other parts of the framework. Exports: - Semantic types and factory functions - All configuration models (actions, options, branches, wizard) - All protocols (ActionExecutor, OptionCollector, NavigationController) - Result types - Type guards and collection aliases Part of CLI-4: Minimal Core Type Definitions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/__init__.py | 117 +++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/src/cli_patterns/core/__init__.py b/src/cli_patterns/core/__init__.py index aee9417..c8b9bed 100644 --- a/src/cli_patterns/core/__init__.py +++ b/src/cli_patterns/core/__init__.py @@ -1 +1,116 @@ -"""Core types and protocols for CLI Patterns.""" +"""Core types and protocols for CLI Patterns. + +This module provides the foundational type system for the CLI Patterns framework: + +Types: + - Semantic types: BranchId, ActionId, OptionKey, MenuId + - State value types: StateValue (JSON-serializable) + - Type guards and factory functions + +Models: + - Configuration models: WizardConfig, BranchConfig, MenuConfig + - Action models: BashActionConfig, PythonActionConfig + - Option models: StringOptionConfig, SelectOptionConfig, etc. + - State models: SessionState + - Result models: ActionResult, CollectionResult, NavigationResult + +Protocols: + - ActionExecutor: For executing actions + - OptionCollector: For collecting option values + - NavigationController: For navigation control +""" + +# Semantic types and utilities +# Configuration models +from cli_patterns.core.models import ( + ActionConfigUnion, + ActionResult, + BaseConfig, + BashActionConfig, + BooleanOptionConfig, + BranchConfig, + CollectionResult, + MenuConfig, + NavigationResult, + NumberOptionConfig, + OptionConfigUnion, + PathOptionConfig, + PythonActionConfig, + SelectOptionConfig, + SessionState, + StringOptionConfig, + WizardConfig, +) + +# Protocols +from cli_patterns.core.protocols import ( + ActionExecutor, + NavigationController, + OptionCollector, +) +from cli_patterns.core.types import ( + ActionId, + ActionList, + ActionSet, + BranchId, + BranchList, + BranchSet, + MenuId, + MenuList, + OptionDict, + OptionKey, + StateValue, + is_action_id, + is_branch_id, + is_menu_id, + is_option_key, + make_action_id, + make_branch_id, + make_menu_id, + make_option_key, +) + +__all__ = [ + # Types + "ActionId", + "ActionList", + "ActionSet", + "BranchId", + "BranchList", + "BranchSet", + "MenuId", + "MenuList", + "OptionDict", + "OptionKey", + "StateValue", + "is_action_id", + "is_branch_id", + "is_menu_id", + "is_option_key", + "make_action_id", + "make_branch_id", + "make_menu_id", + "make_option_key", + # Models + "ActionConfigUnion", + "ActionResult", + "BaseConfig", + "BashActionConfig", + "BooleanOptionConfig", + "BranchConfig", + "CollectionResult", + "MenuConfig", + "NavigationResult", + "NumberOptionConfig", + "OptionConfigUnion", + "PathOptionConfig", + "PythonActionConfig", + "SelectOptionConfig", + "SessionState", + "StringOptionConfig", + "WizardConfig", + # Protocols + "ActionExecutor", + "NavigationController", + "OptionCollector", +] From 7319abae646333c537ccbf60eb0d557cc8078ff0 Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 22:32:14 -0400 Subject: [PATCH 5/9] feat(core): add validators for DoS protection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement validation functions to prevent DoS attacks via deeply nested or excessively large data structures in StateValue fields. Validators: - validate_json_depth(): Prevent stack overflow (max 50 levels) - validate_collection_size(): Prevent memory exhaustion (max 1000 items) - validate_state_value(): Combined validation for StateValue Features: - Recursive depth checking with early termination - Total element counting in nested structures - Clear error messages with limits - Configurable limits via parameters - 27 comprehensive tests covering edge cases Security Benefits: - Prevents stack overflow during JSON serialization - Prevents memory exhaustion from large collections - Prevents CPU exhaustion during parsing - 100% test coverage Part of security hardening (Priority 2: DoS Protection) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/validators.py | 158 ++++++++++++++++++ tests/unit/core/test_validators.py | 239 ++++++++++++++++++++++++++++ 2 files changed, 397 insertions(+) create mode 100644 src/cli_patterns/core/validators.py create mode 100644 tests/unit/core/test_validators.py diff --git a/src/cli_patterns/core/validators.py b/src/cli_patterns/core/validators.py new file mode 100644 index 0000000..53e097b --- /dev/null +++ b/src/cli_patterns/core/validators.py @@ -0,0 +1,158 @@ +"""Validation utilities for CLI Patterns core types. + +This module provides security-focused validators to prevent DoS attacks +and ensure data integrity: + +- JSON depth validation (prevents stack overflow) +- Collection size validation (prevents memory exhaustion) +- StateValue validation (combined depth + size checks) +""" + +from __future__ import annotations + +from typing import Any + +# Configuration constants +MAX_JSON_DEPTH = 50 +"""Maximum nesting depth for JSON-serializable values. + +This prevents stack overflow during serialization and CPU exhaustion +during parsing. Default: 50 levels. +""" + +MAX_COLLECTION_SIZE = 1000 +"""Maximum total number of items in collections (lists, dicts). + +This prevents memory exhaustion from excessively large data structures. +Default: 1000 items (counting nested items recursively). +""" + + +class ValidationError(Exception): + """Raised when validation fails. + + This exception is raised by validators when input doesn't meet + security or integrity requirements. + """ + + pass + + +def validate_json_depth(value: Any, max_depth: int = MAX_JSON_DEPTH) -> None: + """Validate that JSON value doesn't exceed maximum nesting depth. + + This prevents DoS attacks via deeply nested structures that cause: + - Stack overflow during serialization + - Excessive memory consumption + - CPU exhaustion during parsing + + Args: + value: Value to validate (must be JSON-serializable) + max_depth: Maximum allowed nesting depth (default: 50) + + Raises: + ValidationError: If nesting exceeds max_depth + + Example: + >>> validate_json_depth({"a": {"b": {"c": 1}}}) # OK + >>> validate_json_depth(create_deeply_nested(100)) # Raises ValidationError + """ + + def check_depth(obj: Any, current_depth: int = 0) -> int: + """Recursively check nesting depth.""" + if current_depth > max_depth: + raise ValidationError( + f"JSON nesting too deep: {current_depth} levels " + f"(maximum: {max_depth})" + ) + + if isinstance(obj, dict): + if not obj: # Empty dict is depth 0 + return current_depth + return max(check_depth(v, current_depth + 1) for v in obj.values()) + elif isinstance(obj, list): + if not obj: # Empty list is depth 0 + return current_depth + return max(check_depth(item, current_depth + 1) for item in obj) + else: + # Primitive value + return current_depth + + check_depth(value) + + +def validate_collection_size(value: Any, max_size: int = MAX_COLLECTION_SIZE) -> None: + """Validate that collection doesn't exceed maximum size. + + This prevents DoS attacks via large collections that cause memory exhaustion. + Counts all items recursively in nested structures. + + Args: + value: Collection to validate (dict or list) + max_size: Maximum allowed total size (default: 1000) + + Raises: + ValidationError: If collection exceeds max_size + + Example: + >>> validate_collection_size([1, 2, 3]) # OK + >>> validate_collection_size([1] * 10000) # Raises ValidationError + """ + + def check_size(obj: Any) -> int: + """Recursively count total elements.""" + count = 0 + + if isinstance(obj, dict): + count += len(obj) + if count > max_size: + raise ValidationError( + f"Collection too large: {count} items (maximum: {max_size})" + ) + for v in obj.values(): + count += check_size(v) + if count > max_size: + raise ValidationError( + f"Collection too large: {count} items (maximum: {max_size})" + ) + elif isinstance(obj, list): + count += len(obj) + if count > max_size: + raise ValidationError( + f"Collection too large: {count} items (maximum: {max_size})" + ) + for item in obj: + count += check_size(item) + if count > max_size: + raise ValidationError( + f"Collection too large: {count} items (maximum: {max_size})" + ) + + return count + + check_size(value) + + +def validate_state_value(value: Any) -> None: + """Validate StateValue meets all safety requirements. + + This is the main validation function for StateValue types. + It combines depth and size checks to ensure data safety. + + Checks: + - Nesting depth within limits (default: 50 levels) + - Collection size within limits (default: 1000 items) + - Type is JSON-serializable (implicit - errors on non-JSON types) + + Args: + value: StateValue to validate + + Raises: + ValidationError: If validation fails + + Example: + >>> validate_state_value({"user": {"name": "test", "age": 30}}) # OK + >>> validate_state_value(create_huge_dict()) # Raises ValidationError + """ + validate_json_depth(value) + validate_collection_size(value) diff --git a/tests/unit/core/test_validators.py b/tests/unit/core/test_validators.py new file mode 100644 index 0000000..c1767dd --- /dev/null +++ b/tests/unit/core/test_validators.py @@ -0,0 +1,239 @@ +"""Tests for core validators. + +This module tests the validation functions that prevent DoS attacks +and ensure data integrity. +""" + +from __future__ import annotations + +import pytest + +from cli_patterns.core.validators import ( + MAX_COLLECTION_SIZE, + MAX_JSON_DEPTH, + ValidationError, + validate_collection_size, + validate_json_depth, + validate_state_value, +) + +pytestmark = pytest.mark.unit + + +class TestDepthValidation: + """Test JSON depth validation.""" + + def test_accepts_shallow_dict(self) -> None: + """Should accept dict within depth limit.""" + data = {"a": {"b": {"c": 1}}} + validate_json_depth(data) # Should not raise + + def test_accepts_shallow_list(self) -> None: + """Should accept list within depth limit.""" + data = [[[[1]]]] + validate_json_depth(data) # Should not raise + + def test_accepts_empty_dict(self) -> None: + """Should accept empty dict.""" + validate_json_depth({}) + + def test_accepts_empty_list(self) -> None: + """Should accept empty list.""" + validate_json_depth([]) + + def test_accepts_primitives(self) -> None: + """Should accept primitive values.""" + validate_json_depth("string") + validate_json_depth(123) + validate_json_depth(45.67) + validate_json_depth(True) + validate_json_depth(None) + + def test_rejects_deeply_nested_dict(self) -> None: + """Should reject dict exceeding depth limit.""" + # Create deeply nested dict + data: dict[str, any] = {"value": 1} + for _ in range(MAX_JSON_DEPTH + 1): + data = {"nested": data} + + with pytest.raises(ValidationError, match="nesting too deep"): + validate_json_depth(data) + + def test_rejects_deeply_nested_list(self) -> None: + """Should reject list exceeding depth limit.""" + data: list[any] = [1] + for _ in range(MAX_JSON_DEPTH + 1): + data = [data] + + with pytest.raises(ValidationError, match="nesting too deep"): + validate_json_depth(data) + + def test_rejects_mixed_nested_structure(self) -> None: + """Should reject mixed dict/list exceeding depth.""" + data: any = [{"nested": [{"deep": 1}]}] + for _ in range(MAX_JSON_DEPTH): + data = [data] + + with pytest.raises(ValidationError, match="nesting too deep"): + validate_json_depth(data) + + def test_custom_depth_limit(self) -> None: + """Should respect custom depth limit.""" + data = {"a": {"b": {"c": 1}}} + + validate_json_depth(data, max_depth=10) # OK + with pytest.raises(ValidationError): + validate_json_depth(data, max_depth=2) # Too deep + + def test_depth_counts_correctly(self) -> None: + """Should count depth correctly for various structures.""" + # Depth 0: primitives + validate_json_depth(1, max_depth=0) + + # Depth 1: single-level dict/list + validate_json_depth({"a": 1}, max_depth=1) + validate_json_depth([1, 2], max_depth=1) + + # Depth 2: nested + validate_json_depth({"a": {"b": 1}}, max_depth=2) + validate_json_depth([[1]], max_depth=2) + + +class TestSizeValidation: + """Test collection size validation.""" + + def test_accepts_small_dict(self) -> None: + """Should accept dict within size limit.""" + data = {f"key{i}": i for i in range(100)} + validate_collection_size(data) # Should not raise + + def test_accepts_small_list(self) -> None: + """Should accept list within size limit.""" + data = list(range(100)) + validate_collection_size(data) + + def test_accepts_empty_collections(self) -> None: + """Should accept empty collections.""" + validate_collection_size({}) + validate_collection_size([]) + + def test_accepts_primitives(self) -> None: + """Should accept primitive values.""" + validate_collection_size("string") + validate_collection_size(123) + validate_collection_size(True) + validate_collection_size(None) + + def test_rejects_large_dict(self) -> None: + """Should reject dict exceeding size limit.""" + data = {f"key{i}": i for i in range(MAX_COLLECTION_SIZE + 1)} + + with pytest.raises(ValidationError, match="too large"): + validate_collection_size(data) + + def test_rejects_large_list(self) -> None: + """Should reject list exceeding size limit.""" + data = list(range(MAX_COLLECTION_SIZE + 1)) + + with pytest.raises(ValidationError, match="too large"): + validate_collection_size(data) + + def test_counts_nested_elements(self) -> None: + """Should count elements in nested structures.""" + # Create nested structure with many elements + data = {f"key{i}": list(range(100)) for i in range(20)} + # Total: 20 keys + 20*100 list items = 2020 elements + + with pytest.raises(ValidationError, match="too large"): + validate_collection_size(data, max_size=1000) + + def test_counts_deeply_nested(self) -> None: + """Should count all elements in deeply nested structures.""" + data = { + "level1": { + "level2": {"level3": [1, 2, 3, 4, 5]}, + "level2b": [1, 2, 3], + } + } + # Total: 3 dicts + 8 list items = 11 elements + validate_collection_size(data, max_size=15) + + def test_custom_size_limit(self) -> None: + """Should respect custom size limit.""" + data = list(range(50)) + + validate_collection_size(data, max_size=100) # OK + with pytest.raises(ValidationError): + validate_collection_size(data, max_size=40) # Too large + + +class TestStateValueValidation: + """Test combined state value validation.""" + + def test_accepts_valid_state_value(self) -> None: + """Should accept valid state value.""" + data = {"user": {"name": "test", "age": 30, "tags": ["admin", "user"]}} + validate_state_value(data) + + def test_rejects_too_deep(self) -> None: + """Should reject value that's too deep.""" + data: dict[str, any] = {"value": 1} + for _ in range(MAX_JSON_DEPTH + 1): + data = {"nested": data} + + with pytest.raises(ValidationError, match="nesting too deep"): + validate_state_value(data) + + def test_rejects_too_large(self) -> None: + """Should reject value that's too large.""" + data = list(range(MAX_COLLECTION_SIZE + 1)) + + with pytest.raises(ValidationError, match="too large"): + validate_state_value(data) + + def test_validates_complex_structures(self) -> None: + """Should validate complex real-world structures.""" + # Simulate a realistic configuration + config = { + "database": {"host": "localhost", "port": 5432, "name": "mydb"}, + "services": [ + {"name": "api", "port": 8000, "workers": 4}, + {"name": "worker", "port": 8001, "workers": 2}, + ], + "features": {"auth": True, "cache": True, "debug": False}, + } + validate_state_value(config) # Should pass + + +class TestEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_exactly_at_depth_limit(self) -> None: + """Should accept structure exactly at depth limit.""" + data: any = 1 + for _ in range(MAX_JSON_DEPTH): + data = {"nested": data} + validate_json_depth(data) # Should pass + + def test_exactly_at_size_limit(self) -> None: + """Should accept collection exactly at size limit.""" + data = list(range(MAX_COLLECTION_SIZE)) + validate_collection_size(data) # Should pass + + def test_unicode_strings(self) -> None: + """Should handle unicode strings.""" + data = {"emoji": "🚀", "chinese": "你好", "arabic": "مرحبا"} + validate_state_value(data) + + def test_mixed_types(self) -> None: + """Should handle mixed types in collections.""" + data = { + "string": "value", + "number": 42, + "float": 3.14, + "bool": True, + "null": None, + "list": [1, "two", 3.0], + "dict": {"nested": "data"}, + } + validate_state_value(data) From 2463e4a54b2cc6ed4c3b6d997faa3e74a19f8c2f Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 22:32:30 -0400 Subject: [PATCH 6/9] feat(core): add comprehensive security hardening to models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement multiple layers of security protection in core models to prevent command injection, DoS attacks, and resource exhaustion. Security Enhancements: 1. Command Injection Prevention (BashActionConfig): - Added allow_shell_features flag (default: False) - Validates commands to block shell metacharacters - Rejects: pipes, redirects, command substitution, variable expansion - Clear error messages guide users to explicit opt-in - 13 tests for injection patterns 2. DoS Protection (SessionState): - Validates option_values and variables for depth/size - Maximum 1000 options/variables - Integration with validators module - 8 tests for DoS scenarios 3. Collection Size Limits: - BranchConfig: 100 actions, 50 options, 20 menus - WizardConfig: 100 branches - Prevents memory exhaustion from config files - 6 tests for collection limits 4. Entry Branch Validation (WizardConfig): - Ensures entry_branch exists in branches list - Helpful error messages show available branches - 3 tests for validation scenarios Test Coverage: - 30 security-focused tests in test_security.py - All existing tests updated and passing - 100% coverage of new security code Breaking Changes: - Commands with shell features now require allow_shell_features=True - Wizard configs with invalid entry_branch now fail validation - Large collections/deep nesting now rejected Migration: - Set allow_shell_features=True for commands needing pipes/redirects - Ensure entry_branch matches a branch ID - Review any configs with >100 branches or >50 options Part of security hardening (Priorities 1, 2, 3) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/models.py | 173 +++++++++++++- tests/unit/core/test_models.py | 18 +- tests/unit/core/test_security.py | 377 +++++++++++++++++++++++++++++++ 3 files changed, 556 insertions(+), 12 deletions(-) create mode 100644 tests/unit/core/test_security.py diff --git a/src/cli_patterns/core/models.py b/src/cli_patterns/core/models.py index 9901dd4..228f89e 100644 --- a/src/cli_patterns/core/models.py +++ b/src/cli_patterns/core/models.py @@ -10,9 +10,10 @@ from __future__ import annotations +import re from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from cli_patterns.core.types import ( ActionId, @@ -20,6 +21,7 @@ MenuId, OptionKey, ) +from cli_patterns.core.validators import ValidationError, validate_state_value # StateValue is defined as Any for Pydantic compatibility # The actual type constraint (JSON-serializable) is enforced at serialization time @@ -65,6 +67,11 @@ class BashActionConfig(BaseConfig): """Configuration for bash command actions. Executes a bash command with optional environment variables. + + Security: + By default, shell features (pipes, redirects, command substitution) are + disabled to prevent command injection attacks. Set allow_shell_features=True + only for trusted commands that require shell features. """ type: Literal["bash"] = Field( @@ -77,6 +84,47 @@ class BashActionConfig(BaseConfig): env: dict[str, str] = Field( default_factory=dict, description="Environment variables for command" ) + allow_shell_features: bool = Field( + default=False, + description=( + "Allow shell features (pipes, redirects, command substitution). " + "SECURITY RISK: Only enable for trusted commands. When False, " + "command is executed without shell interpretation to prevent injection." + ), + ) + + @model_validator(mode="after") + def validate_command_safety(self) -> BashActionConfig: + """Validate command doesn't contain dangerous patterns. + + This validator blocks shell injection attempts when allow_shell_features=False. + + Returns: + Validated model + + Raises: + ValueError: If command contains dangerous shell metacharacters + """ + if not self.allow_shell_features: + # Dangerous shell metacharacters + dangerous_patterns = [ + (r"[;&|]", "command chaining (;, &, |)"), + (r"`", "command substitution (backticks)"), + (r"\$\(", "command substitution ($())"), + (r"[<>]", "redirection (<, >)"), + (r"\$\{", "variable expansion (${})"), + (r"^\s*\w+\s*=", "variable assignment"), + ] + + for pattern, description in dangerous_patterns: + if re.search(pattern, self.command): + raise ValueError( + f"Command contains {description}. " + f"Set allow_shell_features=True to enable shell features " + f"(SECURITY RISK: only do this for trusted commands)." + ) + + return self class PythonActionConfig(BaseConfig): @@ -220,6 +268,11 @@ class BranchConfig(BaseConfig): A branch represents a screen/step in the wizard with actions, options, and navigation menus. + + Limits: + - Actions: 100 maximum + - Options: 50 maximum + - Menus: 20 maximum """ id: BranchId = Field(description="Unique branch identifier") @@ -235,6 +288,34 @@ class BranchConfig(BaseConfig): default_factory=list, description="Navigation menus in this branch" ) + @field_validator("actions") + @classmethod + def validate_actions_size( + cls, v: list[ActionConfigUnion] + ) -> list[ActionConfigUnion]: + """Validate number of actions is reasonable.""" + if len(v) > 100: + raise ValueError("Too many actions in branch (maximum: 100)") + return v + + @field_validator("options") + @classmethod + def validate_options_size( + cls, v: list[OptionConfigUnion] + ) -> list[OptionConfigUnion]: + """Validate number of options is reasonable.""" + if len(v) > 50: + raise ValueError("Too many options in branch (maximum: 50)") + return v + + @field_validator("menus") + @classmethod + def validate_menus_size(cls, v: list[MenuConfig]) -> list[MenuConfig]: + """Validate number of menus is reasonable.""" + if len(v) > 20: + raise ValueError("Too many menus in branch (maximum: 20)") + return v + # ============================================================================ # Wizard Configuration @@ -246,6 +327,9 @@ class WizardConfig(BaseConfig): This is the top-level configuration that defines an entire wizard, including all branches and the entry point. + + Limits: + - Branches: 100 maximum """ name: str = Field(description="Wizard name (identifier)") @@ -256,8 +340,24 @@ class WizardConfig(BaseConfig): ) branches: list[BranchConfig] = Field(description="All branches in the wizard tree") - # TODO: Add validator to ensure entry_branch exists in branches - # This would be done with @model_validator in Pydantic v2 + @field_validator("branches") + @classmethod + def validate_branches_size(cls, v: list[BranchConfig]) -> list[BranchConfig]: + """Validate number of branches is reasonable.""" + if len(v) > 100: + raise ValueError("Too many branches in wizard (maximum: 100)") + return v + + @model_validator(mode="after") + def validate_entry_branch_exists(self) -> WizardConfig: + """Validate that entry_branch exists in branches list.""" + branch_ids = {b.id for b in self.branches} + if self.entry_branch not in branch_ids: + raise ValueError( + f"entry_branch '{self.entry_branch}' not found in branches. " + f"Available branches: {sorted(branch_ids)}" + ) + return self # ============================================================================ @@ -270,6 +370,11 @@ class SessionState(StrictModel): This model combines both wizard state (navigation, options) and parser state (mode, history) into a single unified state. + + Security: + All StateValue fields (option_values, variables) are validated for: + - Maximum nesting depth (50 levels) + - Maximum collection size (1000 items) """ # Wizard state @@ -295,6 +400,68 @@ class SessionState(StrictModel): default_factory=list, description="Command history for readline/recall" ) + @field_validator("option_values") + @classmethod + def validate_option_values( + cls, v: dict[OptionKey, StateValue] + ) -> dict[OptionKey, StateValue]: + """Validate all option values meet safety requirements. + + Checks each value for: + - Maximum nesting depth (50 levels) + - Maximum collection size (1000 items) + + Args: + v: Option values dict to validate + + Returns: + Validated dict + + Raises: + ValueError: If any value violates safety limits + """ + # Check total number of options + if len(v) > 1000: + raise ValueError("Too many options (maximum: 1000)") + + # Validate each value + for key, value in v.items(): + try: + validate_state_value(value) + except ValidationError as e: + raise ValueError(f"Invalid value for option '{key}': {e}") from e + + return v + + @field_validator("variables") + @classmethod + def validate_variables(cls, v: dict[str, StateValue]) -> dict[str, StateValue]: + """Validate all variables meet safety requirements. + + Checks each value for: + - Maximum nesting depth (50 levels) + - Maximum collection size (1000 items) + + Args: + v: Variables dict to validate + + Returns: + Validated dict + + Raises: + ValueError: If any value violates safety limits + """ + if len(v) > 1000: + raise ValueError("Too many variables (maximum: 1000)") + + for key, value in v.items(): + try: + validate_state_value(value) + except ValidationError as e: + raise ValueError(f"Invalid value for variable '{key}': {e}") from e + + return v + # ============================================================================ # Result Types diff --git a/tests/unit/core/test_models.py b/tests/unit/core/test_models.py index 861f839..12764ac 100644 --- a/tests/unit/core/test_models.py +++ b/tests/unit/core/test_models.py @@ -472,17 +472,17 @@ def test_wizard_config_validates_entry_branch_exists(self) -> None: """ GIVEN: Wizard with entry_branch that doesn't exist in branches WHEN: Creating a WizardConfig - THEN: Validation should succeed (validation is runtime, not construction) + THEN: Validation error is raised """ branch = BranchConfig(id=make_branch_id("main"), title="Main") - # Entry branch points to non-existent branch - this is allowed at construction - config = WizardConfig( - name="test-wizard", - version="1.0.0", - entry_branch=make_branch_id("nonexistent"), - branches=[branch], - ) - assert config.entry_branch == make_branch_id("nonexistent") + # Entry branch points to non-existent branch - should raise validation error + with pytest.raises(ValidationError, match="entry_branch.*not found"): + WizardConfig( + name="test-wizard", + version="1.0.0", + entry_branch=make_branch_id("nonexistent"), + branches=[branch], + ) def test_wizard_config_multiple_branches(self) -> None: """ diff --git a/tests/unit/core/test_security.py b/tests/unit/core/test_security.py new file mode 100644 index 0000000..8bd19ed --- /dev/null +++ b/tests/unit/core/test_security.py @@ -0,0 +1,377 @@ +"""Security tests for core models. + +This module tests command injection prevention, DoS protection, +and collection size limits. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from cli_patterns.core.models import ( + BashActionConfig, + BranchConfig, + SessionState, + StringOptionConfig, + WizardConfig, +) +from cli_patterns.core.types import make_action_id, make_branch_id, make_option_key + +pytestmark = pytest.mark.unit + + +class TestCommandInjectionPrevention: + """Test command injection prevention in BashActionConfig.""" + + def test_rejects_command_chaining_semicolon(self) -> None: + """Should reject commands with semicolon chaining.""" + with pytest.raises(ValidationError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo hello; rm -rf /", + allow_shell_features=False, + ) + + def test_rejects_command_chaining_ampersand(self) -> None: + """Should reject commands with ampersand chaining.""" + with pytest.raises(ValidationError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo hello & rm -rf /", + allow_shell_features=False, + ) + + def test_rejects_command_chaining_pipe(self) -> None: + """Should reject commands with pipe.""" + with pytest.raises(ValidationError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat file | grep secret", + allow_shell_features=False, + ) + + def test_rejects_command_substitution_dollar_paren(self) -> None: + """Should reject commands with $() substitution.""" + with pytest.raises(ValidationError, match="command substitution"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo $(whoami)", + allow_shell_features=False, + ) + + def test_rejects_command_substitution_backtick(self) -> None: + """Should reject commands with backtick substitution.""" + with pytest.raises(ValidationError, match="command substitution"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo `whoami`", + allow_shell_features=False, + ) + + def test_rejects_output_redirection(self) -> None: + """Should reject commands with output redirection.""" + with pytest.raises(ValidationError, match="redirection"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo secret > /tmp/leak", + allow_shell_features=False, + ) + + def test_rejects_input_redirection(self) -> None: + """Should reject commands with input redirection.""" + with pytest.raises(ValidationError, match="redirection"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat < /etc/passwd", + allow_shell_features=False, + ) + + def test_rejects_variable_expansion(self) -> None: + """Should reject commands with variable expansion.""" + with pytest.raises(ValidationError, match="variable expansion"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo ${PATH}", + allow_shell_features=False, + ) + + def test_rejects_variable_assignment(self) -> None: + """Should reject commands with variable assignment.""" + with pytest.raises(ValidationError, match="variable assignment"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="PATH=/evil/path kubectl apply", + allow_shell_features=False, + ) + + def test_allows_safe_command(self) -> None: + """Should allow safe commands without shell features.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="kubectl apply -f deploy.yaml", + allow_shell_features=False, + ) + assert config.command == "kubectl apply -f deploy.yaml" + assert config.allow_shell_features is False + + def test_allows_command_with_arguments(self) -> None: + """Should allow commands with normal arguments.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="docker run --rm -it ubuntu:latest /bin/bash", + allow_shell_features=False, + ) + assert "docker run" in config.command + + def test_allows_dangerous_command_with_flag(self) -> None: + """Should allow dangerous commands when explicitly enabled.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat file | grep secret", + allow_shell_features=True, # Explicit opt-in + ) + assert config.command == "cat file | grep secret" + assert config.allow_shell_features is True + + def test_allows_all_shell_features_when_enabled(self) -> None: + """Should allow all shell features when flag is True.""" + commands = [ + "echo hello; echo world", + "cat file | grep pattern", + "echo $(date)", + "echo `whoami`", + "cat > output.txt", + "cmd < input.txt", + "echo ${VAR}", + "PATH=/new/path cmd", + ] + + for cmd in commands: + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command=cmd, + allow_shell_features=True, + ) + assert config.command == cmd + + +class TestDoSProtection: + """Test DoS protection via depth and size validation.""" + + def test_rejects_deeply_nested_option_value(self) -> None: + """Should reject deeply nested structures in option values.""" + # Create deeply nested dict (> 50 levels) + deep_value: dict[str, any] = {"value": 1} + for _ in range(55): + deep_value = {"nested": deep_value} + + with pytest.raises(ValidationError, match="nesting too deep"): + SessionState(option_values={make_option_key("test"): deep_value}) + + def test_rejects_large_option_value(self) -> None: + """Should reject excessively large option values.""" + # Create list with > 1000 items + large_value = list(range(1500)) + + with pytest.raises(ValidationError, match="too large"): + SessionState(option_values={make_option_key("test"): large_value}) + + def test_rejects_deeply_nested_variable(self) -> None: + """Should reject deeply nested structures in variables.""" + deep_value: dict[str, any] = {"value": 1} + for _ in range(55): + deep_value = {"nested": deep_value} + + with pytest.raises(ValidationError, match="nesting too deep"): + SessionState(variables={"test": deep_value}) + + def test_rejects_large_variable(self) -> None: + """Should reject excessively large variables.""" + large_value = list(range(1500)) + + with pytest.raises(ValidationError, match="too large"): + SessionState(variables={"test": large_value}) + + def test_rejects_too_many_options(self) -> None: + """Should reject too many options.""" + options = {make_option_key(f"opt{i}"): i for i in range(1001)} + + with pytest.raises(ValidationError, match="Too many options"): + SessionState(option_values=options) + + def test_rejects_too_many_variables(self) -> None: + """Should reject too many variables.""" + variables = {f"var{i}": i for i in range(1001)} + + with pytest.raises(ValidationError, match="Too many variables"): + SessionState(variables=variables) + + def test_accepts_valid_nested_value(self) -> None: + """Should accept reasonably nested values.""" + valid_value = {"level1": {"level2": {"level3": {"level4": {"level5": "data"}}}}} + state = SessionState(option_values={make_option_key("test"): valid_value}) + assert state.option_values[make_option_key("test")] == valid_value + + def test_accepts_valid_large_value(self) -> None: + """Should accept moderately large values.""" + valid_value = list(range(500)) + state = SessionState(option_values={make_option_key("test"): valid_value}) + assert len(state.option_values[make_option_key("test")]) == 500 + + +class TestCollectionLimits: + """Test collection size limits in configuration models.""" + + def test_rejects_too_many_actions(self) -> None: + """Should reject branch with too many actions.""" + with pytest.raises(ValidationError, match="Too many actions"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + actions=[ + BashActionConfig( + id=make_action_id(f"action{i}"), + name=f"Action {i}", + command="echo test", + ) + for i in range(101) # Over limit + ], + ) + + def test_rejects_too_many_options(self) -> None: + """Should reject branch with too many options.""" + with pytest.raises(ValidationError, match="Too many options"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + options=[ + StringOptionConfig( + id=make_option_key(f"opt{i}"), + name=f"Option {i}", + description="Test", + ) + for i in range(51) # Over limit + ], + ) + + def test_rejects_too_many_menus(self) -> None: + """Should reject branch with too many menus.""" + from cli_patterns.core.models import MenuConfig + + with pytest.raises(ValidationError, match="Too many menus"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + menus=[ + MenuConfig( + id=make_action_id(f"menu{i}"), + label=f"Menu {i}", + target=make_branch_id("target"), + ) + for i in range(21) # Over limit + ], + ) + + def test_rejects_too_many_branches(self) -> None: + """Should reject wizard with too many branches.""" + with pytest.raises(ValidationError, match="Too many branches"): + WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("branch0"), + branches=[ + BranchConfig(id=make_branch_id(f"branch{i}"), title=f"Branch {i}") + for i in range(101) # Over limit + ], + ) + + def test_accepts_maximum_actions(self) -> None: + """Should accept exactly 100 actions.""" + config = BranchConfig( + id=make_branch_id("test"), + title="Test", + actions=[ + BashActionConfig( + id=make_action_id(f"action{i}"), + name=f"Action {i}", + command="echo test", + ) + for i in range(100) # Exactly at limit + ], + ) + assert len(config.actions) == 100 + + def test_accepts_maximum_branches(self) -> None: + """Should accept exactly 100 branches.""" + config = WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("branch0"), + branches=[ + BranchConfig(id=make_branch_id(f"branch{i}"), title=f"Branch {i}") + for i in range(100) # Exactly at limit + ], + ) + assert len(config.branches) == 100 + + +class TestWizardValidation: + """Test wizard-specific validation.""" + + def test_rejects_nonexistent_entry_branch(self) -> None: + """Should reject wizard with entry_branch not in branches.""" + with pytest.raises(ValidationError, match="entry_branch.*not found"): + WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("nonexistent"), + branches=[ + BranchConfig(id=make_branch_id("main"), title="Main"), + BranchConfig(id=make_branch_id("settings"), title="Settings"), + ], + ) + + def test_accepts_valid_entry_branch(self) -> None: + """Should accept wizard with valid entry_branch.""" + config = WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("main"), + branches=[ + BranchConfig(id=make_branch_id("main"), title="Main"), + BranchConfig(id=make_branch_id("settings"), title="Settings"), + ], + ) + assert config.entry_branch == make_branch_id("main") + + def test_error_message_shows_available_branches(self) -> None: + """Should show available branches in error message.""" + try: + WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("invalid"), + branches=[ + BranchConfig(id=make_branch_id("main"), title="Main"), + BranchConfig(id=make_branch_id("settings"), title="Settings"), + ], + ) + pytest.fail("Should have raised ValidationError") + except ValidationError as e: + error_str = str(e) + assert "Available branches" in error_str + assert "main" in error_str or "settings" in error_str From 1340819538c253884f958af8a440c20e1bfc033b Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 23:34:04 -0400 Subject: [PATCH 7/9] feat(security): add command injection prevention to SubprocessExecutor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Priority 1 (CRITICAL) security enhancement: command injection prevention. Changes: - Modified SubprocessExecutor to use create_subprocess_exec() by default (safe mode) - Added allow_shell_features parameter for explicit opt-in to shell features - Commands are parsed with shlex and executed without shell interpretation by default - Added security warning logging when shell features are enabled - Invalid shell syntax is caught and reported gracefully Security Impact: - Command injection attacks are now completely prevented by default - Shell metacharacters (pipes, redirects, command substitution) are treated as literal - Only explicitly trusted commands with allow_shell_features=True use shell Breaking Change: - Commands now execute without shell by default - Shell features require explicit allow_shell_features=True parameter Tests Added: - 15 command injection unit tests - 13 security integration tests - Updated 8 existing subprocess executor tests All 782 tests passing. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../execution/subprocess_executor.py | 76 ++++++-- tests/integration/test_subprocess_security.py | 151 ++++++++++++++++ tests/unit/core/test_command_injection.py | 166 ++++++++++++++++++ .../execution/test_subprocess_executor.py | 30 ++-- 4 files changed, 394 insertions(+), 29 deletions(-) create mode 100644 tests/integration/test_subprocess_security.py create mode 100644 tests/unit/core/test_command_injection.py diff --git a/src/cli_patterns/execution/subprocess_executor.py b/src/cli_patterns/execution/subprocess_executor.py index 7f6738a..d812229 100644 --- a/src/cli_patterns/execution/subprocess_executor.py +++ b/src/cli_patterns/execution/subprocess_executor.py @@ -13,7 +13,9 @@ from __future__ import annotations import asyncio +import logging import os +import shlex from typing import Optional, Union from rich.console import Console @@ -23,6 +25,8 @@ from ..ui.design.registry import theme_registry from ..ui.design.tokens import StatusToken +logger = logging.getLogger(__name__) + class CommandResult: """Result of a command execution.""" @@ -83,6 +87,7 @@ async def run( timeout: Optional[float] = None, cwd: Optional[str] = None, env: Optional[dict[str, str]] = None, + allow_shell_features: bool = False, ) -> CommandResult: """Execute a command asynchronously with themed output streaming. @@ -91,22 +96,40 @@ async def run( timeout: Command timeout in seconds (uses default if None) cwd: Working directory for the command env: Environment variables for the command + allow_shell_features: Allow shell features (pipes, redirects, etc.). + SECURITY WARNING: Only enable for trusted commands. When False, + command is executed without shell to prevent injection attacks. Returns: CommandResult with exit code and captured output """ timeout = timeout or self.default_timeout - # Show running status - if self.stream_output: - running_style = theme_registry.resolve(StatusToken.RUNNING) - self.console.print(Text(f"Running: {command}", style=running_style)) - - # Prepare command + # Prepare command list for display and execution if isinstance(command, list): + command_list = command command_str = " ".join(command) else: command_str = command + # Parse string into list for safe execution + try: + command_list = shlex.split(command_str) + except ValueError as e: + # Invalid shell syntax + stderr_msg = f"Invalid command syntax: {e}" + if self.stream_output: + error_style = theme_registry.resolve(StatusToken.ERROR) + self.console.print(Text(stderr_msg, style=error_style)) + return CommandResult( + exit_code=-1, + stdout="", + stderr=stderr_msg, + ) + + # Show running status + if self.stream_output: + running_style = theme_registry.resolve(StatusToken.RUNNING) + self.console.print(Text(f"Running: {command_str}", style=running_style)) # Merge environment variables process_env = os.environ.copy() @@ -122,14 +145,39 @@ async def run( process = None # Initialize process variable try: - # Create subprocess - process = await asyncio.create_subprocess_shell( - command_str, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=cwd, - env=process_env, - ) + # Create subprocess - use shell only if explicitly allowed + if allow_shell_features: + # SECURITY WARNING: Shell features enabled + logger.warning( + f"Executing command with shell features enabled: {command_str}" + ) + process = await asyncio.create_subprocess_shell( + command_str, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=process_env, + ) + else: + # Safe execution without shell (prevents injection) + if not command_list: + stderr_msg = "Empty command" + if self.stream_output: + error_style = theme_registry.resolve(StatusToken.ERROR) + self.console.print(Text(stderr_msg, style=error_style)) + return CommandResult( + exit_code=-1, + stdout="", + stderr=stderr_msg, + ) + + process = await asyncio.create_subprocess_exec( + *command_list, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=process_env, + ) # Create tasks for reading streams stdout_task = asyncio.create_task( diff --git a/tests/integration/test_subprocess_security.py b/tests/integration/test_subprocess_security.py new file mode 100644 index 0000000..7a2bbb2 --- /dev/null +++ b/tests/integration/test_subprocess_security.py @@ -0,0 +1,151 @@ +"""Integration tests for subprocess security features. + +This module tests the security features of SubprocessExecutor, including +command injection prevention through the allow_shell_features flag. +""" + +import pytest + +from cli_patterns.execution.subprocess_executor import SubprocessExecutor + + +@pytest.mark.asyncio +@pytest.mark.integration +class TestSubprocessSecurity: + """Test security features of SubprocessExecutor.""" + + async def test_safe_command_without_shell(self) -> None: + """Should execute safe command without shell features.""" + executor = SubprocessExecutor(stream_output=False) + result = await executor.run("echo hello", allow_shell_features=False) + + assert result.success + assert "hello" in result.stdout + + async def test_blocks_command_injection_without_shell(self) -> None: + """Should prevent command injection when shell features disabled.""" + executor = SubprocessExecutor(stream_output=False) + + # This should fail because semicolon will be treated as literal argument + # The command "echo" will receive "test;whoami" as a single argument + result = await executor.run("echo test;whoami", allow_shell_features=False) + + # Should succeed (echo accepts the argument) + assert result.success + # But semicolon should be in the output as a literal character + assert "test;whoami" in result.stdout or ";" in result.stdout + + async def test_allows_shell_features_when_enabled(self) -> None: + """Should allow shell features when explicitly enabled.""" + executor = SubprocessExecutor(stream_output=False) + + # This should work with shell features enabled + result = await executor.run( + "echo hello && echo world", allow_shell_features=True + ) + + assert result.success + assert "hello" in result.stdout + assert "world" in result.stdout + + async def test_pipe_fails_without_shell(self) -> None: + """Should not support pipes when shell features disabled.""" + executor = SubprocessExecutor(stream_output=False) + + # Pipe should not work without shell + result = await executor.run("echo test | cat", allow_shell_features=False) + + # The command will fail because "|" will be treated as a literal argument + # and echo doesn't accept that as a valid option + assert not result.success or "|" in result.stdout + + async def test_pipe_works_with_shell(self) -> None: + """Should support pipes when shell features enabled.""" + executor = SubprocessExecutor(stream_output=False) + + # Pipe should work with shell enabled + result = await executor.run("echo test | cat", allow_shell_features=True) + + assert result.success + assert "test" in result.stdout + + async def test_command_substitution_fails_without_shell(self) -> None: + """Should not execute command substitution without shell.""" + executor = SubprocessExecutor(stream_output=False) + + # Command substitution should not work + result = await executor.run("echo $(whoami)", allow_shell_features=False) + + # Should treat $() as literal text + assert not result.success or "$" in result.stdout or "(" in result.stdout + + async def test_command_substitution_works_with_shell(self) -> None: + """Should execute command substitution with shell features.""" + executor = SubprocessExecutor(stream_output=False) + + # Command substitution should work with shell + result = await executor.run("echo $(echo test)", allow_shell_features=True) + + assert result.success + assert "test" in result.stdout + + async def test_redirection_fails_without_shell(self) -> None: + """Should not support redirection without shell features.""" + executor = SubprocessExecutor(stream_output=False) + + # Redirection should not work + result = await executor.run( + "echo test > /tmp/test_output", allow_shell_features=False + ) + + # Should treat > as literal argument + assert not result.success or ">" in result.stdout + + async def test_handles_commands_with_arguments(self) -> None: + """Should handle commands with normal arguments safely.""" + executor = SubprocessExecutor(stream_output=False) + + result = await executor.run("echo -n hello world", allow_shell_features=False) + + assert result.success + assert "hello world" in result.stdout or "helloworld" in result.stdout + + async def test_handles_quoted_arguments(self) -> None: + """Should handle quoted arguments correctly.""" + executor = SubprocessExecutor(stream_output=False) + + result = await executor.run('echo "hello world"', allow_shell_features=False) + + assert result.success + assert "hello world" in result.stdout + + async def test_default_is_safe_mode(self) -> None: + """Should default to safe mode (shell features disabled).""" + executor = SubprocessExecutor(stream_output=False) + + # Without specifying allow_shell_features, should default to False + result = await executor.run("echo test") + + assert result.success + assert "test" in result.stdout + + async def test_invalid_command_syntax(self) -> None: + """Should handle invalid shell syntax gracefully.""" + executor = SubprocessExecutor(stream_output=False) + + # Unmatched quotes should fail in shlex.split + result = await executor.run('echo "unterminated', allow_shell_features=False) + + assert not result.success + assert "Invalid command syntax" in result.stderr + + async def test_command_not_found(self) -> None: + """Should handle command not found gracefully.""" + executor = SubprocessExecutor(stream_output=False) + + result = await executor.run( + "nonexistent_command_xyz", allow_shell_features=False + ) + + assert not result.success + assert result.exit_code != 0 diff --git a/tests/unit/core/test_command_injection.py b/tests/unit/core/test_command_injection.py new file mode 100644 index 0000000..deddc4f --- /dev/null +++ b/tests/unit/core/test_command_injection.py @@ -0,0 +1,166 @@ +"""Tests for command injection prevention measures. + +This module tests security measures that prevent command injection attacks +through shell metacharacter exploitation. +""" + +import pytest + +from cli_patterns.core.models import BashActionConfig +from cli_patterns.core.types import make_action_id + + +class TestCommandInjectionPrevention: + """Test command injection prevention measures.""" + + def test_rejects_command_chaining_semicolon(self) -> None: + """Should reject commands with semicolon chaining.""" + with pytest.raises(ValueError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo hello; rm -rf /", + allow_shell_features=False, + ) + + def test_rejects_command_chaining_ampersand(self) -> None: + """Should reject commands with ampersand chaining.""" + with pytest.raises(ValueError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo hello & rm -rf /", + allow_shell_features=False, + ) + + def test_rejects_command_chaining_pipe(self) -> None: + """Should reject commands with pipe chaining.""" + with pytest.raises(ValueError, match="command chaining"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat file | grep secret", + allow_shell_features=False, + ) + + def test_rejects_command_substitution_dollar_paren(self) -> None: + """Should reject commands with $() command substitution.""" + with pytest.raises(ValueError, match="command substitution"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo $(whoami)", + allow_shell_features=False, + ) + + def test_rejects_command_substitution_backticks(self) -> None: + """Should reject commands with backtick command substitution.""" + with pytest.raises(ValueError, match="command substitution"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo `whoami`", + allow_shell_features=False, + ) + + def test_rejects_output_redirection(self) -> None: + """Should reject commands with output redirection.""" + with pytest.raises(ValueError, match="redirection"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo secret > /tmp/stolen", + allow_shell_features=False, + ) + + def test_rejects_input_redirection(self) -> None: + """Should reject commands with input redirection.""" + with pytest.raises(ValueError, match="redirection"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat < /etc/passwd", + allow_shell_features=False, + ) + + def test_rejects_variable_expansion(self) -> None: + """Should reject commands with ${} variable expansion.""" + with pytest.raises(ValueError, match="variable expansion"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo ${HOME}", + allow_shell_features=False, + ) + + def test_rejects_variable_assignment(self) -> None: + """Should reject commands with variable assignment.""" + with pytest.raises(ValueError, match="variable assignment"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="PATH=/evil/path ls", + allow_shell_features=False, + ) + + def test_allows_safe_command_without_shell_features(self) -> None: + """Should allow safe commands without shell features.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="kubectl apply -f deploy.yaml", + allow_shell_features=False, + ) + assert config.command == "kubectl apply -f deploy.yaml" + assert config.allow_shell_features is False + + def test_allows_command_with_arguments(self) -> None: + """Should allow commands with normal arguments.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="docker run --rm -v /data:/data myimage", + allow_shell_features=False, + ) + assert config.command == "docker run --rm -v /data:/data myimage" + + def test_allows_dangerous_command_with_explicit_flag(self) -> None: + """Should allow dangerous commands when explicitly enabled.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat file | grep secret", + allow_shell_features=True, # Explicit opt-in + ) + assert config.command == "cat file | grep secret" + assert config.allow_shell_features is True + + def test_allows_all_shell_features_when_enabled(self) -> None: + """Should allow all shell features when flag is True.""" + # This should NOT raise even with all dangerous patterns + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="cat ${FILE} | grep pattern > output.txt && echo done", + allow_shell_features=True, + ) + assert config.allow_shell_features is True + + def test_default_allow_shell_features_is_false(self) -> None: + """Should default to allow_shell_features=False for security.""" + config = BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo hello", + ) + assert config.allow_shell_features is False + + def test_error_message_suggests_fix(self) -> None: + """Should provide helpful error message with fix suggestion.""" + with pytest.raises(ValueError, match="Set allow_shell_features=True to enable"): + BashActionConfig( + id=make_action_id("test"), + name="Test", + command="echo test | cat", + allow_shell_features=False, + ) diff --git a/tests/unit/execution/test_subprocess_executor.py b/tests/unit/execution/test_subprocess_executor.py index 2cbc513..0eaf10f 100644 --- a/tests/unit/execution/test_subprocess_executor.py +++ b/tests/unit/execution/test_subprocess_executor.py @@ -70,7 +70,7 @@ def executor(self, console): @pytest.mark.asyncio async def test_successful_command(self, executor, console): """Test successful command execution.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 0 @@ -95,7 +95,7 @@ async def test_successful_command(self, executor, console): @pytest.mark.asyncio async def test_failed_command(self, executor, console): """Test failed command execution.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 1 @@ -116,7 +116,7 @@ async def test_failed_command(self, executor, console): @pytest.mark.asyncio async def test_command_not_found(self, executor, console): """Test command not found error.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: mock_create.side_effect = FileNotFoundError("Command not found") result = await executor.run("nonexistent-command") @@ -129,7 +129,7 @@ async def test_command_not_found(self, executor, console): @pytest.mark.asyncio async def test_permission_denied(self, executor, console): """Test permission denied error.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: mock_create.side_effect = PermissionError("Permission denied") result = await executor.run("/root/protected") @@ -143,7 +143,7 @@ async def test_permission_denied(self, executor, console): @pytest.mark.slow async def test_timeout(self, executor, console): """Test command timeout.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process that times out mock_process = AsyncMock() mock_process.returncode = None @@ -166,7 +166,7 @@ async def test_timeout(self, executor, console): @pytest.mark.asyncio async def test_keyboard_interrupt(self, executor, console): """Test keyboard interrupt handling.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: mock_create.side_effect = KeyboardInterrupt() result = await executor.run("long-running-command") @@ -178,7 +178,7 @@ async def test_keyboard_interrupt(self, executor, console): @pytest.mark.asyncio async def test_list_command(self, executor, console): """Test command as list of arguments.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 0 @@ -193,14 +193,14 @@ async def test_list_command(self, executor, console): assert result.success mock_create.assert_called_once() - # Check that list was joined into string - call_args = mock_create.call_args[0][0] - assert call_args == "ls -la /tmp" + # Check that list was passed correctly to exec + call_args = mock_create.call_args[0] + assert call_args == ("ls", "-la", "/tmp") @pytest.mark.asyncio async def test_custom_env(self, executor, console): """Test command with custom environment variables.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 0 @@ -212,7 +212,7 @@ async def test_custom_env(self, executor, console): mock_create.return_value = mock_process custom_env = {"MY_VAR": "VALUE"} - result = await executor.run("echo $MY_VAR", env=custom_env) + result = await executor.run("echo test", env=custom_env) assert result.success mock_create.assert_called_once() @@ -224,7 +224,7 @@ async def test_custom_env(self, executor, console): @pytest.mark.asyncio async def test_custom_cwd(self, executor, console): """Test command with custom working directory.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 0 @@ -248,7 +248,7 @@ async def test_no_streaming(self, console): """Test executor without output streaming.""" executor = SubprocessExecutor(console=console, stream_output=False) - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process mock_process = AsyncMock() mock_process.returncode = 0 @@ -269,7 +269,7 @@ async def test_no_streaming(self, console): @pytest.mark.asyncio async def test_binary_output_handling(self, executor, console): """Test handling of binary output that can't be decoded.""" - with patch("asyncio.create_subprocess_shell") as mock_create: + with patch("asyncio.create_subprocess_exec") as mock_create: # Mock process with binary output mock_process = AsyncMock() mock_process.returncode = 0 From 03e909b33f0207af96df8047b7bfccd1c0128de2 Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 23:34:15 -0400 Subject: [PATCH 8/9] test(core): add collection size limit tests for DoS protection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Priority 3 (MEDIUM) security testing: collection size limits. Tests Added: - Branch actions limit: max 100 actions - Branch options limit: max 50 options - Branch menus limit: max 20 menus - Wizard branches limit: max 100 branches - SessionState option_values limit: max 1000 items - SessionState variables limit: max 1000 items Coverage: - 12 new collection limit tests - Tests both boundary conditions (at limit) and violations (over limit) These tests verify the collection size validators already implemented in models.py prevent DoS attacks via memory exhaustion. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/unit/core/test_models.py | 158 +++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/tests/unit/core/test_models.py b/tests/unit/core/test_models.py index 12764ac..cdd868c 100644 --- a/tests/unit/core/test_models.py +++ b/tests/unit/core/test_models.py @@ -759,3 +759,161 @@ def test_json_deserialization(self) -> None: config = BashActionConfig(**json_data) assert config.id == make_action_id("deploy") assert config.name == "Deploy" + + +class TestCollectionLimits: + """Test collection size limits for DoS protection.""" + + def test_rejects_too_many_actions_in_branch(self) -> None: + """Should reject branch with too many actions (>100).""" + with pytest.raises(ValidationError, match="Too many actions"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + actions=[ + BashActionConfig( + id=make_action_id(f"action{i}"), + name=f"Action {i}", + command="echo test", + ) + for i in range(101) # Over limit + ], + ) + + def test_accepts_max_actions_in_branch(self) -> None: + """Should accept branch with exactly 100 actions.""" + config = BranchConfig( + id=make_branch_id("test"), + title="Test", + actions=[ + BashActionConfig( + id=make_action_id(f"action{i}"), + name=f"Action {i}", + command="echo test", + ) + for i in range(100) # At limit + ], + ) + assert len(config.actions) == 100 + + def test_rejects_too_many_options_in_branch(self) -> None: + """Should reject branch with too many options (>50).""" + with pytest.raises(ValidationError, match="Too many options"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + options=[ + StringOptionConfig( + id=make_option_key(f"option{i}"), + name=f"Option {i}", + description="Test option", + ) + for i in range(51) # Over limit + ], + ) + + def test_accepts_max_options_in_branch(self) -> None: + """Should accept branch with exactly 50 options.""" + config = BranchConfig( + id=make_branch_id("test"), + title="Test", + options=[ + StringOptionConfig( + id=make_option_key(f"option{i}"), + name=f"Option {i}", + description="Test option", + ) + for i in range(50) # At limit + ], + ) + assert len(config.options) == 50 + + def test_rejects_too_many_menus_in_branch(self) -> None: + """Should reject branch with too many menus (>20).""" + with pytest.raises(ValidationError, match="Too many menus"): + BranchConfig( + id=make_branch_id("test"), + title="Test", + menus=[ + MenuConfig( + id=make_menu_id(f"menu{i}"), + label=f"Menu {i}", + target=make_branch_id("target"), + ) + for i in range(21) # Over limit + ], + ) + + def test_accepts_max_menus_in_branch(self) -> None: + """Should accept branch with exactly 20 menus.""" + config = BranchConfig( + id=make_branch_id("test"), + title="Test", + menus=[ + MenuConfig( + id=make_menu_id(f"menu{i}"), + label=f"Menu {i}", + target=make_branch_id("target"), + ) + for i in range(20) # At limit + ], + ) + assert len(config.menus) == 20 + + def test_rejects_too_many_branches_in_wizard(self) -> None: + """Should reject wizard with too many branches (>100).""" + with pytest.raises(ValidationError, match="Too many branches"): + WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("branch0"), + branches=[ + BranchConfig( + id=make_branch_id(f"branch{i}"), + title=f"Branch {i}", + ) + for i in range(101) # Over limit + ], + ) + + def test_accepts_max_branches_in_wizard(self) -> None: + """Should accept wizard with exactly 100 branches.""" + config = WizardConfig( + name="test", + version="1.0.0", + entry_branch=make_branch_id("branch0"), + branches=[ + BranchConfig( + id=make_branch_id(f"branch{i}"), + title=f"Branch {i}", + ) + for i in range(100) # At limit + ], + ) + assert len(config.branches) == 100 + + def test_rejects_too_many_option_values_in_session(self) -> None: + """Should reject session with too many option values (>1000).""" + with pytest.raises(ValidationError, match="Too many options"): + SessionState( + option_values={ + make_option_key(f"option{i}"): "value" for i in range(1001) + } + ) + + def test_accepts_max_option_values_in_session(self) -> None: + """Should accept session with exactly 1000 option values.""" + state = SessionState( + option_values={make_option_key(f"option{i}"): "value" for i in range(1000)} + ) + assert len(state.option_values) == 1000 + + def test_rejects_too_many_variables_in_session(self) -> None: + """Should reject session with too many variables (>1000).""" + with pytest.raises(ValidationError, match="Too many variables"): + SessionState(variables={f"var{i}": "value" for i in range(1001)}) + + def test_accepts_max_variables_in_session(self) -> None: + """Should accept session with exactly 1000 variables.""" + state = SessionState(variables={f"var{i}": "value" for i in range(1000)}) + assert len(state.variables) == 1000 From fa72fc6a9993cb0860717e3b098d9909c8f853f5 Mon Sep 17 00:00:00 2001 From: Doug Date: Tue, 30 Sep 2025 23:34:27 -0400 Subject: [PATCH 9/9] feat(core): add production validation mode with security config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Priority 4 (LOW) security enhancement: production validation mode. Changes: - New config.py module with SecurityConfig TypedDict - Environment variable support for security settings - Updated factory functions to respect global validation config - Factory functions now accept Optional[bool] for validation: - None: use global config (default) - True: force validation - False: skip validation Environment Variables: - CLI_PATTERNS_ENABLE_VALIDATION: Enable strict validation (default: false) - CLI_PATTERNS_MAX_JSON_DEPTH: Max nesting depth (default: 50) - CLI_PATTERNS_MAX_COLLECTION_SIZE: Max collection size (default: 1000) - CLI_PATTERNS_ALLOW_SHELL: Allow shell features globally (default: false) Benefits: - Zero-overhead validation in development (default: off) - Easy enablement for production via environment variable - Consistent validation behavior across all factory functions - Configurable DoS protection limits Usage: ```bash # Production deployment export CLI_PATTERNS_ENABLE_VALIDATION=true ``` All 782 tests passing. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/cli_patterns/core/config.py | 123 ++++++++++++++++++++++++++++++++ src/cli_patterns/core/types.py | 39 +++++++--- 2 files changed, 153 insertions(+), 9 deletions(-) create mode 100644 src/cli_patterns/core/config.py diff --git a/src/cli_patterns/core/config.py b/src/cli_patterns/core/config.py new file mode 100644 index 0000000..3b952f4 --- /dev/null +++ b/src/cli_patterns/core/config.py @@ -0,0 +1,123 @@ +"""Configuration for CLI Patterns core behavior. + +This module provides security and runtime configuration through environment +variables and TypedDict configurations. +""" + +from __future__ import annotations + +import os +from typing import TypedDict + + +class SecurityConfig(TypedDict): + """Security configuration settings. + + These settings control security features like validation strictness, + DoS protection limits, and shell feature permissions. + """ + + enable_validation: bool + """Enable strict validation for all factory functions. + + When True, factory functions perform validation on inputs. + Default: False (for performance in development). + Recommended for production: True + """ + + max_json_depth: int + """Maximum nesting depth for JSON values. + + Prevents DoS attacks via deeply nested structures. + Default: 50 levels + """ + + max_collection_size: int + """Maximum size for collections. + + Prevents memory exhaustion from large data structures. + Default: 1000 items + """ + + allow_shell_features: bool + """Allow shell features by default (INSECURE). + + When True, shell features (pipes, redirects, etc.) are allowed by default. + Default: False (secure) + WARNING: Setting this to True is a security risk. Always use per-action + configuration instead. + """ + + +def get_security_config() -> SecurityConfig: + """Get security configuration from environment variables. + + Environment Variables: + CLI_PATTERNS_ENABLE_VALIDATION: Enable strict validation (default: false) + Set to 'true' to enable validation in factory functions. + + CLI_PATTERNS_MAX_JSON_DEPTH: Max JSON nesting depth (default: 50) + Controls maximum depth for nested data structures. + + CLI_PATTERNS_MAX_COLLECTION_SIZE: Max collection size (default: 1000) + Controls maximum number of items in collections. + + CLI_PATTERNS_ALLOW_SHELL: Allow shell features globally (default: false) + WARNING: Setting to 'true' is insecure. Use per-action configuration. + + Returns: + SecurityConfig with settings from environment or defaults + + Example: + >>> os.environ['CLI_PATTERNS_ENABLE_VALIDATION'] = 'true' + >>> config = get_security_config() + >>> config['enable_validation'] + True + """ + return SecurityConfig( + enable_validation=os.getenv("CLI_PATTERNS_ENABLE_VALIDATION", "false").lower() + == "true", + max_json_depth=int(os.getenv("CLI_PATTERNS_MAX_JSON_DEPTH", "50")), + max_collection_size=int(os.getenv("CLI_PATTERNS_MAX_COLLECTION_SIZE", "1000")), + allow_shell_features=os.getenv("CLI_PATTERNS_ALLOW_SHELL", "false").lower() + == "true", + ) + + +# Global config instance (cached) +_security_config: SecurityConfig | None = None + + +def get_config() -> SecurityConfig: + """Get global security config (cached). + + This function caches the configuration on first call to avoid + repeated environment variable lookups. + + Returns: + Cached SecurityConfig instance + + Example: + >>> config = get_config() + >>> if config['enable_validation']: + ... # Perform validation + ... pass + """ + global _security_config + if _security_config is None: + _security_config = get_security_config() + return _security_config + + +def reset_config() -> None: + """Reset cached configuration. + + This is primarily useful for testing when you need to reload + configuration from environment variables. + + Example: + >>> reset_config() + >>> # Config will be reloaded on next get_config() call + """ + global _security_config + _security_config = None diff --git a/src/cli_patterns/core/types.py b/src/cli_patterns/core/types.py index 1c3c748..5aa338d 100644 --- a/src/cli_patterns/core/types.py +++ b/src/cli_patterns/core/types.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Any, NewType, Union +from typing import Any, NewType, Optional, Union from typing_extensions import TypeGuard @@ -63,12 +63,12 @@ # Factory functions for creating semantic types -def make_branch_id(value: str, validate: bool = False) -> BranchId: +def make_branch_id(value: str, validate: Optional[bool] = None) -> BranchId: """Create a BranchId from a string value. Args: value: String value to convert to BranchId - validate: If True, validate the input (default: False for zero overhead) + validate: If True, validate input. If None, use global config. If False, skip. Returns: BranchId with semantic type safety @@ -76,6 +76,12 @@ def make_branch_id(value: str, validate: bool = False) -> BranchId: Raises: ValueError: If validate=True and value is invalid """ + if validate is None: + # Import here to avoid circular dependency + from cli_patterns.core.config import get_config + + validate = get_config()["enable_validation"] + if validate: if not value or not value.strip(): raise ValueError("BranchId cannot be empty") @@ -84,12 +90,12 @@ def make_branch_id(value: str, validate: bool = False) -> BranchId: return BranchId(value) -def make_action_id(value: str, validate: bool = False) -> ActionId: +def make_action_id(value: str, validate: Optional[bool] = None) -> ActionId: """Create an ActionId from a string value. Args: value: String value to convert to ActionId - validate: If True, validate the input (default: False for zero overhead) + validate: If True, validate input. If None, use global config. If False, skip. Returns: ActionId with semantic type safety @@ -97,6 +103,11 @@ def make_action_id(value: str, validate: bool = False) -> ActionId: Raises: ValueError: If validate=True and value is invalid """ + if validate is None: + from cli_patterns.core.config import get_config + + validate = get_config()["enable_validation"] + if validate: if not value or not value.strip(): raise ValueError("ActionId cannot be empty") @@ -105,12 +116,12 @@ def make_action_id(value: str, validate: bool = False) -> ActionId: return ActionId(value) -def make_option_key(value: str, validate: bool = False) -> OptionKey: +def make_option_key(value: str, validate: Optional[bool] = None) -> OptionKey: """Create an OptionKey from a string value. Args: value: String value to convert to OptionKey - validate: If True, validate the input (default: False for zero overhead) + validate: If True, validate input. If None, use global config. If False, skip. Returns: OptionKey with semantic type safety @@ -118,6 +129,11 @@ def make_option_key(value: str, validate: bool = False) -> OptionKey: Raises: ValueError: If validate=True and value is invalid """ + if validate is None: + from cli_patterns.core.config import get_config + + validate = get_config()["enable_validation"] + if validate: if not value or not value.strip(): raise ValueError("OptionKey cannot be empty") @@ -126,12 +142,12 @@ def make_option_key(value: str, validate: bool = False) -> OptionKey: return OptionKey(value) -def make_menu_id(value: str, validate: bool = False) -> MenuId: +def make_menu_id(value: str, validate: Optional[bool] = None) -> MenuId: """Create a MenuId from a string value. Args: value: String value to convert to MenuId - validate: If True, validate the input (default: False for zero overhead) + validate: If True, validate input. If None, use global config. If False, skip. Returns: MenuId with semantic type safety @@ -139,6 +155,11 @@ def make_menu_id(value: str, validate: bool = False) -> MenuId: Raises: ValueError: If validate=True and value is invalid """ + if validate is None: + from cli_patterns.core.config import get_config + + validate = get_config()["enable_validation"] + if validate: if not value or not value.strip(): raise ValueError("MenuId cannot be empty")