diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 475f230..e86451b 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,22 +1,80 @@ -name: Publish Python ๐Ÿ distribution to PyPI +name: CI/CD Pipeline on: push: branches: - main - master + pull_request: + branches: + - main + - master release: types: [published] jobs: + lint: + name: Lint & Format Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install linting tools + run: | + python -m pip install --upgrade pip + pip install ruff mypy + + - name: Check formatting with Ruff + run: ruff format --check . + + - name: Lint with Ruff + run: ruff check . + + - name: Type check with mypy + run: mypy runapi --ignore-missing-imports + continue-on-error: true # Don't fail on type errors initially + + test: + name: Test Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + needs: lint + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests + run: | + python tests/test_runapi.py + check-version: runs-on: ubuntu-latest + needs: test + if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master') outputs: is_new_version: ${{ steps.check.outputs.is_new_version }} version: ${{ steps.check.outputs.version }} steps: - uses: actions/checkout@v4 - + - name: Set up Python uses: actions/setup-python@v5 with: @@ -38,9 +96,9 @@ jobs: project = tomllib.load(f) local_version = project["project"]["version"] package_name = project["project"]["name"] - + print(f"Checking version {local_version} for package {package_name}...") - + # Output version for later jobs with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: print(f"version={local_version}", file=fh) @@ -65,7 +123,7 @@ jobs: else: print(f"Error checking PyPI: {e}") sys.exit(1) - + except Exception as e: print(f"Error: {e}") sys.exit(1) @@ -73,22 +131,22 @@ jobs: build: needs: check-version if: needs.check-version.outputs.is_new_version == 'true' - name: Build distribution ๐Ÿ“ฆ + name: Build distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - + - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.11" - + - name: Install build dependencies run: | python -m pip install --upgrade pip pip install build - + - name: Build a binary wheel and a source tarball run: python -m build @@ -99,7 +157,7 @@ jobs: path: dist/ publish-to-pypi: - name: Publish Python ๐Ÿ distribution to PyPI + name: Publish to PyPI needs: build runs-on: ubuntu-latest permissions: @@ -112,7 +170,7 @@ jobs: name: python-package-distributions path: dist/ - - name: Publish distribution ๐Ÿ“ฆ to PyPI + - name: Publish distribution to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} @@ -120,14 +178,14 @@ jobs: verbose: true create-release: - name: Create GitHub Release ๐Ÿท๏ธ + name: Create GitHub Release needs: [publish-to-pypi, check-version] runs-on: ubuntu-latest permissions: contents: write steps: - uses: actions/checkout@v4 - + - name: Create Release uses: softprops/action-gh-release@v2 with: diff --git a/.gitignore b/.gitignore index b3fd459..40bf579 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,106 @@ +# Byte-compiled / optimized / DLL files __pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ dist/ -venv/ \ No newline at end of file +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDEs and editors +.idea/ +.vscode/ +*.swp +*.swo +*~ +.project +.pydevproject +.settings/ +*.sublime-project +*.sublime-workspace + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre +.pyre/ + +# Logs +*.log +logs/ + +# Local development +.DS_Store +Thumbs.db + +# Temporary files +tmp/ +temp/ +*.tmp +*.bak + +# Secrets (never commit these) +*.pem +*.key +secrets.json +credentials.json diff --git a/README.md b/README.md index 7004416..c5ab76d 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,9 @@ A Next.js-inspired file-based routing framework built on FastAPI for Python back - ๐Ÿ”ง **CLI tools** - Command-line interface for project management - ๐Ÿ“ **Auto-documentation** - Automatic API documentation via FastAPI - ๐ŸŽฏ **Type hints** - Full typing support with Pydantic integration +- ๐Ÿ“ฆ **Schema layer** - Auto-discovered Pydantic models with base classes +- ๐Ÿ—„๏ธ **Repository pattern** - Data access abstraction with in-memory and SQLAlchemy support +- ๐Ÿงฉ **Service layer** - Business logic separation with CRUD services ## Installation @@ -41,6 +44,10 @@ pip install runapi ## Table of Contents - [Quick Start](#quick-start) +- [Project Architecture](#project-architecture) +- [Schemas](#schemas) +- [Repositories](#repositories) +- [Services](#services) - [Configuration](#configuration) - [Authentication](#authentication) - [Middleware](#middleware) @@ -72,8 +79,14 @@ my-api/ โ”‚ โ”œโ”€โ”€ users.py # GET, POST /api/users โ”‚ โ””โ”€โ”€ users/ โ”‚ โ””โ”€โ”€ [id].py # GET, PUT, DELETE /api/users/{id} +โ”œโ”€โ”€ schemas/ # Pydantic models (auto-discovered) +โ”‚ โ””โ”€โ”€ user.py # UserCreate, UserResponse, etc. +โ”œโ”€โ”€ repositories/ # Data access layer +โ”‚ โ””โ”€โ”€ user.py # UserRepository +โ”œโ”€โ”€ services/ # Business logic layer +โ”‚ โ””โ”€โ”€ user.py # UserService โ”œโ”€โ”€ static/ # Static files -โ”œโ”€โ”€ uploads/ # File uploads directory +โ”œโ”€โ”€ uploads/ # File uploads directory โ”œโ”€โ”€ main.py # Application entry point โ”œโ”€โ”€ .env # Configuration file โ””โ”€โ”€ README.md @@ -137,6 +150,314 @@ Once your server is running, you can access: - **Alternative Documentation**: `http://localhost:8000/redoc` (ReDoc) - **OpenAPI JSON Schema**: `http://localhost:8000/openapi.json` +## Project Architecture + +runapi follows a clean architecture pattern separating concerns into layers: + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Routes (routes/) โ”‚ +โ”‚ Thin HTTP handlers - file-based โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Services (services/) โ”‚ +โ”‚ Business logic, validation, orchestration โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Repositories (repositories/) โ”‚ +โ”‚ Data access abstraction (CRUD) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Schemas (schemas/) โ”‚ +โ”‚ Pydantic models for validation & serialization โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +This separation provides: +- **Testability**: Each layer can be tested independently +- **Maintainability**: Clear boundaries between concerns +- **Flexibility**: Swap implementations without affecting other layers + +## Schemas + +Schemas define your data models using Pydantic. They are **auto-discovered** from the `schemas/` directory. + +### Generate a Schema + +```bash +runapi generate schema user +``` + +This creates `schemas/user.py` with boilerplate classes. + +### Schema Base Classes + +```python +from runapi import BaseSchema, IDMixin, TimestampMixin +from pydantic import Field +from typing import Optional + +# Base schema with ORM support and validation +class UserBase(BaseSchema): + email: str = Field(..., description="User email") + name: str = Field(..., min_length=1, max_length=100) + +# Schema for creating (no ID, no timestamps) +class UserCreate(UserBase): + password: str = Field(..., min_length=8) + +# Schema for updating (all fields optional) +class UserUpdate(BaseSchema): + email: Optional[str] = None + name: Optional[str] = None + +# Schema for responses (includes ID and timestamps) +class UserResponse(UserBase, IDMixin, TimestampMixin): + pass +``` + +### Built-in Schema Utilities + +```python +from runapi import ( + BaseSchema, # Base with ORM mode, validation + IDMixin, # Adds 'id: int' field + TimestampMixin, # Adds 'created_at', 'updated_at' + MessageResponse, # Simple {"message": str, "success": bool} + PaginatedResponse, # Generic paginated list wrapper + PaginationParams, # Query params with offset/limit +) + +# Pagination example +from runapi import PaginatedResponse, PaginationParams + +async def get_users(params: PaginationParams): + users = await user_service.get_all( + skip=params.offset, + limit=params.limit + ) + total = await user_service.count() + return PaginatedResponse.create( + items=users, + total=total, + page=params.page, + page_size=params.page_size + ) +``` + +### Using Schemas in Routes + +```python +# routes/api/users.py +from runapi import JSONResponse, Request +from schemas.user import UserCreate, UserResponse + +async def post(request: Request): + body = await request.json() + user_data = UserCreate(**body) # Validates input + # ... create user logic + return JSONResponse(UserResponse(**user).model_dump()) +``` + +## Repositories + +Repositories abstract data access, making it easy to swap storage backends. + +### Generate a Repository + +```bash +runapi generate repository user +``` + +### In-Memory Repository (Prototyping) + +```python +from runapi import InMemoryRepository + +class UserRepository(InMemoryRepository): + """In-memory storage - great for development/testing.""" + + async def find_by_email(self, email: str): + return await self.get_by(email=email) + + async def find_active_users(self): + return await self.get_many_by(is_active=True) + +# Usage +repo = UserRepository() +user = await repo.create({"name": "John", "email": "john@example.com"}) +users = await repo.get_all(skip=0, limit=10) +await repo.update(1, {"name": "Johnny"}) +await repo.delete(1) +``` + +### Typed Repository (with Pydantic models) + +```python +from runapi import TypedInMemoryRepository +from schemas.user import UserResponse + +class UserRepository(TypedInMemoryRepository[UserResponse]): + def __init__(self): + super().__init__(UserResponse) + + async def find_by_email(self, email: str): + return await self.get_by(email=email) + +# Returns UserResponse instances, not dicts +user = await repo.create({"name": "John", "email": "john@example.com"}) +assert isinstance(user, UserResponse) +``` + +### SQLAlchemy Repository (Production) + +```python +from runapi import SQLAlchemyRepository, SQLALCHEMY_AVAILABLE +from sqlalchemy.ext.asyncio import AsyncSession + +if SQLALCHEMY_AVAILABLE: + class UserRepository(SQLAlchemyRepository[UserModel, int]): + def __init__(self, session: AsyncSession): + super().__init__(session, UserModel) + + async def find_by_email(self, email: str): + return await self.get_by(email=email) +``` + +### Repository Methods + +All repositories provide these methods: + +| Method | Description | +|--------|-------------| +| `get(id)` | Get by ID | +| `get_all(skip, limit, **filters)` | Get all with pagination | +| `get_by(**filters)` | Get single matching filters | +| `create(data)` | Create new entity | +| `update(id, data)` | Update existing entity | +| `delete(id)` | Delete entity | +| `count(**filters)` | Count entities | +| `exists(id)` | Check if exists | + +## Services + +Services contain business logic, sitting between routes and repositories. + +### Generate a Service + +```bash +runapi generate service user +``` + +### CRUD Service (Ready-to-use) + +```python +from runapi import CRUDService, InMemoryRepository + +class UserService(CRUDService): + """Inherits all CRUD operations with error handling.""" + + async def register(self, data: dict): + # Business logic: check if email exists + existing = await self.repository.get_by(email=data["email"]) + if existing: + raise ValidationError("Email already registered") + return await self.create(data) + + async def deactivate(self, user_id: int): + return await self.update(user_id, {"is_active": False}) + +# Usage +user_repo = UserRepository() +user_service = UserService(user_repo, entity_name="User") + +# Built-in methods with error handling +user = await user_service.get(1) # Raises NotFoundError if missing +users = await user_service.get_all(skip=0, limit=10) +new_user = await user_service.create({"name": "John"}) +await user_service.update(1, {"name": "Johnny"}) +await user_service.delete(1) # Raises NotFoundError if missing +``` + +### Validated Service (with Schema Validation) + +```python +from runapi import ValidatedService +from schemas.user import UserCreate, UserUpdate + +class UserService(ValidatedService): + create_schema = UserCreate # Auto-validates on create + update_schema = UserUpdate # Auto-validates on update + +# Input is validated against schemas automatically +user = await user_service.create({ + "name": "John", + "email": "john@example.com", + "password": "secure123" +}) +``` + +### Complete Example: Routes + Service + Repository + +```python +# repositories/user.py +from runapi import InMemoryRepository + +class UserRepository(InMemoryRepository): + async def find_by_email(self, email: str): + return await self.get_by(email=email) + +# services/user.py +from runapi import CRUDService, ValidationError + +class UserService(CRUDService): + async def register(self, data: dict): + if await self.repository.get_by(email=data["email"]): + raise ValidationError("Email exists") + return await self.create(data) + +# routes/api/users.py +from runapi import JSONResponse, Request +from repositories.user import UserRepository +from services.user import UserService + +user_repo = UserRepository() +user_service = UserService(user_repo, "User") + +async def get(): + users = await user_service.get_all() + return JSONResponse(users) + +async def post(request: Request): + body = await request.json() + user = await user_service.register(body) + return JSONResponse(user, status_code=201) +``` + +### Service Decorators + +```python +from runapi import validate_input, require_exists, log_operation +from schemas.user import UserCreate + +class UserService(CRUDService): + + @validate_input(UserCreate) + async def create(self, data: dict): + return await self.repository.create(data) + + @require_exists("User") + async def update(self, id: int, data: dict): + return await self.repository.update(id, data) + + @log_operation("delete_user") + async def delete(self, id: int): + return await self.repository.delete(id) +``` + ## Configuration runapi uses environment variables for configuration. Create a `.env` file: @@ -302,18 +623,39 @@ runapi init my-project # Run development server runapi dev +# Run production server (multiple workers) +runapi start --workers 4 + # Generate boilerplate code -runapi generate route users -runapi generate middleware auth -runapi generate main +runapi generate route users # Create route file +runapi generate schema user # Create schema with base classes +runapi generate repository user # Create repository with CRUD +runapi generate service user # Create service with business logic +runapi generate middleware auth # Create custom middleware +runapi generate main # Create main.py entry point -# List all routes -runapi routes +# List resources +runapi routes # List all API routes +runapi schemas # List all schemas # Show project info runapi info ``` +### Generator Examples + +```bash +# Generate a complete user module +runapi generate schema user +runapi generate repository user +runapi generate service user +runapi generate route users --path api + +# Generate nested resources +runapi generate schema product --path api +runapi generate repository product --path api +``` + ## Advanced Usage ### Custom Application Setup @@ -500,6 +842,56 @@ Raises a not found error (404). #### `raise_auth_error(message: str = "Authentication required")` Raises an authentication error (401). +### Schema Classes + +#### `BaseSchema` +Base Pydantic model with ORM mode, validation, and JSON serialization. + +#### `IDMixin` +Mixin adding `id: int` field. + +#### `TimestampMixin` +Mixin adding `created_at` and `updated_at` datetime fields. + +#### `PaginatedResponse[T]` +Generic paginated response wrapper with `items`, `total`, `page`, `page_size`, `pages`. + +#### `PaginationParams` +Query parameters for pagination with `page`, `page_size`, `offset`, `limit` properties. + +### Repository Classes + +#### `BaseRepository[T, ID]` +Abstract base repository with CRUD operations. + +#### `InMemoryRepository` +Dictionary-based in-memory storage for prototyping and testing. + +#### `TypedInMemoryRepository[T]` +Type-safe in-memory repository returning Pydantic model instances. + +#### `SQLAlchemyRepository[T, ID]` +Async SQLAlchemy repository (requires `sqlalchemy[asyncio]`). + +### Service Classes + +#### `CRUDService[T, ID]` +Ready-to-use CRUD service with error handling for `get`, `get_all`, `create`, `update`, `delete`. + +#### `ValidatedService[T, ID]` +CRUD service with automatic Pydantic schema validation. + +### Service Decorators + +#### `@validate_input(schema)` +Validates input data against a Pydantic schema before method execution. + +#### `@require_exists(entity_name)` +Ensures entity exists before method execution, raises `NotFoundError` if not. + +#### `@log_operation(operation_name)` +Logs service operation start, completion, and errors. + ## Route Conventions ### File Naming @@ -612,7 +1004,10 @@ runapi dev ## Roadmap -- [ ] Database integration helpers (SQLAlchemy, MongoDB) +- [x] Schema layer with auto-discovery +- [x] Repository pattern (in-memory, SQLAlchemy) +- [x] Service layer with CRUD operations +- [x] CLI generators for schemas, repositories, services - [ ] Built-in caching mechanisms (Redis, in-memory) - [ ] WebSocket routing support - [ ] Background task queue integration @@ -620,6 +1015,7 @@ runapi dev - [ ] More authentication providers (OAuth, LDAP) - [ ] Performance monitoring and metrics - [ ] GraphQL support +- [ ] MongoDB repository support ## Contributing @@ -669,7 +1065,23 @@ Please include: ## Changelog -### v0.1.2 (Latest) +### v0.1.3 (Latest) +- **New Feature**: Schema layer with auto-discovery from `schemas/` directory +- **New Feature**: `BaseSchema`, `IDMixin`, `TimestampMixin` for consistent model definitions +- **New Feature**: `PaginatedResponse` and `PaginationParams` for pagination support +- **New Feature**: Repository pattern with `BaseRepository`, `InMemoryRepository`, `TypedInMemoryRepository` +- **New Feature**: Optional `SQLAlchemyRepository` for async database support +- **New Feature**: Service layer with `CRUDService`, `ValidatedService` +- **New Feature**: Service decorators: `@validate_input`, `@require_exists`, `@log_operation` +- **New Feature**: `ServiceFactory` and `RepositoryFactory` for dependency injection +- **CLI**: Added `runapi generate schema ` command +- **CLI**: Added `runapi generate repository ` command +- **CLI**: Added `runapi generate service ` command +- **CLI**: Added `runapi schemas` command to list all schemas +- **CLI**: Updated `runapi init` to create schemas/, repositories/, services/ directories +- **Tests**: Added 20 comprehensive tests covering all new features + +### v0.1.2 - **New Feature**: Added `runapi start` command for production deployments (no-reload, multi-worker support) - **Performance**: Optimized startup time by ignoring irrelevant directories during route discovery - **Performance**: Replaced O(N) rate limiting with O(1) Fixed Window Counter algorithm diff --git a/pyproject.toml b/pyproject.toml index 842a163..ecc2f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "runapi" -version = "0.1.2" +version = "0.1.3" description = "A Next.js-inspired file-based routing framework built on FastAPI" readme = "README.md" authors = [{ name = "Amanpreet Singh", email = "amanpreetsinghjhiwant7@gmail.com" }] @@ -53,8 +53,7 @@ Changelog = "https://github.com/Amanbig/runapi/releases" dev = [ "pytest>=7.0.0", "httpx>=0.24.0", - "black>=23.0.0", - "isort>=5.12.0", + "ruff>=0.1.0", "mypy>=1.0.0", ] test = [ @@ -63,6 +62,54 @@ test = [ "pytest-asyncio>=0.21.0", "pytest-cov>=4.0.0", ] +lint = [ + "ruff>=0.1.0", + "mypy>=1.0.0", +] [project.scripts] runapi = "runapi.cli:main" + +[tool.ruff] +target-version = "py38" +line-length = 100 +exclude = [ + ".git", + ".venv", + "venv", + "__pycache__", + "build", + "dist", + "*.egg-info", +] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults + "B905", # zip without explicit strict parameter +] + +[tool.ruff.lint.isort] +known-first-party = ["runapi"] + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = false +ignore_missing_imports = true +exclude = [ + "tests/", + "build/", + "dist/", +] diff --git a/runapi/__init__.py b/runapi/__init__.py index 327c95a..3448bc9 100644 --- a/runapi/__init__.py +++ b/runapi/__init__.py @@ -2,101 +2,149 @@ RunApi - A Next.js-inspired file-based routing framework built on FastAPI """ -__version__ = "0.1.1" +__version__ = "0.1.3" __author__ = "Amanpreet Singh" __email__ = "amanpreetsinghjhiwant7@gmail.com" # Core framework -from .core import create_app, create_runapi_app, RunApiApp +# Authentication +from .auth import ( + APIKeyManager, + AuthDependencies, + JWTManager, + PasswordManager, + TokenResponse, + api_key_manager, + create_access_token, + create_refresh_token, + create_token_response, + generate_api_key, + generate_password, + get_current_active_user, + get_current_user, + hash_password, + require_permissions, + require_roles, + verify_password, + verify_token, +) # Configuration from .config import RunApiConfig, get_config, load_config +from .core import RunApiApp, create_app, create_runapi_app # Error handling from .errors import ( - RunApiException, - ValidationError, AuthenticationError, AuthorizationError, - NotFoundError, ConflictError, - RateLimitError, - ServerError, DatabaseError, - ExternalServiceError, - ErrorResponse, ErrorHandler, - setup_error_handlers, - raise_validation_error, + ErrorResponse, + ExternalServiceError, + NotFoundError, + RateLimitError, + RunApiException, + ServerError, + ValidationError, + bad_request, + conflict, + create_error_response, + forbidden, + internal_error, + not_found, raise_auth_error, - raise_permission_error, - raise_not_found, raise_conflict, + raise_not_found, + raise_permission_error, raise_server_error, - create_error_response, - bad_request, + raise_validation_error, + rate_limited, + setup_error_handlers, unauthorized, - forbidden, - not_found, - conflict, unprocessable_entity, - rate_limited, - internal_error, -) - -# Authentication -from .auth import ( - PasswordManager, - JWTManager, - APIKeyManager, - AuthDependencies, - TokenResponse, - hash_password, - verify_password, - create_access_token, - create_refresh_token, - verify_token, - get_current_user, - get_current_active_user, - require_roles, - require_permissions, - generate_api_key, - generate_password, - create_token_response, - api_key_manager, ) # Middleware from .middleware import ( - RunApiMiddleware, - RequestLoggingMiddleware, - RateLimitMiddleware, AuthMiddleware, - SecurityHeadersMiddleware, CompressionMiddleware, CORSMiddleware, - create_rate_limit_middleware, + RateLimitMiddleware, + RequestLoggingMiddleware, + RunApiMiddleware, + SecurityHeadersMiddleware, create_auth_middleware, create_logging_middleware, + create_rate_limit_middleware, create_security_middleware, ) +# Repository +from .repository import ( + SQLALCHEMY_AVAILABLE, + BaseRepository, + InMemoryRepository, + RepositoryFactory, + RepositoryProtocol, + TypedInMemoryRepository, + create_repository, +) + +# Schemas +from .schemas import ( + BaseSchema, + ErrorDetail, + IDMixin, + MessageResponse, + PaginatedResponse, + PaginationParams, + SchemaRegistry, + TimestampMixin, + create_create_model, + create_response_model, + create_update_model, + get_schema, + list_schemas, + load_schemas, +) +from .schemas import ( + ErrorResponse as SchemaErrorResponse, +) + +# Conditional SQLAlchemy import +if SQLALCHEMY_AVAILABLE: + from .repository import SQLAlchemyRepository +else: + SQLAlchemyRepository = None # type: ignore + +# Service # Convenience imports -from fastapi import FastAPI, APIRouter, Depends, HTTPException, Request, Response -from fastapi.responses import JSONResponse, HTMLResponse, FileResponse +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware +from fastapi.responses import FileResponse, HTMLResponse, JSONResponse + +from .service import ( + BaseService, + CRUDService, + ServiceFactory, + ValidatedService, + create_crud_service, + create_service_dependency, + log_operation, + require_exists, + validate_input, +) __all__ = [ # Core "create_app", - "create_runapi_app", + "create_runapi_app", "RunApiApp", - # Configuration "RunApiConfig", "get_config", "load_config", - # Error handling "RunApiException", "ValidationError", @@ -126,10 +174,9 @@ "unprocessable_entity", "rate_limited", "internal_error", - # Authentication "PasswordManager", - "JWTManager", + "JWTManager", "APIKeyManager", "AuthDependencies", "TokenResponse", @@ -146,11 +193,10 @@ "generate_password", "create_token_response", "api_key_manager", - # Middleware "RunApiMiddleware", "RequestLoggingMiddleware", - "RateLimitMiddleware", + "RateLimitMiddleware", "AuthMiddleware", "SecurityHeadersMiddleware", "CompressionMiddleware", @@ -159,10 +205,44 @@ "create_auth_middleware", "create_logging_middleware", "create_security_middleware", - + # Schemas + "BaseSchema", + "TimestampMixin", + "IDMixin", + "MessageResponse", + "PaginatedResponse", + "PaginationParams", + "ErrorDetail", + "SchemaErrorResponse", + "SchemaRegistry", + "load_schemas", + "get_schema", + "list_schemas", + "create_response_model", + "create_create_model", + "create_update_model", + # Repository + "BaseRepository", + "RepositoryProtocol", + "InMemoryRepository", + "TypedInMemoryRepository", + "SQLAlchemyRepository", + "RepositoryFactory", + "create_repository", + "SQLALCHEMY_AVAILABLE", + # Service + "BaseService", + "CRUDService", + "ValidatedService", + "ServiceFactory", + "validate_input", + "require_exists", + "log_operation", + "create_service_dependency", + "create_crud_service", # FastAPI re-exports "FastAPI", - "APIRouter", + "APIRouter", "Depends", "HTTPException", "Request", @@ -171,4 +251,4 @@ "HTMLResponse", "FileResponse", "FastAPICORSMiddleware", -] \ No newline at end of file +] diff --git a/runapi/auth.py b/runapi/auth.py index 0a97c69..c4099ff 100644 --- a/runapi/auth.py +++ b/runapi/auth.py @@ -1,134 +1,134 @@ -import os -import time import hashlib -import secrets -from typing import Optional, Dict, Any, Union -from datetime import datetime, timedelta -import json -import base64 import hmac +import secrets +import time +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional try: from passlib.context import CryptContext except ImportError: CryptContext = None -from fastapi import HTTPException, Depends, Request -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from jose import JWTError, jwt from .config import get_config class PasswordManager: """Password hashing and verification utilities.""" - + def __init__(self, schemes: list = None): if CryptContext is None: - raise ImportError("passlib is required for password hashing. Install with: pip install passlib[bcrypt]") + raise ImportError( + "passlib is required for password hashing. Install with: pip install passlib[bcrypt]" + ) self.schemes = schemes or ["bcrypt"] self.pwd_context = CryptContext(schemes=self.schemes, deprecated="auto") - + def hash_password(self, password: str) -> str: """Hash a password.""" return self.pwd_context.hash(password) - + def verify_password(self, plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return self.pwd_context.verify(plain_password, hashed_password) - + def generate_random_password(self, length: int = 12) -> str: """Generate a random password.""" return secrets.token_urlsafe(length) -from jose import jwt, JWTError - class JWTManager: """JWT token management utilities using python-jose.""" - + def __init__(self, secret_key: str = None, algorithm: str = "HS256"): self.config = get_config() self.secret_key = secret_key or self.config.secret_key self.algorithm = algorithm or self.config.jwt_algorithm self.access_token_expire = self.config.jwt_expiry self.refresh_token_expire = self.config.jwt_refresh_expiry - + if self.secret_key == "dev-secret-key-change-in-production": raise ValueError("Change the SECRET_KEY in production!") - + def create_token( self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None, - token_type: str = "access" + token_type: str = "access", ) -> str: """Create a JWT token.""" to_encode = data.copy() - + # Set expiration if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now(timezone.utc) + expires_delta else: - expire_minutes = self.access_token_expire if token_type == "access" else self.refresh_token_expire - expire = datetime.utcnow() + timedelta(seconds=expire_minutes) - - to_encode.update({ - "exp": expire, - "iat": datetime.utcnow(), - "type": token_type - }) - + expire_minutes = ( + self.access_token_expire if token_type == "access" else self.refresh_token_expire + ) + expire = datetime.now(timezone.utc) + timedelta(seconds=expire_minutes) + + to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc), "type": token_type}) + # Encode token encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm) return encoded_jwt - + def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify and decode a JWT token.""" try: payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm]) - + # Check expiration (handled by jose, but double checking payload) - if 'exp' in payload and payload['exp'] < time.time(): + if "exp" in payload and payload["exp"] < time.time(): return None - + return payload - + except JWTError: return None - - def create_access_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + + def create_access_token( + self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None + ) -> str: """Create an access token.""" return self.create_token(data, expires_delta, "access") - - def create_refresh_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + + def create_refresh_token( + self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None + ) -> str: """Create a refresh token.""" return self.create_token(data, expires_delta, "refresh") - + def refresh_access_token(self, refresh_token: str) -> Optional[str]: """Create new access token from refresh token.""" payload = self.verify_token(refresh_token) if not payload or payload.get("type") != "refresh": return None - + # Remove token-specific fields - user_data = {k: v for k, v in payload.items() if k not in ['exp', 'iat', 'type']} + user_data = {k: v for k, v in payload.items() if k not in ["exp", "iat", "type"]} return self.create_access_token(user_data) class APIKeyManager: """API Key management utilities.""" - + def __init__(self): self.config = get_config() - + def generate_api_key(self, length: int = 32) -> str: """Generate a new API key.""" return secrets.token_urlsafe(length) - + def hash_api_key(self, api_key: str) -> str: """Hash an API key for storage.""" return hashlib.sha256(api_key.encode()).hexdigest() - + def verify_api_key(self, api_key: str, hashed_key: str) -> bool: """Verify an API key against its hash.""" return hmac.compare_digest(self.hash_api_key(api_key), hashed_key) @@ -136,63 +136,64 @@ def verify_api_key(self, api_key: str, hashed_key: str) -> bool: class AuthDependencies: """FastAPI dependency classes for authentication.""" - + def __init__(self, jwt_manager: JWTManager = None): self.jwt_manager = jwt_manager or JWTManager() self.bearer_scheme = HTTPBearer() - + async def get_current_user( - self, - credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()) + self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()) ) -> Dict[str, Any]: """Dependency to get current authenticated user.""" token = credentials.credentials payload = self.jwt_manager.verify_token(token) - + if not payload: raise HTTPException( status_code=401, detail="Invalid or expired token", headers={"WWW-Authenticate": "Bearer"}, ) - + return payload - - async def get_current_active_user( - self, - current_user: Dict[str, Any] = Depends(lambda: AuthDependencies().get_current_user) - ) -> Dict[str, Any]: - """Dependency to get current active user.""" - if current_user.get("disabled"): - raise HTTPException(status_code=400, detail="Inactive user") - return current_user - - async def require_roles(self, required_roles: list): + + def get_current_active_user_dependency(self): + """Create a dependency to get current active user. + + Returns a dependency function that can be used with Depends(). + """ + get_user = self.get_current_user + + async def active_user_checker( + credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()), + ) -> Dict[str, Any]: + current_user = await get_user(credentials) + if current_user.get("disabled"): + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + return active_user_checker + + def require_roles(self, required_roles: list): """Create a dependency that requires specific roles.""" - async def role_checker( - current_user: Dict[str, Any] = Depends(self.get_current_user) - ): + + async def role_checker(current_user: Dict[str, Any] = Depends(self.get_current_user)): user_roles = current_user.get("roles", []) if not any(role in user_roles for role in required_roles): - raise HTTPException( - status_code=403, - detail="Insufficient permissions" - ) + raise HTTPException(status_code=403, detail="Insufficient permissions") return current_user + return role_checker - - async def require_permissions(self, required_permissions: list): + + def require_permissions(self, required_permissions: list): """Create a dependency that requires specific permissions.""" - async def permission_checker( - current_user: Dict[str, Any] = Depends(self.get_current_user) - ): + + async def permission_checker(current_user: Dict[str, Any] = Depends(self.get_current_user)): user_permissions = current_user.get("permissions", []) if not all(perm in user_permissions for perm in required_permissions): - raise HTTPException( - status_code=403, - detail="Insufficient permissions" - ) + raise HTTPException(status_code=403, detail="Insufficient permissions") return current_user + return permission_checker @@ -202,6 +203,7 @@ async def permission_checker( api_key_manager = APIKeyManager() auth_deps = None + def _get_password_manager(): global password_manager if password_manager is None: @@ -211,12 +213,14 @@ def _get_password_manager(): pass return password_manager + def _get_jwt_manager(): global jwt_manager if jwt_manager is None: jwt_manager = JWTManager() return jwt_manager + def _get_auth_deps(): global auth_deps if auth_deps is None: @@ -228,7 +232,9 @@ def hash_password(password: str) -> str: """Hash a password using the global password manager.""" manager = _get_password_manager() if manager is None: - raise ImportError("passlib is required for password hashing. Install with: pip install passlib[bcrypt]") + raise ImportError( + "passlib is required for password hashing. Install with: pip install passlib[bcrypt]" + ) return manager.hash_password(password) @@ -236,7 +242,9 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password using the global password manager.""" manager = _get_password_manager() if manager is None: - raise ImportError("passlib is required for password hashing. Install with: pip install passlib[bcrypt]") + raise ImportError( + "passlib is required for password hashing. Install with: pip install passlib[bcrypt]" + ) return manager.verify_password(plain_password, hashed_password) @@ -262,7 +270,7 @@ def get_current_user(): def get_current_active_user(): """Get the current active user dependency.""" - return _get_auth_deps().get_current_active_user + return _get_auth_deps().get_current_active_user_dependency() def require_roles(roles: list): @@ -285,23 +293,22 @@ def generate_password(length: int = 12) -> str: """Generate a random password.""" manager = _get_password_manager() if manager is None: - raise ImportError("passlib is required for password generation. Install with: pip install passlib[bcrypt]") + raise ImportError( + "passlib is required for password generation. Install with: pip install passlib[bcrypt]" + ) return manager.generate_random_password(length) class TokenResponse: """Standard token response format.""" - + def __init__(self, access_token: str, refresh_token: str = None, token_type: str = "bearer"): self.access_token = access_token self.refresh_token = refresh_token self.token_type = token_type - + def dict(self): - result = { - "access_token": self.access_token, - "token_type": self.token_type - } + result = {"access_token": self.access_token, "token_type": self.token_type} if self.refresh_token: result["refresh_token"] = self.refresh_token return result @@ -311,4 +318,4 @@ def create_token_response(user_data: Dict[str, Any]) -> TokenResponse: """Create a standard token response with access and refresh tokens.""" access_token = create_access_token(user_data) refresh_token = create_refresh_token(user_data) - return TokenResponse(access_token=access_token, refresh_token=refresh_token) \ No newline at end of file + return TokenResponse(access_token=access_token, refresh_token=refresh_token) diff --git a/runapi/cli.py b/runapi/cli.py index d45e811..b0fcb2c 100644 --- a/runapi/cli.py +++ b/runapi/cli.py @@ -1,17 +1,14 @@ # runapi/cli.py +import os +from pathlib import Path + import typer import uvicorn from rich.console import Console -from rich.table import Table from rich.panel import Panel -from rich.text import Text -from pathlib import Path -import os -import shutil -from typing import Optional +from rich.table import Table -from .config import load_config, RunApiConfig -from .core import create_runapi_app +from .config import load_config app = typer.Typer(name="runapi", help="RunApi - Next.js-inspired Python Backend Framework") console = Console() @@ -27,10 +24,10 @@ def dev( ): """Run the RunApi development server.""" console.print(Panel.fit("๐Ÿš€ [bold blue]RunApi Development Server[/bold blue]", style="blue")) - + # Load configuration config = load_config(config_file) - + # Override config with CLI arguments if provided if host: config.host = host @@ -40,14 +37,16 @@ def dev( config.reload = reload if log_level: config.log_level = log_level - + # Check if main.py exists main_path = Path("main.py") if not main_path.exists(): console.print("[red]โŒ Error: main.py not found in current directory") - console.print("[yellow]๐Ÿ’ก Tip: Run 'runapi init' to create a new project or 'runapi generate main' to create main.py") + console.print( + "[yellow]๐Ÿ’ก Tip: Run 'runapi init' to create a new project or 'runapi generate main' to create main.py" + ) raise typer.Exit(code=1) - + # Display server info table = Table(show_header=False, box=None) table.add_row("๐ŸŒ Server:", f"http://{config.host}:{config.port}") @@ -55,28 +54,36 @@ def dev( table.add_row("๐Ÿ“ Log Level:", config.log_level.upper()) table.add_row("โš™๏ธ Config:", config_file if Path(config_file).exists() else "Default") console.print(table) - + # Check for routes directory if Path("routes").exists(): console.print("๐Ÿ“ Routes directory detected") else: console.print("[yellow]โš ๏ธ No routes directory found") - + + # Check for schemas directory + if Path("schemas").exists(): + console.print("๐Ÿ“ฆ Schemas directory detected") + else: + console.print("[dim]๐Ÿ“ฆ No schemas directory (optional)") + console.print() - + try: # Ensure we're in the correct working directory import os import sys + current_dir = os.getcwd() - + # Add current directory to Python path if not already there if current_dir not in sys.path: sys.path.insert(0, current_dir) - + # Verify main.py can be imported before starting server try: import importlib.util + spec = importlib.util.spec_from_file_location("main", "main.py") if spec is None: raise ImportError("Cannot load main.py") @@ -85,9 +92,11 @@ def dev( console.print("โœ… [green]main.py loaded successfully") except Exception as e: console.print(f"[red]โŒ Error importing main.py: {e}") - console.print("[yellow]๐Ÿ’ก Make sure main.py exists and runapi is installed in this environment") - raise typer.Exit(code=1) - + console.print( + "[yellow]๐Ÿ’ก Make sure main.py exists and runapi is installed in this environment" + ) + raise typer.Exit(code=1) from e + # Run uvicorn with the FastAPI app uvicorn.run( "main:app", @@ -101,7 +110,7 @@ def dev( console.print("\n[yellow]๐Ÿ‘‹ Server stopped") except Exception as e: console.print(f"[red]โŒ Server error: {e}") - raise typer.Exit(code=1) + raise typer.Exit(code=1) from e @app.command() @@ -114,10 +123,10 @@ def start( ): """Run the RunApi server in production mode.""" console.print(Panel.fit("๐Ÿš€ [bold green]RunApi Production Server[/bold green]", style="green")) - + # Load configuration config = load_config(config_file) - + # Override config with CLI arguments if host: config.host = host @@ -125,7 +134,7 @@ def start( config.port = port if log_level: config.log_level = log_level - + # Determine workers # If not specified in CLI, check env/config, else default to 1 (or cpu_count in real prod) # RunApiConfig doesn't have 'workers' yet, adding it logic here or just defaulting @@ -135,7 +144,7 @@ def start( if not Path("main.py").exists(): console.print("[red]โŒ Error: main.py not found") raise typer.Exit(code=1) - + # Display server info table = Table(show_header=False, box=None) table.add_row("๐ŸŒ Server:", f"http://{config.host}:{config.port}") @@ -144,13 +153,14 @@ def start( table.add_row("๐Ÿ“ Log Level:", config.log_level.upper()) console.print(table) console.print() - + try: # Puts current dir in path import sys + if os.getcwd() not in sys.path: sys.path.insert(0, os.getcwd()) - + uvicorn.run( "main:app", host=config.host, @@ -163,7 +173,7 @@ def start( console.print("\n[yellow]๐Ÿ‘‹ Server stopped") except Exception as e: console.print(f"[red]โŒ Server error: {e}") - raise typer.Exit(code=1) + raise typer.Exit(code=1) from e @app.command() @@ -173,22 +183,28 @@ def init( ): """Initialize a new RunApi project.""" project_path = Path(name) - + if project_path.exists(): console.print(f"[red]โŒ Directory '{name}' already exists") raise typer.Exit(code=1) - + console.print(f"๐Ÿš€ [bold blue]Creating RunApi project: {name}[/bold blue]") - + # Create project directory project_path.mkdir() - + # Create basic project structure (project_path / "routes").mkdir() (project_path / "routes" / "__init__.py").touch() + (project_path / "schemas").mkdir() + (project_path / "schemas" / "__init__.py").touch() + (project_path / "repositories").mkdir() + (project_path / "repositories" / "__init__.py").touch() + (project_path / "services").mkdir() + (project_path / "services" / "__init__.py").touch() (project_path / "static").mkdir() (project_path / "uploads").mkdir() - + # Create main.py main_content = '''""" RunApi Application Entry Point @@ -208,14 +224,14 @@ def init( if __name__ == "__main__": runapi_app.run() ''' - - (project_path / "main.py").write_text(main_content, encoding='utf-8') - + + (project_path / "main.py").write_text(main_content, encoding="utf-8") + # Create example route routes_api_path = project_path / "routes" / "api" routes_api_path.mkdir() (routes_api_path / "__init__.py").touch() - + example_route = '''""" Example API route GET /api/hello @@ -229,9 +245,9 @@ async def get(): "status": "success" }) ''' - - (routes_api_path / "hello.py").write_text(example_route, encoding='utf-8') - + + (routes_api_path / "hello.py").write_text(example_route, encoding="utf-8") + # Create index route index_route = '''""" Home page route @@ -248,11 +264,11 @@ async def get(): } }) ''' - - (project_path / "routes" / "index.py").write_text(index_route, encoding='utf-8') - + + (project_path / "routes" / "index.py").write_text(index_route, encoding="utf-8") + # Create .env file - env_content = '''# RunApi Configuration + env_content = """# RunApi Configuration DEBUG=true HOST=127.0.0.1 PORT=8000 @@ -274,12 +290,12 @@ async def get(): STATIC_FILES_ENABLED=true STATIC_FILES_PATH=static STATIC_FILES_URL=/static -''' - - (project_path / ".env").write_text(env_content, encoding='utf-8') - +""" + + (project_path / ".env").write_text(env_content, encoding="utf-8") + # Create .gitignore - gitignore_content = '''# Python + gitignore_content = """# Python __pycache__/ *.py[cod] *$py.class @@ -305,12 +321,12 @@ async def get(): # OS .DS_Store Thumbs.db -''' - - (project_path / ".gitignore").write_text(gitignore_content, encoding='utf-8') - +""" + + (project_path / ".gitignore").write_text(gitignore_content, encoding="utf-8") + # Create README - readme_content = f'''# {name} + readme_content = f"""# {name} A RunApi API project. @@ -366,43 +382,51 @@ async def get(): async def post(): return {{"message": "POST request"}} ``` -''' - - (project_path / "README.md").write_text(readme_content, encoding='utf-8') - +""" + + (project_path / "README.md").write_text(readme_content, encoding="utf-8") + console.print("โœ… [green]Project created successfully!") - console.print(f"\n๐Ÿ“ Project structure:") - + console.print("\n๐Ÿ“ Project structure:") + # Show project structure - for root, dirs, files in os.walk(project_path): - level = root.replace(str(project_path), '').count(os.sep) - indent = ' ' * 2 * level + for root, _dirs, files in os.walk(project_path): + level = root.replace(str(project_path), "").count(os.sep) + indent = " " * 2 * level console.print(f"{indent}{os.path.basename(root)}/") - subindent = ' ' * 2 * (level + 1) + subindent = " " * 2 * (level + 1) for file in files: console.print(f"{subindent}{file}") - - console.print(f"\n๐Ÿš€ To get started:") + + console.print("\n๐Ÿš€ To get started:") console.print(f" cd {name}") - console.print(f" runapi dev") + console.print(" runapi dev") @app.command() def generate( - item: str = typer.Argument(..., help="What to generate (route, main, middleware)"), + item: str = typer.Argument( + ..., help="What to generate (route, schema, repository, service, main, middleware)" + ), name: str = typer.Argument(..., help="Name of the item"), path: str = typer.Option("", "--path", "-p", help="Path for the item"), ): """Generate boilerplate code.""" if item == "route": _generate_route(name, path) + elif item == "schema": + _generate_schema(name, path) + elif item == "repository": + _generate_repository(name, path) + elif item == "service": + _generate_service(name, path) elif item == "main": _generate_main() elif item == "middleware": _generate_middleware(name) else: console.print(f"[red]โŒ Unknown generator: {item}") - console.print("Available generators: route, main, middleware") + console.print("Available generators: route, schema, repository, service, main, middleware") raise typer.Exit(code=1) @@ -412,7 +436,7 @@ def _generate_route(name: str, path: str): if not routes_path.exists(): routes_path.mkdir() (routes_path / "__init__.py").touch() - + if path: route_path = routes_path / path route_path.mkdir(parents=True, exist_ok=True) @@ -420,11 +444,11 @@ def _generate_route(name: str, path: str): file_path = route_path / f"{name}.py" else: file_path = routes_path / f"{name}.py" - + if file_path.exists(): console.print(f"[red]โŒ Route already exists: {file_path}") raise typer.Exit(code=1) - + # Generate route template route_template = f'''""" {name.title()} route @@ -444,10 +468,10 @@ async def post(request: Request): """Handle POST request.""" # Get request body body = await request.json() - + return JSONResponse({{ "message": "Data received", - "method": "POST", + "method": "POST", "data": body }}) @@ -465,10 +489,10 @@ async def post(request: Request): # """Handle PATCH request.""" # pass ''' - - file_path.write_text(route_template, encoding='utf-8') + + file_path.write_text(route_template, encoding="utf-8") console.print(f"โœ… [green]Route created: {file_path}") - + # Show URL mapping route_url = "/" + str(file_path.relative_to(routes_path)).replace("\\", "/").replace(".py", "") if route_url.endswith("/index"): @@ -476,13 +500,324 @@ async def post(request: Request): console.print(f"๐ŸŒ URL: {route_url}") +def _generate_schema(name: str, path: str): + """Generate a new schema file.""" + schemas_path = Path("schemas") + if not schemas_path.exists(): + schemas_path.mkdir() + (schemas_path / "__init__.py").touch() + + if path: + schema_path = schemas_path / path + schema_path.mkdir(parents=True, exist_ok=True) + (schema_path / "__init__.py").touch() + file_path = schema_path / f"{name}.py" + else: + file_path = schemas_path / f"{name}.py" + + if file_path.exists(): + console.print(f"[red]โŒ Schema already exists: {file_path}") + raise typer.Exit(code=1) + + # Convert name to PascalCase for class names + class_name = "".join(word.capitalize() for word in name.replace("-", "_").split("_")) + + schema_template = f'''""" +{class_name} schemas +Generated by RunApi CLI +""" +from runapi import BaseSchema, TimestampMixin, IDMixin +from pydantic import Field, EmailStr +from typing import Optional +from datetime import datetime + + +class {class_name}Base(BaseSchema): + """Base schema with shared fields.""" + name: str = Field(..., min_length=1, max_length=100, description="{class_name} name") + # Add more shared fields here + + +class {class_name}Create({class_name}Base): + """Schema for creating a new {name}.""" + pass + # Add create-specific fields here (e.g., password for user) + + +class {class_name}Update(BaseSchema): + """Schema for updating an existing {name}. All fields optional.""" + name: Optional[str] = Field(None, min_length=1, max_length=100) + # Add more updatable fields here + + +class {class_name}Response({class_name}Base, IDMixin, TimestampMixin): + """Schema for {name} responses.""" + pass + # Response includes id and timestamps from mixins + + +class {class_name}List(BaseSchema): + """Schema for listing multiple {name}s.""" + items: list[{class_name}Response] = Field(default_factory=list) + total: int = Field(..., description="Total number of items") + page: int = Field(default=1, ge=1) + page_size: int = Field(default=20, ge=1, le=100) + + +# Example usage in routes: +# +# from schemas.{name} import {class_name}Create, {class_name}Response +# +# async def post(body: {class_name}Create) -> {class_name}Response: +# # Create {name} logic +# return {class_name}Response(id=1, name=body.name, created_at=datetime.now()) +''' + + file_path.write_text(schema_template, encoding="utf-8") + console.print(f"โœ… [green]Schema created: {file_path}") + console.print( + f"๐Ÿ“ฆ Classes: {class_name}Base, {class_name}Create, {class_name}Update, {class_name}Response, {class_name}List" + ) + + +def _generate_repository(name: str, path: str): + """Generate a new repository file.""" + repos_path = Path("repositories") + if not repos_path.exists(): + repos_path.mkdir() + (repos_path / "__init__.py").touch() + + if path: + repo_path = repos_path / path + repo_path.mkdir(parents=True, exist_ok=True) + (repo_path / "__init__.py").touch() + file_path = repo_path / f"{name}.py" + else: + file_path = repos_path / f"{name}.py" + + if file_path.exists(): + console.print(f"[red]โŒ Repository already exists: {file_path}") + raise typer.Exit(code=1) + + # Convert name to PascalCase for class names + class_name = "".join(word.capitalize() for word in name.replace("-", "_").split("_")) + + repo_template = f'''""" +{class_name} repository +Generated by RunApi CLI +""" +from typing import Optional, List, Dict, Any +from runapi import BaseRepository, InMemoryRepository, TypedInMemoryRepository + +# Option 1: Use InMemoryRepository for prototyping +# This is great for development and testing + +class {class_name}Repository(InMemoryRepository): + """ + {class_name} repository for data access. + + For production, replace InMemoryRepository with SQLAlchemyRepository + or implement your own data access logic. + """ + + async def find_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """Find {name} by name.""" + return await self.get_by(name=name) + + async def find_active(self) -> List[Dict[str, Any]]: + """Find all active {name}s.""" + return await self.get_many_by(is_active=True) + + +# Option 2: For typed repositories with Pydantic models +# Uncomment and modify as needed: +# +# from schemas.{name} import {class_name}Response +# +# class Typed{class_name}Repository(TypedInMemoryRepository[{class_name}Response]): +# def __init__(self): +# super().__init__({class_name}Response) +# +# async def find_by_email(self, email: str) -> Optional[{class_name}Response]: +# return await self.get_by(email=email) + + +# Option 3: For SQLAlchemy (if you have a database) +# Uncomment and modify as needed: +# +# from runapi import SQLAlchemyRepository, SQLALCHEMY_AVAILABLE +# from sqlalchemy.ext.asyncio import AsyncSession +# +# if SQLALCHEMY_AVAILABLE: +# from models.{name} import {class_name}Model +# +# class SQLAlchemy{class_name}Repository(SQLAlchemyRepository[{class_name}Model, int]): +# def __init__(self, session: AsyncSession): +# super().__init__(session, {class_name}Model) +# +# async def find_by_email(self, email: str) -> Optional[{class_name}Model]: +# return await self.get_by(email=email) + + +# Example usage in routes: +# +# from repositories.{name} import {class_name}Repository +# +# # Create repository instance (consider using dependency injection) +# {name}_repo = {class_name}Repository() +# +# async def get(): +# items = await {name}_repo.get_all() +# return items +# +# async def post(request: Request): +# body = await request.json() +# item = await {name}_repo.create(body) +# return item +''' + + file_path.write_text(repo_template, encoding="utf-8") + console.print(f"โœ… [green]Repository created: {file_path}") + console.print(f"๐Ÿ“ฆ Class: {class_name}Repository") + + +def _generate_service(name: str, path: str): + """Generate a new service file.""" + services_path = Path("services") + if not services_path.exists(): + services_path.mkdir() + (services_path / "__init__.py").touch() + + if path: + service_path = services_path / path + service_path.mkdir(parents=True, exist_ok=True) + (service_path / "__init__.py").touch() + file_path = service_path / f"{name}.py" + else: + file_path = services_path / f"{name}.py" + + if file_path.exists(): + console.print(f"[red]โŒ Service already exists: {file_path}") + raise typer.Exit(code=1) + + # Convert name to PascalCase for class names + class_name = "".join(word.capitalize() for word in name.replace("-", "_").split("_")) + + service_template = f'''""" +{class_name} service +Generated by RunApi CLI +""" +from typing import Optional, List, Dict, Any +from runapi import NotFoundError, ValidationError + + +class {class_name}Service: + """ + {class_name} service containing business logic. + + Services orchestrate between repositories and handle: + - Business rules and validation + - Complex operations spanning multiple repositories + - Transaction management + """ + + def __init__(self, repository=None): + """ + Initialize service with repository. + + Args: + repository: The {name} repository (inject via dependency) + """ + self.repository = repository + + async def get_all( + self, + skip: int = 0, + limit: int = 100, + **filters + ) -> List[Dict[str, Any]]: + """Get all {name}s with pagination.""" + return await self.repository.get_all(skip=skip, limit=limit, **filters) + + async def get_by_id(self, id: int) -> Dict[str, Any]: + """Get a {name} by ID.""" + item = await self.repository.get(id) + if not item: + raise NotFoundError(f"{class_name} with id {{id}} not found") + return item + + async def create(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Create a new {name}. + + Add business validation logic here. + """ + # Example validation + # if not data.get("name"): + # raise ValidationError("Name is required") + + return await self.repository.create(data) + + async def update(self, id: int, data: Dict[str, Any]) -> Dict[str, Any]: + """Update an existing {name}.""" + # Verify exists + existing = await self.repository.get(id) + if not existing: + raise NotFoundError(f"{class_name} with id {{id}} not found") + + # Add business validation here + # ... + + result = await self.repository.update(id, data) + return result + + async def delete(self, id: int) -> bool: + """Delete a {name}.""" + existing = await self.repository.get(id) + if not existing: + raise NotFoundError(f"{class_name} with id {{id}} not found") + + # Add business rules for deletion here + # Example: Check if item can be deleted + # if existing.get("is_protected"): + # raise ValidationError("Cannot delete protected item") + + return await self.repository.delete(id) + + async def count(self, **filters) -> int: + """Count {name}s matching filters.""" + return await self.repository.count(**filters) + + +# Example usage in routes: +# +# from services.{name} import {class_name}Service +# from repositories.{name} import {class_name}Repository +# +# # Setup (consider using dependency injection) +# {name}_repo = {class_name}Repository() +# {name}_service = {class_name}Service({name}_repo) +# +# async def get(): +# return await {name}_service.get_all() +# +# async def post(request: Request): +# body = await request.json() +# return await {name}_service.create(body) +''' + + file_path.write_text(service_template, encoding="utf-8") + console.print(f"โœ… [green]Service created: {file_path}") + console.print(f"๐Ÿ“ฆ Class: {class_name}Service") + + def _generate_main(): """Generate main.py file.""" main_path = Path("main.py") if main_path.exists(): if not typer.confirm("main.py already exists. Overwrite?"): raise typer.Exit() - + main_content = '''""" RunApi Application Entry Point """ @@ -507,8 +842,8 @@ def _generate_main(): if __name__ == "__main__": runapi_app.run() ''' - - main_path.write_text(main_content, encoding='utf-8') + + main_path.write_text(main_content, encoding="utf-8") console.print("โœ… [green]main.py created successfully!") @@ -517,13 +852,13 @@ def _generate_middleware(name: str): middleware_path = Path("middleware") middleware_path.mkdir(exist_ok=True) (middleware_path / "__init__.py").touch() - + file_path = middleware_path / f"{name}.py" - + if file_path.exists(): console.print(f"[red]โŒ Middleware already exists: {file_path}") raise typer.Exit(code=1) - + middleware_template = f'''""" {name.title()} middleware Generated by RunApi CLI @@ -533,31 +868,31 @@ def _generate_middleware(name: str): class {name.title()}Middleware(RunApiMiddleware): """Custom {name} middleware.""" - + def __init__(self, app, **kwargs): super().__init__(app) # Initialize middleware parameters pass - + async def dispatch(self, request: Request, call_next: Callable) -> Response: """Process request and response.""" # Pre-processing print(f"Processing request: {{request.method}} {{request.url.path}}") - + # Call next middleware/route response = await call_next(request) - + # Post-processing print(f"Response status: {{response.status_code}}") - + return response # Usage in main.py: -# from middleware.{name} import {name.title()}Middleware +# from middleware.{name} import {name.title()}Middleware # runapi_app.add_middleware({name.title()}Middleware) ''' - - file_path.write_text(middleware_template, encoding='utf-8') + + file_path.write_text(middleware_template, encoding="utf-8") console.print(f"โœ… [green]Middleware created: {file_path}") @@ -568,76 +903,143 @@ def routes(): if not routes_path.exists(): console.print("[red]โŒ No routes directory found") raise typer.Exit(code=1) - + console.print("๐Ÿ“‹ [bold blue]Available Routes[/bold blue]\n") - + table = Table(show_header=True, header_style="bold blue") table.add_column("Method") - table.add_column("Path") + table.add_column("Path") table.add_column("File") - + for route_file in routes_path.rglob("*.py"): if route_file.name == "__init__.py": continue - + # Generate URL path relative_path = route_file.relative_to(routes_path) url_path = "/" + str(relative_path).replace("\\", "/").replace(".py", "") - + if url_path.endswith("/index"): url_path = url_path[:-6] or "/" - + # Check for dynamic routes if "[" in url_path and "]" in url_path: # Convert [id] to {id} import re - url_path = re.sub(r'\[([^\]]+)\]', r'{\1}', url_path) - + + url_path = re.sub(r"\[([^\]]+)\]", r"{\1}", url_path) + # Read file to detect HTTP methods # Read file to detect HTTP methods try: content = route_file.read_text() import ast + try: tree = ast.parse(content) methods = [] for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - if node.name in ["get", "post", "put", "delete", "patch", "head", "options"]: + if node.name in [ + "get", + "post", + "put", + "delete", + "patch", + "head", + "options", + ]: methods.append(node.name.upper()) - + # Deduplicate and sort - methods = sorted(list(set(methods))) + methods = sorted(set(methods)) methods_str = ", ".join(methods) if methods else "No methods found" table.add_row(methods_str, url_path, str(relative_path)) except SyntaxError: table.add_row("Error", url_path, "Syntax Error in file") - + except Exception as e: table.add_row("Error", url_path, f"Error reading file: {e}") - + console.print(table) -@app.command() +@app.command() +def schemas(): + """List all available schemas in the project.""" + schemas_path = Path("schemas") + if not schemas_path.exists(): + console.print("[yellow]โš ๏ธ No schemas directory found") + console.print("[dim]๐Ÿ’ก Tip: Run 'runapi generate schema user' to create your first schema") + raise typer.Exit(code=0) + + console.print("๐Ÿ“‹ [bold blue]Available Schemas[/bold blue]\n") + + table = Table(show_header=True, header_style="bold blue") + table.add_column("File") + table.add_column("Classes") + table.add_column("Module Path") + + schema_count = 0 + for schema_file in schemas_path.rglob("*.py"): + if schema_file.name == "__init__.py": + continue + + relative_path = schema_file.relative_to(schemas_path) + module_path = "schemas." + str(relative_path).replace("\\", ".").replace("/", ".").replace( + ".py", "" + ) + + # Read file to detect Pydantic model classes + try: + content = schema_file.read_text() + import ast + + try: + tree = ast.parse(content) + classes = [] + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + # Check if it's likely a Pydantic model (inherits from something) + if node.bases: + classes.append(node.name) + + classes_str = ", ".join(classes) if classes else "No classes found" + table.add_row(str(relative_path), classes_str, module_path) + schema_count += 1 + except SyntaxError: + table.add_row(str(relative_path), "[red]Syntax Error", module_path) + + except Exception as e: + table.add_row(str(relative_path), f"[red]Error: {e}", module_path) + + if schema_count > 0: + console.print(table) + console.print(f"\n[dim]Total: {schema_count} schema file(s)") + else: + console.print("[yellow]No schema files found in schemas/ directory") + console.print("[dim]๐Ÿ’ก Tip: Run 'runapi generate schema user' to create a schema") + + +@app.command() def info(): """Show project information and configuration.""" console.print("โ„น๏ธ [bold blue]RunApi Project Information[/bold blue]\n") - + # Load config config = load_config() - + # Project info info_table = Table(show_header=False, box=None) info_table.add_row("๐Ÿ“ Project Directory:", str(Path.cwd())) info_table.add_row("๐Ÿ RunApi Version:", "0.1.0") - + # Check main.py if Path("main.py").exists(): info_table.add_row("๐Ÿ“„ Entry Point:", "main.py โœ…") else: info_table.add_row("๐Ÿ“„ Entry Point:", "main.py โŒ") - + # Check routes routes_path = Path("routes") if routes_path.exists(): @@ -645,22 +1047,32 @@ def info(): info_table.add_row("๐Ÿ›ฃ๏ธ Routes:", f"{route_count} files") else: info_table.add_row("๐Ÿ›ฃ๏ธ Routes:", "No routes directory") - + + # Check schemas + schemas_path = Path("schemas") + if schemas_path.exists(): + schema_count = len([f for f in schemas_path.rglob("*.py") if f.name != "__init__.py"]) + info_table.add_row("๐Ÿ“ฆ Schemas:", f"{schema_count} files") + else: + info_table.add_row("๐Ÿ“ฆ Schemas:", "No schemas directory") + console.print(info_table) console.print() - + # Configuration config_table = Table(show_header=True, header_style="bold blue", title="Configuration") config_table.add_column("Setting") config_table.add_column("Value") - + config_table.add_row("Debug", "โœ… Enabled" if config.debug else "โŒ Disabled") config_table.add_row("Host", config.host) config_table.add_row("Port", str(config.port)) config_table.add_row("CORS Origins", ", ".join(config.cors_origins)) - config_table.add_row("Rate Limiting", "โœ… Enabled" if config.rate_limit_enabled else "โŒ Disabled") + config_table.add_row( + "Rate Limiting", "โœ… Enabled" if config.rate_limit_enabled else "โŒ Disabled" + ) config_table.add_row("Log Level", config.log_level) - + console.print(config_table) @@ -668,5 +1080,6 @@ def main(): """Main entry point for CLI.""" app() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/runapi/config.py b/runapi/config.py index f5601fa..ef58a82 100644 --- a/runapi/config.py +++ b/runapi/config.py @@ -1,125 +1,127 @@ import os -from typing import List, Optional, Dict, Any from pathlib import Path +from typing import Any, Dict, List, Optional class RunApiConfig: """Configuration management for RunApi framework.""" - + def __init__(self, env_file: Optional[str] = None): self.env_file = env_file or ".env" self._load_env_file() - + # Core settings self.debug: bool = self._get_bool("DEBUG", True) self.host: str = self._get_str("HOST", "127.0.0.1") self.port: int = self._get_int("PORT", 8000) self.reload: bool = self._get_bool("RELOAD", True) - + # Security settings self.secret_key: str = self._get_str("SECRET_KEY", "dev-secret-key-change-in-production") self.allowed_hosts: List[str] = self._get_list("ALLOWED_HOSTS", ["*"]) - + # CORS settings self.cors_origins: List[str] = self._get_list("CORS_ORIGINS", ["*"]) self.cors_credentials: bool = self._get_bool("CORS_CREDENTIALS", True) self.cors_methods: List[str] = self._get_list("CORS_METHODS", ["*"]) self.cors_headers: List[str] = self._get_list("CORS_HEADERS", ["*"]) - + # Database settings self.database_url: Optional[str] = self._get_str("DATABASE_URL") self.database_echo: bool = self._get_bool("DATABASE_ECHO", False) - + # Cache settings self.cache_backend: str = self._get_str("CACHE_BACKEND", "memory") self.redis_url: Optional[str] = self._get_str("REDIS_URL") self.cache_ttl: int = self._get_int("CACHE_TTL", 300) # 5 minutes default - + # Rate limiting self.rate_limit_enabled: bool = self._get_bool("RATE_LIMIT_ENABLED", False) self.rate_limit_calls: int = self._get_int("RATE_LIMIT_CALLS", 100) self.rate_limit_period: int = self._get_int("RATE_LIMIT_PERIOD", 60) # 1 minute - + # Logging self.log_level: str = self._get_str("LOG_LEVEL", "INFO") - self.log_format: str = self._get_str("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") - + self.log_format: str = self._get_str( + "LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + # Static files self.static_files_enabled: bool = self._get_bool("STATIC_FILES_ENABLED", True) self.static_files_path: str = self._get_str("STATIC_FILES_PATH", "static") self.static_files_url: str = self._get_str("STATIC_FILES_URL", "/static") - + # Upload settings self.max_upload_size: int = self._get_int("MAX_UPLOAD_SIZE", 10 * 1024 * 1024) # 10MB self.upload_path: str = self._get_str("UPLOAD_PATH", "uploads") - + # JWT settings self.jwt_algorithm: str = self._get_str("JWT_ALGORITHM", "HS256") self.jwt_expiry: int = self._get_int("JWT_EXPIRY", 3600) # 1 hour self.jwt_refresh_expiry: int = self._get_int("JWT_REFRESH_EXPIRY", 86400) # 24 hours - + # Custom settings self.custom: Dict[str, Any] = {} - + def _load_env_file(self): """Load environment variables from .env file if it exists.""" env_path = Path(self.env_file) if env_path.exists(): - with open(env_path, 'r', encoding='utf-8') as f: + with open(env_path, encoding="utf-8") as f: for line in f: line = line.strip() - if line and not line.startswith('#') and '=' in line: - key, value = line.split('=', 1) + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) key = key.strip() value = value.strip().strip('"').strip("'") os.environ[key] = value - + def _get_str(self, key: str, default: Optional[str] = None) -> Optional[str]: """Get string value from environment.""" return os.getenv(key, default) - + def _get_bool(self, key: str, default: bool = False) -> bool: """Get boolean value from environment.""" value = os.getenv(key, str(default)).lower() - return value in ('true', '1', 'yes', 'on') - + return value in ("true", "1", "yes", "on") + def _get_int(self, key: str, default: int = 0) -> int: """Get integer value from environment.""" try: return int(os.getenv(key, str(default))) except ValueError: return default - + def _get_float(self, key: str, default: float = 0.0) -> float: """Get float value from environment.""" try: return float(os.getenv(key, str(default))) except ValueError: return default - + def _get_list(self, key: str, default: List[str] = None) -> List[str]: """Get list value from environment (comma-separated).""" if default is None: default = [] - + value = os.getenv(key) if not value: return default - - return [item.strip() for item in value.split(',') if item.strip()] - + + return [item.strip() for item in value.split(",") if item.strip()] + def get(self, key: str, default: Any = None) -> Any: """Get custom configuration value.""" return self.custom.get(key, default) - + def set(self, key: str, value: Any) -> None: """Set custom configuration value.""" self.custom[key] = value - + def is_development(self) -> bool: """Check if running in development mode.""" return self.debug - + def is_production(self) -> bool: """Check if running in production mode.""" return not self.debug @@ -138,4 +140,4 @@ def load_config(env_file: Optional[str] = None) -> RunApiConfig: """Load configuration with optional custom env file.""" global config config = RunApiConfig(env_file) - return config \ No newline at end of file + return config diff --git a/runapi/core.py b/runapi/core.py index b66ee16..6b05a41 100644 --- a/runapi/core.py +++ b/runapi/core.py @@ -1,47 +1,51 @@ # runapi/core.py -from fastapi import FastAPI, APIRouter -from fastapi.staticfiles import StaticFiles -from pathlib import Path import importlib.util import logging -from typing import List, Optional, Type, Dict, Any +from pathlib import Path +from typing import List, Optional, Type -from .config import get_config, RunApiConfig +from fastapi import APIRouter, FastAPI +from fastapi.staticfiles import StaticFiles + +from .config import RunApiConfig, get_config +from .errors import setup_error_handlers from .middleware import ( - CORSMiddleware, - RequestLoggingMiddleware, - RateLimitMiddleware, AuthMiddleware, - SecurityHeadersMiddleware, CompressionMiddleware, - RunApiMiddleware + RateLimitMiddleware, + RequestLoggingMiddleware, + RunApiMiddleware, + SecurityHeadersMiddleware, ) -from .errors import setup_error_handlers +from .schemas import SchemaRegistry, load_schemas class RunApiApp: """Enhanced RunApi application class with configuration and middleware support.""" - + def __init__(self, config: Optional[RunApiConfig] = None, **fastapi_kwargs): self.config = config or get_config() self.app = self._create_fastapi_app(**fastapi_kwargs) self.middleware_stack: List[Type[RunApiMiddleware]] = [] - + # Setup logging self._setup_logging() - + # Setup default middleware self._setup_default_middleware() - + # Setup error handlers self._setup_error_handlers() - + # Load routes self._load_routes() - + + # Load schemas + self._load_schemas() + # Setup static files self._setup_static_files() - + def _create_fastapi_app(self, **kwargs) -> FastAPI: """Create FastAPI application with configuration.""" # Merge config with kwargs @@ -52,52 +56,52 @@ def _create_fastapi_app(self, **kwargs) -> FastAPI: "version": kwargs.get("version", "1.0.0"), } app_kwargs.update(kwargs) - + return FastAPI(**app_kwargs) - + def _setup_logging(self): """Setup logging configuration.""" logging.basicConfig( - level=getattr(logging, self.config.log_level.upper()), - format=self.config.log_format + level=getattr(logging, self.config.log_level.upper()), format=self.config.log_format ) self.logger = logging.getLogger("runapi") - + def _setup_default_middleware(self): """Setup default middleware based on configuration.""" # CORS middleware if self.config.cors_origins: from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware + self.app.add_middleware( FastAPICORSMiddleware, allow_origins=self.config.cors_origins, allow_credentials=self.config.cors_credentials, allow_methods=self.config.cors_methods, - allow_headers=self.config.cors_headers + allow_headers=self.config.cors_headers, ) - + # Rate limiting middleware if self.config.rate_limit_enabled: self.app.add_middleware( RateLimitMiddleware, calls=self.config.rate_limit_calls, - period=self.config.rate_limit_period + period=self.config.rate_limit_period, ) - + # Security headers middleware self.app.add_middleware(SecurityHeadersMiddleware) - + # Request logging middleware if self.config.debug: self.app.add_middleware(RequestLoggingMiddleware, logger=self.logger) - + # Compression middleware self.app.add_middleware(CompressionMiddleware) - + def _setup_error_handlers(self): """Setup error handlers for the application.""" setup_error_handlers(self.app, self.logger, self.config.debug) - + def _setup_static_files(self): """Setup static file serving.""" if self.config.static_files_enabled: @@ -106,19 +110,25 @@ def _setup_static_files(self): self.app.mount( self.config.static_files_url, StaticFiles(directory=str(static_path)), - name="static" + name="static", ) - + + def _load_schemas(self): + """Load schemas from project's schemas/ folder.""" + schemas_path = Path("schemas") + if schemas_path.exists(): + loaded = load_schemas(schemas_path, self.logger) + self.logger.debug(f"Loaded {len(loaded)} schema modules") + def _load_routes(self): """Load routes from project's routes/ folder.""" routes_path = Path("routes") if routes_path.exists(): self._load_routes_recursive(routes_path) - + def _load_routes_recursive(self, routes_dir: Path, prefix: str = ""): """Recursively load routes from directory structure.""" - router = APIRouter(prefix=prefix) - + for item in routes_dir.iterdir(): if item.is_dir(): # Skip hidden directories and __pycache__ @@ -130,35 +140,42 @@ def _load_routes_recursive(self, routes_dir: Path, prefix: str = ""): self._load_routes_recursive(item, new_prefix) elif item.suffix == ".py" and item.name != "__init__.py": self._load_route_file(item, prefix) - + def _load_route_file(self, route_file: Path, prefix: str = ""): """Load a single route file.""" try: route_name = route_file.stem - module_name = f"routes.{prefix.replace('/', '.')}.{route_name}".strip(".") - + prefix_part = prefix.replace("/", ".").strip(".") + module_name = ( + f"routes.{prefix_part}.{route_name}" if prefix_part else f"routes.{route_name}" + ) + spec = importlib.util.spec_from_file_location(module_name, route_file) + if spec is None or spec.loader is None: + self.logger.warning(f"Could not load spec for route {route_file}") + return + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Extract router or create one route_router = getattr(module, "router", APIRouter()) - + # Map HTTP methods to functions for method in ["get", "post", "put", "delete", "patch", "head", "options", "trace"]: if hasattr(module, method): path = self._get_route_path(route_name) getattr(route_router, method)(path)(getattr(module, method)) - + # Include the router with proper prefix final_prefix = prefix if prefix else "" self.app.include_router(route_router, prefix=final_prefix) - + self.logger.debug(f"Loaded route: {route_file} with prefix: {final_prefix}") - + except Exception as e: self.logger.error(f"Failed to load route {route_file}: {e}") - + def _get_route_path(self, route_name: str) -> str: """Convert route name to FastAPI path.""" if route_name == "index": @@ -173,41 +190,54 @@ def _get_route_path(self, route_name: str) -> str: return f"/{{{param_name}:path}}" else: return f"/{route_name}" - + def add_middleware(self, middleware_class: Type[RunApiMiddleware], **kwargs): """Add custom middleware to the application.""" self.app.add_middleware(middleware_class, **kwargs) self.middleware_stack.append(middleware_class) self.logger.debug(f"Added middleware: {middleware_class.__name__}") - - def add_auth_middleware(self, protected_paths: List[str] = None, excluded_paths: List[str] = None): + + def add_auth_middleware( + self, protected_paths: List[str] = None, excluded_paths: List[str] = None + ): """Add JWT authentication middleware.""" - if not self.config.secret_key or self.config.secret_key == "dev-secret-key-change-in-production": + if ( + not self.config.secret_key + or self.config.secret_key == "dev-secret-key-change-in-production" + ): self.logger.warning("Using default secret key. Change SECRET_KEY in production!") - + self.add_middleware( AuthMiddleware, secret_key=self.config.secret_key, protected_paths=protected_paths, - excluded_paths=excluded_paths + excluded_paths=excluded_paths, ) - + def get_app(self) -> FastAPI: """Get the underlying FastAPI application.""" return self.app - + + def get_schema(self, name: str): + """Get a registered schema by name.""" + return SchemaRegistry.get(name) + + def list_schemas(self) -> List[str]: + """List all registered schema names.""" + return list(SchemaRegistry.get_all().keys()) + def run(self, host: str = None, port: int = None, **uvicorn_kwargs): """Run the application with uvicorn.""" import uvicorn - + run_kwargs = { "host": host or self.config.host, "port": port or self.config.port, "reload": self.config.reload, "log_level": self.config.log_level.lower(), - **uvicorn_kwargs + **uvicorn_kwargs, } - + self.logger.info(f"Starting RunApi server on {run_kwargs['host']}:{run_kwargs['port']}") uvicorn.run(self.app, **run_kwargs) @@ -220,4 +250,4 @@ def create_app(config: Optional[RunApiConfig] = None, **kwargs) -> FastAPI: def create_runapi_app(config: Optional[RunApiConfig] = None, **kwargs) -> RunApiApp: """Create a RunApi application instance.""" - return RunApiApp(config=config, **kwargs) \ No newline at end of file + return RunApiApp(config=config, **kwargs) diff --git a/runapi/errors.py b/runapi/errors.py index 8861d3c..0ad329d 100644 --- a/runapi/errors.py +++ b/runapi/errors.py @@ -1,24 +1,25 @@ """ Error handling system for RunApi framework """ + +import logging import traceback -from typing import Dict, Any, Optional, Union +from typing import Any, Dict, Optional + from fastapi import HTTPException, Request from fastapi.responses import JSONResponse -from fastapi.exception_handlers import http_exception_handler from starlette.exceptions import HTTPException as StarletteHTTPException -import logging class RunApiException(Exception): """Base exception class for RunApi framework.""" - + def __init__( self, message: str, status_code: int = 500, details: Optional[Dict[str, Any]] = None, - error_code: Optional[str] = None + error_code: Optional[str] = None, ): self.message = message self.status_code = status_code @@ -29,188 +30,184 @@ def __init__( class ValidationError(RunApiException): """Raised when request validation fails.""" - + def __init__(self, message: str = "Validation failed", details: Dict[str, Any] = None): super().__init__(message, 400, details, "VALIDATION_ERROR") class AuthenticationError(RunApiException): """Raised when authentication fails.""" - + def __init__(self, message: str = "Authentication required", details: Dict[str, Any] = None): super().__init__(message, 401, details, "AUTHENTICATION_ERROR") class AuthorizationError(RunApiException): """Raised when authorization fails.""" - + def __init__(self, message: str = "Insufficient permissions", details: Dict[str, Any] = None): super().__init__(message, 403, details, "AUTHORIZATION_ERROR") class NotFoundError(RunApiException): """Raised when a resource is not found.""" - + def __init__(self, message: str = "Resource not found", details: Dict[str, Any] = None): super().__init__(message, 404, details, "NOT_FOUND_ERROR") class ConflictError(RunApiException): """Raised when there's a conflict with the current state.""" - - def __init__(self, message: str = "Conflict with current state", details: Dict[str, Any] = None): + + def __init__( + self, message: str = "Conflict with current state", details: Dict[str, Any] = None + ): super().__init__(message, 409, details, "CONFLICT_ERROR") class RateLimitError(RunApiException): """Raised when rate limit is exceeded.""" - + def __init__(self, message: str = "Rate limit exceeded", details: Dict[str, Any] = None): super().__init__(message, 429, details, "RATE_LIMIT_ERROR") class ServerError(RunApiException): """Raised when an internal server error occurs.""" - + def __init__(self, message: str = "Internal server error", details: Dict[str, Any] = None): super().__init__(message, 500, details, "SERVER_ERROR") class DatabaseError(RunApiException): """Raised when database operations fail.""" - + def __init__(self, message: str = "Database operation failed", details: Dict[str, Any] = None): super().__init__(message, 500, details, "DATABASE_ERROR") class ExternalServiceError(RunApiException): """Raised when external service calls fail.""" - + def __init__(self, message: str = "External service error", details: Dict[str, Any] = None): super().__init__(message, 502, details, "EXTERNAL_SERVICE_ERROR") class ErrorResponse: """Standard error response format.""" - + def __init__( self, message: str, status_code: int = 500, error_code: str = "UNKNOWN_ERROR", details: Optional[Dict[str, Any]] = None, - request_id: Optional[str] = None + request_id: Optional[str] = None, ): self.message = message self.status_code = status_code self.error_code = error_code self.details = details or {} self.request_id = request_id - + def to_dict(self) -> Dict[str, Any]: """Convert error response to dictionary.""" response = { "error": { "code": self.error_code, "message": self.message, - "status_code": self.status_code + "status_code": self.status_code, } } - + if self.details: response["error"]["details"] = self.details - + if self.request_id: response["error"]["request_id"] = self.request_id - + return response - + def to_json_response(self) -> JSONResponse: """Convert to FastAPI JSONResponse.""" - return JSONResponse( - status_code=self.status_code, - content=self.to_dict() - ) + return JSONResponse(status_code=self.status_code, content=self.to_dict()) class ErrorHandler: """Error handler with logging and formatting.""" - + def __init__(self, logger: Optional[logging.Logger] = None, debug: bool = False): self.logger = logger or logging.getLogger(__name__) self.debug = debug - + def handle_runapi_exception(self, request: Request, exc: RunApiException) -> JSONResponse: """Handle RunApi custom exceptions.""" self.logger.warning( f"RunApi exception: {exc.error_code} - {exc.message}", - extra={"status_code": exc.status_code, "details": exc.details} + extra={"status_code": exc.status_code, "details": exc.details}, ) - + error_response = ErrorResponse( message=exc.message, status_code=exc.status_code, error_code=exc.error_code, details=exc.details, - request_id=getattr(request.state, "request_id", None) + request_id=getattr(request.state, "request_id", None), ) - + return error_response.to_json_response() - + def handle_http_exception(self, request: Request, exc: HTTPException) -> JSONResponse: """Handle FastAPI HTTP exceptions.""" self.logger.warning(f"HTTP exception: {exc.status_code} - {exc.detail}") - + error_response = ErrorResponse( message=str(exc.detail), status_code=exc.status_code, error_code="HTTP_ERROR", - request_id=getattr(request.state, "request_id", None) + request_id=getattr(request.state, "request_id", None), ) - + return error_response.to_json_response() - + def handle_validation_exception(self, request: Request, exc: Exception) -> JSONResponse: """Handle Pydantic validation exceptions.""" self.logger.warning(f"Validation exception: {str(exc)}") - + details = {} if hasattr(exc, "errors"): details = {"validation_errors": exc.errors()} - + error_response = ErrorResponse( message="Request validation failed", status_code=422, error_code="VALIDATION_ERROR", details=details, - request_id=getattr(request.state, "request_id", None) + request_id=getattr(request.state, "request_id", None), ) - + return error_response.to_json_response() - + def handle_generic_exception(self, request: Request, exc: Exception) -> JSONResponse: """Handle generic exceptions.""" - self.logger.error( - f"Unhandled exception: {type(exc).__name__} - {str(exc)}", - exc_info=True - ) - + self.logger.error(f"Unhandled exception: {type(exc).__name__} - {str(exc)}", exc_info=True) + details = {} if self.debug: details = { "exception_type": type(exc).__name__, "exception_message": str(exc), - "traceback": traceback.format_exc().split('\n') + "traceback": traceback.format_exc().split("\n"), } - + error_response = ErrorResponse( message="An unexpected error occurred" if not self.debug else str(exc), status_code=500, error_code="INTERNAL_ERROR", details=details, - request_id=getattr(request.state, "request_id", None) + request_id=getattr(request.state, "request_id", None), ) - + return error_response.to_json_response() @@ -221,29 +218,33 @@ def handle_generic_exception(self, request: Request, exc: Exception) -> JSONResp def setup_error_handlers(app, logger: Optional[logging.Logger] = None, debug: bool = False): """Setup error handlers for a FastAPI application.""" handler = ErrorHandler(logger, debug) - + @app.exception_handler(RunApiException) async def runapi_exception_handler(request: Request, exc: RunApiException): return handler.handle_runapi_exception(request, exc) - + @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): return handler.handle_http_exception(request, exc) - + @app.exception_handler(StarletteHTTPException) async def starlette_http_exception_handler(request: Request, exc: StarletteHTTPException): - return handler.handle_http_exception(request, HTTPException(status_code=exc.status_code, detail=exc.detail)) - + return handler.handle_http_exception( + request, HTTPException(status_code=exc.status_code, detail=exc.detail) + ) + # Handle validation errors from Pydantic try: from pydantic import ValidationError as PydanticValidationError - + @app.exception_handler(PydanticValidationError) - async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError): + async def pydantic_validation_exception_handler( + request: Request, exc: PydanticValidationError + ): return handler.handle_validation_exception(request, exc) except ImportError: pass - + # Handle generic exceptions @app.exception_handler(Exception) async def generic_exception_handler(request: Request, exc: Exception): @@ -282,17 +283,11 @@ def raise_server_error(message: str = "Internal server error"): def create_error_response( - message: str, - status_code: int = 500, - error_code: str = "ERROR", - details: Dict[str, Any] = None + message: str, status_code: int = 500, error_code: str = "ERROR", details: Dict[str, Any] = None ) -> JSONResponse: """Create a standard error response.""" error_response = ErrorResponse( - message=message, - status_code=status_code, - error_code=error_code, - details=details + message=message, status_code=status_code, error_code=error_code, details=details ) return error_response.to_json_response() @@ -323,16 +318,22 @@ def conflict(message: str = "Conflict", details: Dict[str, Any] = None) -> JSONR return create_error_response(message, 409, "CONFLICT", details) -def unprocessable_entity(message: str = "Unprocessable Entity", details: Dict[str, Any] = None) -> JSONResponse: +def unprocessable_entity( + message: str = "Unprocessable Entity", details: Dict[str, Any] = None +) -> JSONResponse: """Return a 422 Unprocessable Entity response.""" return create_error_response(message, 422, "UNPROCESSABLE_ENTITY", details) -def rate_limited(message: str = "Rate limit exceeded", details: Dict[str, Any] = None) -> JSONResponse: +def rate_limited( + message: str = "Rate limit exceeded", details: Dict[str, Any] = None +) -> JSONResponse: """Return a 429 Too Many Requests response.""" return create_error_response(message, 429, "RATE_LIMITED", details) -def internal_error(message: str = "Internal server error", details: Dict[str, Any] = None) -> JSONResponse: +def internal_error( + message: str = "Internal server error", details: Dict[str, Any] = None +) -> JSONResponse: """Return a 500 Internal Server Error response.""" - return create_error_response(message, 500, "INTERNAL_ERROR", details) \ No newline at end of file + return create_error_response(message, 500, "INTERNAL_ERROR", details) diff --git a/runapi/middleware.py b/runapi/middleware.py index 9b6136a..c277121 100644 --- a/runapi/middleware.py +++ b/runapi/middleware.py @@ -1,19 +1,19 @@ -import time +import asyncio import json import logging -from typing import Callable, Dict, Any, List, Optional -from fastapi import Request, Response, HTTPException -from starlette.middleware.base import BaseHTTPMiddleware +import time +from typing import Any, Callable, Dict, List, Optional + +from fastapi import Request, Response from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware -from fastapi.responses import JSONResponse -from collections import defaultdict -import asyncio from fastapi.middleware.gzip import GZipMiddleware +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware class RunApiMiddleware(BaseHTTPMiddleware): """Base middleware class for RunApi framework.""" - + async def dispatch(self, request: Request, call_next: Callable) -> Response: """Override this method in subclasses.""" return await call_next(request) @@ -21,42 +21,42 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: class RequestLoggingMiddleware(RunApiMiddleware): """Middleware for logging HTTP requests and responses.""" - + def __init__(self, app, logger: Optional[logging.Logger] = None): super().__init__(app) self.logger = logger or logging.getLogger(__name__) - + async def dispatch(self, request: Request, call_next: Callable) -> Response: start_time = time.time() - + # Log request self.logger.info(f"Request: {request.method} {request.url}") - + response = await call_next(request) - + # Calculate processing time process_time = time.time() - start_time - + # Log response self.logger.info( f"Response: {response.status_code} - " f"Time: {process_time:.4f}s - " f"Size: {response.headers.get('content-length', 'unknown')}" ) - + response.headers["X-Process-Time"] = str(process_time) return response class RateLimitMiddleware(RunApiMiddleware): """Rate limiting middleware using Fixed Window Counter (O(1)).""" - + def __init__( - self, + self, app, calls: int = 100, period: int = 60, # seconds - key_func: Optional[Callable[[Request], str]] = None + key_func: Optional[Callable[[Request], str]] = None, ): super().__init__(app) self.calls = calls @@ -65,18 +65,18 @@ def __init__( # Store: {key: [count, start_time]} self.requests: Dict[str, List[float]] = {} self.lock = asyncio.Lock() - + def _default_key_func(self, request: Request) -> str: """Default key function using client IP.""" forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() return request.client.host if request.client else "unknown" - + async def dispatch(self, request: Request, call_next: Callable) -> Response: key = self.key_func(request) current_time = time.time() - + async with self.lock: # Get current window state if key not in self.requests: @@ -85,7 +85,7 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: reset_time = current_time + self.period else: count, start_time = self.requests[key] - + if current_time > start_time + self.period: # New window self.requests[key] = [1, current_time] @@ -98,28 +98,30 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: status_code=429, content={ "error": "Rate limit exceeded", - "message": f"Maximum {self.calls} requests per {self.period} seconds" + "message": f"Maximum {self.calls} requests per {self.period} seconds", + }, + headers={ + "Retry-After": str(int(start_time + self.period - current_time)) }, - headers={"Retry-After": str(int(start_time + self.period - current_time))} ) - + self.requests[key][0] += 1 remaining = self.calls - self.requests[key][0] reset_time = start_time + self.period - + response = await call_next(request) - + # Add rate limit headers response.headers["X-RateLimit-Limit"] = str(self.calls) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(int(reset_time)) - + return response class AuthMiddleware(RunApiMiddleware): """JWT-based authentication middleware.""" - + def __init__( self, app, @@ -128,7 +130,7 @@ def __init__( protected_paths: Optional[List[str]] = None, excluded_paths: Optional[List[str]] = None, header_name: str = "Authorization", - token_prefix: str = "Bearer " + token_prefix: str = "Bearer ", ): super().__init__(app) self.secret_key = secret_key @@ -137,149 +139,152 @@ def __init__( self.excluded_paths = excluded_paths or ["/docs", "/redoc", "/openapi.json"] self.header_name = header_name self.token_prefix = token_prefix - + def _is_protected_path(self, path: str) -> bool: """Check if path requires authentication.""" # If no protected paths specified, protect all except excluded if not self.protected_paths: return path not in self.excluded_paths - + # Check if path matches any protected pattern for pattern in self.protected_paths: if path.startswith(pattern): return True return False - + def _extract_token(self, request: Request) -> Optional[str]: """Extract JWT token from request headers.""" auth_header = request.headers.get(self.header_name) if not auth_header or not auth_header.startswith(self.token_prefix): return None - - return auth_header[len(self.token_prefix):].strip() - + + return auth_header[len(self.token_prefix) :].strip() + def _verify_token(self, token: str) -> Optional[Dict[str, Any]]: """Verify JWT token and return payload.""" try: # Note: In real implementation, you'd use python-jose or similar # This is a simplified version import base64 - import hmac import hashlib - - parts = token.split('.') + import hmac + + parts = token.split(".") if len(parts) != 3: return None - + header, payload, signature = parts - + # Verify signature (simplified) - expected_sig = base64.urlsafe_b64encode( - hmac.new( - self.secret_key.encode(), - f"{header}.{payload}".encode(), - hashlib.sha256 - ).digest() - ).decode().rstrip('=') - + expected_sig = ( + base64.urlsafe_b64encode( + hmac.new( + self.secret_key.encode(), f"{header}.{payload}".encode(), hashlib.sha256 + ).digest() + ) + .decode() + .rstrip("=") + ) + if not hmac.compare_digest(signature, expected_sig): return None - + # Decode payload - payload_data = json.loads( - base64.urlsafe_b64decode(payload + '==') - ) - + payload_data = json.loads(base64.urlsafe_b64decode(payload + "==")) + # Check expiration - if 'exp' in payload_data and payload_data['exp'] < time.time(): + if "exp" in payload_data and payload_data["exp"] < time.time(): return None - + return payload_data - + except Exception: return None - + async def dispatch(self, request: Request, call_next: Callable) -> Response: path = request.url.path - + # Skip authentication for excluded paths if not self._is_protected_path(path): return await call_next(request) - + # Extract and verify token token = self._extract_token(request) if not token: return JSONResponse( status_code=401, - content={"error": "Authentication required", "message": "Missing or invalid token"} + content={"error": "Authentication required", "message": "Missing or invalid token"}, ) - + payload = self._verify_token(token) if not payload: return JSONResponse( status_code=401, - content={"error": "Authentication failed", "message": "Invalid or expired token"} + content={"error": "Authentication failed", "message": "Invalid or expired token"}, ) - + # Add user info to request state request.state.user = payload - + return await call_next(request) class SecurityHeadersMiddleware(RunApiMiddleware): """Add security headers to responses.""" - + def __init__( self, app, include_server: bool = False, csp_policy: Optional[str] = None, - hsts_max_age: int = 31536000 # 1 year + hsts_max_age: int = 31536000, # 1 year ): super().__init__(app) self.include_server = include_server self.csp_policy = csp_policy self.hsts_max_age = hsts_max_age - + async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) - + # Remove server header if requested if not self.include_server: if "Server" in response.headers: del response.headers["Server"] - + # Add security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" - + # Add HSTS header for HTTPS if request.url.scheme == "https": - response.headers["Strict-Transport-Security"] = f"max-age={self.hsts_max_age}; includeSubDomains" - + response.headers["Strict-Transport-Security"] = ( + f"max-age={self.hsts_max_age}; includeSubDomains" + ) + # Add CSP header if policy is defined if self.csp_policy: response.headers["Content-Security-Policy"] = self.csp_policy - + return response class CompressionMiddleware(GZipMiddleware): """ Compression middleware using GZipMiddleware. - + This replaces the previous custom implementation to support streaming responses and better memory efficiency. """ + pass class CORSMiddleware: """CORS middleware wrapper for FastAPI's CORS middleware.""" - + def __init__( self, allow_origins: List[str] = None, @@ -287,7 +292,7 @@ def __init__( allow_methods: List[str] = None, allow_headers: List[str] = None, expose_headers: List[str] = None, - max_age: int = 600 + max_age: int = 600, ): self.allow_origins = allow_origins or ["*"] self.allow_credentials = allow_credentials @@ -295,7 +300,7 @@ def __init__( self.allow_headers = allow_headers or ["*"] self.expose_headers = expose_headers or [] self.max_age = max_age - + def get_middleware(self): """Get FastAPI CORS middleware instance.""" return FastAPICORSMiddleware( @@ -304,7 +309,7 @@ def get_middleware(self): allow_methods=self.allow_methods, allow_headers=self.allow_headers, expose_headers=self.expose_headers, - max_age=self.max_age + max_age=self.max_age, ) @@ -315,17 +320,11 @@ def create_rate_limit_middleware(app, calls: int = 100, period: int = 60): def create_auth_middleware( - app, - secret_key: str, - protected_paths: List[str] = None, - excluded_paths: List[str] = None + app, secret_key: str, protected_paths: List[str] = None, excluded_paths: List[str] = None ): """Create authentication middleware.""" return AuthMiddleware( - app, - secret_key=secret_key, - protected_paths=protected_paths, - excluded_paths=excluded_paths + app, secret_key=secret_key, protected_paths=protected_paths, excluded_paths=excluded_paths ) @@ -335,15 +334,9 @@ def create_logging_middleware(app, logger: logging.Logger = None): def create_security_middleware( - app, - include_server: bool = False, - csp_policy: str = None, - hsts_max_age: int = 31536000 + app, include_server: bool = False, csp_policy: str = None, hsts_max_age: int = 31536000 ): """Create security headers middleware.""" return SecurityHeadersMiddleware( - app, - include_server=include_server, - csp_policy=csp_policy, - hsts_max_age=hsts_max_age - ) \ No newline at end of file + app, include_server=include_server, csp_policy=csp_policy, hsts_max_age=hsts_max_age + ) diff --git a/runapi/repository.py b/runapi/repository.py new file mode 100644 index 0000000..4fe1a5c --- /dev/null +++ b/runapi/repository.py @@ -0,0 +1,521 @@ +# runapi/repository.py +""" +Repository pattern for RunAPI - Data access layer abstraction. + +Provides: +- BaseRepository abstract class with common CRUD operations +- InMemoryRepository for testing and prototyping +- SQLAlchemy integration (optional, if installed) +- Generic typing for type-safe repositories +""" + +import logging +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import ( + Any, + Dict, + Generic, + List, + Optional, + Protocol, + Type, + TypeVar, + runtime_checkable, +) + +logger = logging.getLogger("runapi.repository") + +# Type variables for generic repositories +T = TypeVar("T") # Entity type +ID = TypeVar("ID") # ID type (usually int or str) +CreateSchema = TypeVar("CreateSchema") +UpdateSchema = TypeVar("UpdateSchema") + + +# ============================================================================= +# Repository Protocol (Interface) +# ============================================================================= + + +@runtime_checkable +class RepositoryProtocol(Protocol[T, ID]): + """Protocol defining the repository interface.""" + + async def get(self, id: ID) -> Optional[T]: ... + async def get_all(self, skip: int = 0, limit: int = 100) -> List[T]: ... + async def create(self, data: Dict[str, Any]) -> T: ... + async def update(self, id: ID, data: Dict[str, Any]) -> Optional[T]: ... + async def delete(self, id: ID) -> bool: ... + async def count(self) -> int: ... + + +# ============================================================================= +# Base Repository (Abstract) +# ============================================================================= + + +class BaseRepository(ABC, Generic[T, ID]): + """ + Abstract base repository with common CRUD operations. + + Inherit from this class and implement the abstract methods + to create repositories for your data sources. + + Example: + class UserRepository(BaseRepository[User, int]): + async def get(self, id: int) -> Optional[User]: + # Implementation + pass + """ + + @abstractmethod + async def get(self, id: ID) -> Optional[T]: + """Get a single entity by ID.""" + pass + + @abstractmethod + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get all entities with pagination and optional filters.""" + pass + + @abstractmethod + async def create(self, data: Dict[str, Any]) -> T: + """Create a new entity.""" + pass + + @abstractmethod + async def update(self, id: ID, data: Dict[str, Any]) -> Optional[T]: + """Update an existing entity.""" + pass + + @abstractmethod + async def delete(self, id: ID) -> bool: + """Delete an entity by ID. Returns True if deleted.""" + pass + + async def count(self, **filters) -> int: + """Count entities matching filters. Default implementation.""" + items = await self.get_all(skip=0, limit=999999, **filters) + return len(items) + + async def exists(self, id: ID) -> bool: + """Check if an entity exists.""" + return await self.get(id) is not None + + async def get_by(self, **filters) -> Optional[T]: + """Get a single entity matching filters.""" + items = await self.get_all(skip=0, limit=1, **filters) + return items[0] if items else None + + async def get_many_by(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get multiple entities matching filters.""" + return await self.get_all(skip=skip, limit=limit, **filters) + + async def create_many(self, items: List[Dict[str, Any]]) -> List[T]: + """Create multiple entities. Default implementation.""" + return [await self.create(item) for item in items] + + async def update_many(self, ids: List[ID], data: Dict[str, Any]) -> List[T]: + """Update multiple entities. Default implementation.""" + results = [] + for id in ids: + result = await self.update(id, data) + if result: + results.append(result) + return results + + async def delete_many(self, ids: List[ID]) -> int: + """Delete multiple entities. Returns count of deleted.""" + count = 0 + for id in ids: + if await self.delete(id): + count += 1 + return count + + +# ============================================================================= +# In-Memory Repository (for testing/prototyping) +# ============================================================================= + + +class InMemoryRepository(BaseRepository[Dict[str, Any], int]): + """ + In-memory repository for testing and prototyping. + + Stores entities in a dictionary. Useful for: + - Unit testing without database + - Rapid prototyping + - Development before database setup + + Example: + repo = InMemoryRepository() + user = await repo.create({"name": "John", "email": "john@example.com"}) + users = await repo.get_all() + """ + + def __init__(self): + self._storage: Dict[int, Dict[str, Any]] = {} + self._id_counter = 0 + + def _next_id(self) -> int: + self._id_counter += 1 + return self._id_counter + + async def get(self, id: int) -> Optional[Dict[str, Any]]: + """Get entity by ID.""" + return self._storage.get(id) + + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[Dict[str, Any]]: + """Get all entities with pagination and filters.""" + items = list(self._storage.values()) + + # Apply filters + if filters: + items = [item for item in items if all(item.get(k) == v for k, v in filters.items())] + + # Apply pagination + return items[skip : skip + limit] + + async def create(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new entity.""" + entity = data.copy() + entity["id"] = self._next_id() + entity["created_at"] = datetime.now(timezone.utc) + entity["updated_at"] = datetime.now(timezone.utc) + self._storage[entity["id"]] = entity + return entity + + async def update(self, id: int, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update an existing entity.""" + if id not in self._storage: + return None + + entity = self._storage[id] + for key, value in data.items(): + if value is not None: # Only update non-None values + entity[key] = value + entity["updated_at"] = datetime.now(timezone.utc) + return entity + + async def delete(self, id: int) -> bool: + """Delete an entity.""" + if id in self._storage: + del self._storage[id] + return True + return False + + async def count(self, **filters) -> int: + """Count entities.""" + if not filters: + return len(self._storage) + return len(await self.get_all(**filters)) + + def clear(self): + """Clear all entities (useful for testing).""" + self._storage.clear() + self._id_counter = 0 + + +# ============================================================================= +# Typed In-Memory Repository +# ============================================================================= + + +class TypedInMemoryRepository(BaseRepository[T, int], Generic[T]): + """ + Type-safe in-memory repository using Pydantic models. + + Example: + from pydantic import BaseModel + + class User(BaseModel): + id: Optional[int] = None + name: str + email: str + + repo = TypedInMemoryRepository(User) + user = await repo.create({"name": "John", "email": "john@example.com"}) + # user is a User instance + """ + + def __init__(self, model_class: Type[T]): + self._model_class = model_class + self._storage: Dict[int, T] = {} + self._id_counter = 0 + + def _next_id(self) -> int: + self._id_counter += 1 + return self._id_counter + + def _to_model(self, data: Dict[str, Any]) -> T: + """Convert dict to model instance.""" + return self._model_class(**data) + + def _to_dict(self, model: T) -> Dict[str, Any]: + """Convert model to dict.""" + if hasattr(model, "model_dump"): + return model.model_dump() + elif hasattr(model, "dict"): + return model.dict() + elif hasattr(model, "__dict__"): + return {k: v for k, v in model.__dict__.items() if not k.startswith("_")} + raise TypeError(f"Cannot convert {type(model).__name__} to dict") + + async def get(self, id: int) -> Optional[T]: + """Get entity by ID.""" + return self._storage.get(id) + + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get all entities with pagination and filters.""" + items = list(self._storage.values()) + + # Apply filters + if filters: + filtered = [] + for item in items: + item_dict = self._to_dict(item) + if all(item_dict.get(k) == v for k, v in filters.items()): + filtered.append(item) + items = filtered + + return items[skip : skip + limit] + + async def create(self, data: Dict[str, Any]) -> T: + """Create a new entity.""" + entity_data = data.copy() + entity_data["id"] = self._next_id() + + # Add timestamps if the model supports them + now = datetime.now(timezone.utc) + if "created_at" not in entity_data: + entity_data["created_at"] = now + if "updated_at" not in entity_data: + entity_data["updated_at"] = now + + entity = self._to_model(entity_data) + self._storage[entity_data["id"]] = entity + return entity + + async def update(self, id: int, data: Dict[str, Any]) -> Optional[T]: + """Update an existing entity.""" + if id not in self._storage: + return None + + existing = self._storage[id] + existing_dict = self._to_dict(existing) + + for key, value in data.items(): + if value is not None: + existing_dict[key] = value + + existing_dict["updated_at"] = datetime.now(timezone.utc) + + entity = self._to_model(existing_dict) + self._storage[id] = entity + return entity + + async def delete(self, id: int) -> bool: + """Delete an entity.""" + if id in self._storage: + del self._storage[id] + return True + return False + + async def count(self, **filters) -> int: + """Count entities.""" + if not filters: + return len(self._storage) + return len(await self.get_all(**filters)) + + def clear(self): + """Clear all entities.""" + self._storage.clear() + self._id_counter = 0 + + +# ============================================================================= +# SQLAlchemy Repository (Optional) +# ============================================================================= + +try: + from sqlalchemy import func, select + from sqlalchemy.ext.asyncio import AsyncSession + + SQLALCHEMY_AVAILABLE = True + + class SQLAlchemyRepository(BaseRepository[T, ID], Generic[T, ID]): + """ + SQLAlchemy-based repository for async database operations. + + Requires SQLAlchemy with async support. + + Example: + class UserRepository(SQLAlchemyRepository[User, int]): + def __init__(self, session: AsyncSession): + super().__init__(session, User) + + async with async_session() as session: + repo = UserRepository(session) + user = await repo.create({"name": "John", "email": "john@example.com"}) + """ + + def __init__(self, session: AsyncSession, model_class: Type[T]): + self.session = session + self.model_class = model_class + + async def get(self, id: ID) -> Optional[T]: + """Get entity by ID.""" + result = await self.session.get(self.model_class, id) + return result + + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get all entities with pagination and filters.""" + query = select(self.model_class) + + # Apply filters + for key, value in filters.items(): + if hasattr(self.model_class, key): + query = query.where(getattr(self.model_class, key) == value) + + query = query.offset(skip).limit(limit) + result = await self.session.execute(query) + return list(result.scalars().all()) + + async def create(self, data: Dict[str, Any]) -> T: + """Create a new entity.""" + entity = self.model_class(**data) + self.session.add(entity) + await self.session.flush() + await self.session.refresh(entity) + return entity + + async def update(self, id: ID, data: Dict[str, Any]) -> Optional[T]: + """Update an existing entity.""" + entity = await self.get(id) + if not entity: + return None + + for key, value in data.items(): + if value is not None and hasattr(entity, key): + setattr(entity, key, value) + + await self.session.flush() + await self.session.refresh(entity) + return entity + + async def delete(self, id: ID) -> bool: + """Delete an entity.""" + entity = await self.get(id) + if not entity: + return False + + await self.session.delete(entity) + await self.session.flush() + return True + + async def count(self, **filters) -> int: + """Count entities.""" + query = select(func.count()).select_from(self.model_class) + + for key, value in filters.items(): + if hasattr(self.model_class, key): + query = query.where(getattr(self.model_class, key) == value) + + result = await self.session.execute(query) + return result.scalar() or 0 + + async def commit(self): + """Commit the current transaction.""" + await self.session.commit() + + async def rollback(self): + """Rollback the current transaction.""" + await self.session.rollback() + +except ImportError: + SQLALCHEMY_AVAILABLE = False + SQLAlchemyRepository = None # type: ignore + + +# ============================================================================= +# Repository Factory +# ============================================================================= + + +class RepositoryFactory: + """ + Factory for creating repository instances. + + Useful for dependency injection and testing. + + Example: + factory = RepositoryFactory() + factory.register("users", UserRepository) + + # In route handler + user_repo = factory.create("users", session=db_session) + """ + + _repositories: Dict[str, Type[BaseRepository]] = {} + + @classmethod + def register(cls, name: str, repository_class: Type[BaseRepository]) -> None: + """Register a repository class.""" + cls._repositories[name] = repository_class + logger.debug(f"Registered repository: {name}") + + @classmethod + def get(cls, name: str) -> Optional[Type[BaseRepository]]: + """Get a registered repository class.""" + return cls._repositories.get(name) + + @classmethod + def create(cls, name: str, **kwargs) -> BaseRepository: + """Create an instance of a registered repository.""" + repo_class = cls._repositories.get(name) + if not repo_class: + raise ValueError(f"Repository '{name}' not registered") + return repo_class(**kwargs) + + @classmethod + def list_repositories(cls) -> List[str]: + """List all registered repository names.""" + return list(cls._repositories.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registered repositories.""" + cls._repositories.clear() + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + +def create_repository(model_class: Type[T], storage: str = "memory") -> BaseRepository[T, int]: + """ + Create a repository for a model. + + Args: + model_class: The Pydantic model class + storage: Storage backend ("memory" or "sqlalchemy") + + Returns: + A repository instance + """ + if storage == "memory": + return TypedInMemoryRepository(model_class) + elif storage == "sqlalchemy": + if not SQLALCHEMY_AVAILABLE: + raise ImportError( + "SQLAlchemy is required for SQLAlchemy repositories. " + "Install with: pip install sqlalchemy[asyncio]" + ) + raise ValueError( + "SQLAlchemy repositories require a session. " + "Use SQLAlchemyRepository directly with a session." + ) + else: + raise ValueError(f"Unknown storage backend: {storage}") diff --git a/runapi/schemas.py b/runapi/schemas.py new file mode 100644 index 0000000..481d889 --- /dev/null +++ b/runapi/schemas.py @@ -0,0 +1,382 @@ +# runapi/schemas.py +""" +Schema layer for RunAPI - Pydantic model management and auto-discovery. + +Provides: +- BaseSchema with common configurations +- Schema registry for auto-discovery +- Common mixins (timestamps, pagination, etc.) +- Utility functions for schema operations +""" + +import importlib.util +import logging +import sys +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +logger = logging.getLogger("runapi.schemas") + +# Type variable for generic schemas +T = TypeVar("T") + + +# ============================================================================= +# Schema Registry +# ============================================================================= + + +class SchemaRegistry: + """ + Registry for auto-discovered schemas. + + Schemas placed in the `schemas/` directory are automatically discovered + and registered here for easy access across the application. + """ + + _schemas: Dict[str, Type[BaseModel]] = {} + _modules: Dict[str, Any] = {} + _loaded: bool = False + + @classmethod + def register(cls, name: str, schema: Type[BaseModel]) -> None: + """Register a schema by name.""" + cls._schemas[name] = schema + logger.debug(f"Registered schema: {name}") + + @classmethod + def get(cls, name: str) -> Optional[Type[BaseModel]]: + """Get a schema by name.""" + return cls._schemas.get(name) + + @classmethod + def get_all(cls) -> Dict[str, Type[BaseModel]]: + """Get all registered schemas.""" + return cls._schemas.copy() + + @classmethod + def get_module(cls, name: str) -> Optional[Any]: + """Get a loaded schema module by name.""" + return cls._modules.get(name) + + @classmethod + def clear(cls) -> None: + """Clear all registered schemas (useful for testing).""" + cls._schemas.clear() + cls._modules.clear() + cls._loaded = False + + +# ============================================================================= +# Base Schema Classes +# ============================================================================= + + +class BaseSchema(BaseModel): + """ + Base schema class with sensible defaults for API development. + + Features: + - Automatic ORM mode for SQLAlchemy compatibility + - Strict validation by default + - JSON-compatible serialization + + Example: + class UserResponse(BaseSchema): + id: int + email: str + created_at: datetime + """ + + model_config = ConfigDict( + from_attributes=True, # Enable ORM mode (formerly orm_mode) + validate_assignment=True, # Validate on attribute assignment + str_strip_whitespace=True, # Strip whitespace from strings + ser_json_timedelta="iso8601", # Pydantic v2 handles datetime ISO serialization by default + ) + + +class TimestampMixin(BaseModel): + """Mixin for models with timestamp fields.""" + + created_at: Optional[datetime] = Field(default=None, description="Creation timestamp") + updated_at: Optional[datetime] = Field(default=None, description="Last update timestamp") + + +class IDMixin(BaseModel): + """Mixin for models with an ID field.""" + + id: int = Field(..., description="Unique identifier") + + +# ============================================================================= +# Common Response Schemas +# ============================================================================= + + +class MessageResponse(BaseSchema): + """Simple message response.""" + + message: str = Field(..., description="Response message") + success: bool = Field(default=True, description="Operation success status") + + +class PaginatedResponse(BaseSchema, Generic[T]): + """ + Generic paginated response wrapper. + + Example: + class UserList(PaginatedResponse[UserResponse]): + pass + """ + + items: List[T] = Field(default_factory=list, description="List of items") + total: int = Field(..., description="Total number of items") + page: int = Field(default=1, ge=1, description="Current page number") + page_size: int = Field(default=20, ge=1, le=100, description="Items per page") + pages: int = Field(..., description="Total number of pages") + + @classmethod + def create( + cls, items: List[T], total: int, page: int = 1, page_size: int = 20 + ) -> "PaginatedResponse[T]": + """Factory method to create paginated response.""" + pages = (total + page_size - 1) // page_size if page_size > 0 else 0 + return cls(items=items, total=total, page=page, page_size=page_size, pages=pages) + + +class PaginationParams(BaseSchema): + """Query parameters for pagination.""" + + page: int = Field(default=1, ge=1, description="Page number") + page_size: int = Field(default=20, ge=1, le=100, description="Items per page") + + @property + def offset(self) -> int: + """Calculate offset for database queries.""" + return (self.page - 1) * self.page_size + + @property + def limit(self) -> int: + """Alias for page_size.""" + return self.page_size + + +class ErrorDetail(BaseSchema): + """Error detail schema.""" + + field: Optional[str] = Field(default=None, description="Field that caused the error") + message: str = Field(..., description="Error message") + code: Optional[str] = Field(default=None, description="Error code") + + +class ErrorResponse(BaseSchema): + """Standardized error response.""" + + error: str = Field(..., description="Error type") + message: str = Field(..., description="Error message") + details: Optional[List[ErrorDetail]] = Field(default=None, description="Detailed errors") + request_id: Optional[str] = Field(default=None, description="Request tracking ID") + + +# ============================================================================= +# Schema Discovery +# ============================================================================= + + +def load_schemas(schemas_path: Path = None, logger: logging.Logger = None) -> Dict[str, Any]: + """ + Load schemas from the schemas/ directory. + + Similar to route loading, this discovers all Python files in the schemas/ + directory and imports them, making their Pydantic models available. + + Args: + schemas_path: Path to schemas directory (defaults to ./schemas) + logger: Logger instance for debug output + + Returns: + Dictionary of loaded modules by name + """ + if logger is None: + logger = logging.getLogger("runapi.schemas") + + if schemas_path is None: + schemas_path = Path("schemas") + + if not schemas_path.exists(): + logger.debug(f"Schemas directory not found: {schemas_path}") + return {} + + loaded_modules = {} + _load_schemas_recursive(schemas_path, "", loaded_modules, logger) + + SchemaRegistry._modules = loaded_modules + SchemaRegistry._loaded = True + + return loaded_modules + + +def _load_schemas_recursive( + schemas_dir: Path, prefix: str, loaded_modules: Dict[str, Any], logger: logging.Logger +) -> None: + """Recursively load schema files from directory structure.""" + + for item in schemas_dir.iterdir(): + if item.is_dir(): + # Skip hidden directories and __pycache__ + if item.name.startswith(".") or item.name.startswith("__"): + continue + + # Recurse into subdirectories + new_prefix = f"{prefix}.{item.name}" if prefix else item.name + _load_schemas_recursive(item, new_prefix, loaded_modules, logger) + + elif item.suffix == ".py" and item.name != "__init__.py": + _load_schema_file(item, prefix, loaded_modules, logger) + + +def _load_schema_file( + schema_file: Path, prefix: str, loaded_modules: Dict[str, Any], logger: logging.Logger +) -> None: + """Load a single schema file and register its Pydantic models.""" + + try: + schema_name = schema_file.stem + module_name = f"schemas.{prefix}.{schema_name}".strip(".").replace("..", ".") + + # Import the module + spec = importlib.util.spec_from_file_location(module_name, schema_file) + if spec is None or spec.loader is None: + logger.warning(f"Could not load spec for schema {schema_file}") + return + + module = importlib.util.module_from_spec(spec) + + # Add to sys.modules so relative imports work + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Store the module + loaded_modules[module_name] = module + + # Find and register all Pydantic models in the module + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, BaseModel) + and attr is not BaseModel + and not attr_name.startswith("_") + ): + # Register with full path and short name + full_name = f"{module_name}.{attr_name}" + SchemaRegistry.register(full_name, attr) + + # Check for short name collision before registering + existing = SchemaRegistry.get(attr_name) + if existing is not None and existing is not attr: + logger.warning( + f"Schema name collision: '{attr_name}' from {module_name} " + f"shadows existing schema. Use full path '{full_name}' to disambiguate." + ) + SchemaRegistry.register(attr_name, attr) # Also register short name + + logger.debug(f"Loaded schema module: {module_name}") + + except Exception as e: + logger.error(f"Failed to load schema {schema_file}: {e}") + + +def get_schema(name: str) -> Optional[Type[BaseModel]]: + """ + Get a registered schema by name. + + Args: + name: Schema name (e.g., "UserResponse" or "schemas.user.UserResponse") + + Returns: + The schema class or None if not found + """ + return SchemaRegistry.get(name) + + +def list_schemas() -> List[str]: + """List all registered schema names.""" + return list(SchemaRegistry.get_all().keys()) + + +# ============================================================================= +# Schema Utilities +# ============================================================================= + + +def create_response_model( + name: str, *, include_id: bool = True, include_timestamps: bool = True, **fields: Any +) -> Type[BaseSchema]: + """ + Dynamically create a response schema. + + Example: + UserResponse = create_response_model( + "UserResponse", + include_id=True, + include_timestamps=True, + email=(str, ...), + name=(str, None) + ) + """ + bases = [BaseSchema] + + if include_id: + bases.insert(0, IDMixin) + if include_timestamps: + bases.insert(0, TimestampMixin) + + return type(name, tuple(bases), {"__annotations__": fields}) + + +def create_create_model(name: str, **fields: Any) -> Type[BaseSchema]: + """ + Dynamically create a 'create' schema (no ID, no timestamps). + + Example: + UserCreate = create_create_model( + "UserCreate", + email=(str, ...), + password=(str, ...) + ) + """ + return type(name, (BaseSchema,), {"__annotations__": fields}) + + +def create_update_model(name: str, **fields: Any) -> Type[BaseSchema]: + """ + Dynamically create an 'update' schema (all fields optional). + + Example: + UserUpdate = create_update_model( + "UserUpdate", + email=str, + name=str + ) + """ + # Make all fields optional with None defaults + annotations = {} + field_defaults = {} + + for field_name, field_type in fields.items(): + if isinstance(field_type, tuple): + # If tuple provided, use first element as type + annotations[field_name] = Optional[field_type[0]] + else: + annotations[field_name] = Optional[field_type] + field_defaults[field_name] = None + + namespace = {"__annotations__": annotations} + namespace.update(field_defaults) + + return type(name, (BaseSchema,), namespace) diff --git a/runapi/service.py b/runapi/service.py new file mode 100644 index 0000000..7179542 --- /dev/null +++ b/runapi/service.py @@ -0,0 +1,463 @@ +# runapi/service.py +""" +Service layer for RunAPI - Business logic abstraction. + +Provides: +- BaseService abstract class for business logic +- CRUDService for common CRUD operations +- Service decorators for validation, transactions, etc. +- Dependency injection utilities +""" + +import logging +from abc import ABC, abstractmethod +from functools import wraps +from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Type, TypeVar + +from .errors import NotFoundError +from .repository import BaseRepository + +logger = logging.getLogger("runapi.service") + +# Type variables +T = TypeVar("T") # Entity type +ID = TypeVar("ID") # ID type +CreateSchema = TypeVar("CreateSchema") +UpdateSchema = TypeVar("UpdateSchema") + + +# ============================================================================= +# Base Service (Abstract) +# ============================================================================= + + +class BaseService(ABC, Generic[T, ID]): + """ + Abstract base service for business logic. + + Services sit between routes and repositories, handling: + - Business rules and validation + - Complex operations spanning multiple repositories + - Transaction coordination + - Authorization checks + + Example: + class UserService(BaseService[User, int]): + def __init__(self, user_repo: UserRepository): + self.repository = user_repo + + async def register(self, data: dict) -> User: + # Business logic here + if await self.repository.get_by(email=data['email']): + raise ValidationError("Email already exists") + return await self.repository.create(data) + """ + + repository: BaseRepository[T, ID] + + @abstractmethod + async def get(self, id: ID) -> T: + """Get an entity by ID.""" + pass + + @abstractmethod + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get all entities with pagination.""" + pass + + @abstractmethod + async def create(self, data: Dict[str, Any]) -> T: + """Create a new entity.""" + pass + + @abstractmethod + async def update(self, id: ID, data: Dict[str, Any]) -> T: + """Update an existing entity.""" + pass + + @abstractmethod + async def delete(self, id: ID) -> bool: + """Delete an entity.""" + pass + + +# ============================================================================= +# CRUD Service (Ready-to-use) +# ============================================================================= + + +class CRUDService(BaseService[T, ID], Generic[T, ID]): + """ + Ready-to-use CRUD service with common operations. + + Provides standard CRUD operations with built-in: + - Not found error handling + - Pagination support + - Filter support + + Example: + class UserService(CRUDService[User, int]): + def __init__(self, repository: UserRepository): + super().__init__(repository) + + # Add custom business methods + async def deactivate(self, user_id: int) -> User: + return await self.update(user_id, {"is_active": False}) + """ + + def __init__(self, repository: BaseRepository[T, ID], entity_name: str = "Entity"): + """ + Initialize CRUD service. + + Args: + repository: The repository for data access + entity_name: Name used in error messages (e.g., "User", "Product") + """ + self.repository = repository + self.entity_name = entity_name + + async def get(self, id: ID) -> T: + """ + Get an entity by ID. + + Raises: + NotFoundError: If entity not found + """ + entity = await self.repository.get(id) + if entity is None: + raise NotFoundError(f"{self.entity_name} with id {id} not found") + return entity + + async def get_or_none(self, id: ID) -> Optional[T]: + """Get an entity by ID, returning None if not found.""" + return await self.repository.get(id) + + async def get_all(self, skip: int = 0, limit: int = 100, **filters) -> List[T]: + """Get all entities with pagination and optional filters.""" + return await self.repository.get_all(skip=skip, limit=limit, **filters) + + async def get_by(self, **filters) -> Optional[T]: + """Get a single entity matching filters.""" + return await self.repository.get_by(**filters) + + async def create(self, data: Dict[str, Any]) -> T: + """ + Create a new entity. + + Override this method to add validation logic. + """ + return await self.repository.create(data) + + async def update(self, id: ID, data: Dict[str, Any]) -> T: + """ + Update an existing entity. + + Raises: + NotFoundError: If entity not found + """ + # Verify exists + await self.get(id) # Raises NotFoundError if not found + + result = await self.repository.update(id, data) + if result is None: + raise NotFoundError(f"{self.entity_name} with id {id} not found") + return result + + async def delete(self, id: ID) -> bool: + """ + Delete an entity. + + Raises: + NotFoundError: If entity not found + """ + # Verify exists + await self.get(id) # Raises NotFoundError if not found + + return await self.repository.delete(id) + + async def exists(self, id: ID) -> bool: + """Check if an entity exists.""" + return await self.repository.exists(id) + + async def count(self, **filters) -> int: + """Count entities matching filters.""" + return await self.repository.count(**filters) + + # Bulk operations + + async def create_many(self, items: List[Dict[str, Any]]) -> List[T]: + """Create multiple entities.""" + return await self.repository.create_many(items) + + async def update_many(self, ids: List[ID], data: Dict[str, Any]) -> List[T]: + """Update multiple entities.""" + return await self.repository.update_many(ids, data) + + async def delete_many(self, ids: List[ID]) -> int: + """Delete multiple entities. Returns count of deleted.""" + return await self.repository.delete_many(ids) + + +# ============================================================================= +# Service with Validation +# ============================================================================= + + +class ValidatedService(CRUDService[T, ID], Generic[T, ID]): + """ + CRUD service with schema validation support. + + Validates input data against Pydantic schemas before operations. + + Example: + class UserService(ValidatedService[User, int]): + create_schema = UserCreate + update_schema = UserUpdate + + def __init__(self, repository: UserRepository): + super().__init__(repository, "User") + """ + + create_schema: Optional[Type] = None + update_schema: Optional[Type] = None + + def _validate_create(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Validate data for create operation.""" + if self.create_schema: + validated = self.create_schema(**data) + if hasattr(validated, "model_dump"): + return validated.model_dump(exclude_unset=True) + elif hasattr(validated, "dict"): + return validated.dict(exclude_unset=True) + return data + + def _validate_update(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Validate data for update operation.""" + if self.update_schema: + validated = self.update_schema(**data) + if hasattr(validated, "model_dump"): + return validated.model_dump(exclude_unset=True, exclude_none=True) + elif hasattr(validated, "dict"): + return validated.dict(exclude_unset=True, exclude_none=True) + return data + + async def create(self, data: Dict[str, Any]) -> T: + """Create with validation.""" + validated_data = self._validate_create(data) + return await self.repository.create(validated_data) + + async def update(self, id: ID, data: Dict[str, Any]) -> T: + """Update with validation.""" + await self.get(id) # Verify exists + validated_data = self._validate_update(data) + result = await self.repository.update(id, validated_data) + if result is None: + raise NotFoundError(f"{self.entity_name} with id {id} not found") + return result + + +# ============================================================================= +# Service Decorators +# ============================================================================= + + +def validate_input(schema: Type): + """ + Decorator to validate input data against a Pydantic schema. + + Example: + class UserService(CRUDService): + @validate_input(UserCreate) + async def create(self, data: dict) -> User: + return await self.repository.create(data) + """ + + def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + @wraps(func) + async def wrapper(self, data: Dict[str, Any], *args, **kwargs): + validated = schema(**data) + if hasattr(validated, "model_dump"): + validated_data = validated.model_dump(exclude_unset=True) + elif hasattr(validated, "dict"): + validated_data = validated.dict(exclude_unset=True) + else: + validated_data = data + return await func(self, validated_data, *args, **kwargs) + + return wrapper + + return decorator + + +def require_exists(entity_name: str = "Entity"): + """ + Decorator to ensure an entity exists before operation. + + Example: + class UserService(CRUDService): + @require_exists("User") + async def update(self, id: int, data: dict) -> User: + return await self.repository.update(id, data) + """ + + def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + @wraps(func) + async def wrapper(self, id: Any, *args, **kwargs): + exists = await self.repository.exists(id) + if not exists: + raise NotFoundError(f"{entity_name} with id {id} not found") + return await func(self, id, *args, **kwargs) + + return wrapper + + return decorator + + +def log_operation(operation_name: str = None): + """ + Decorator to log service operations. + + Example: + class UserService(CRUDService): + @log_operation("create_user") + async def create(self, data: dict) -> User: + return await self.repository.create(data) + """ + + def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + @wraps(func) + async def wrapper(self, *args, **kwargs): + op_name = operation_name or func.__name__ + logger.info(f"Starting operation: {op_name}") + try: + result = await func(self, *args, **kwargs) + logger.info(f"Completed operation: {op_name}") + return result + except Exception as e: + logger.error(f"Failed operation: {op_name} - {e}") + raise + + return wrapper + + return decorator + + +# ============================================================================= +# Service Factory +# ============================================================================= + + +class ServiceFactory: + """ + Factory for creating and managing service instances. + + Useful for dependency injection and testing. + + Example: + factory = ServiceFactory() + factory.register("users", UserService, user_repository) + + # In route handler + user_service = factory.get("users") + """ + + _services: Dict[str, Any] = {} + _factories: Dict[str, tuple] = {} + + @classmethod + def register(cls, name: str, service_class: Type[BaseService], *args, **kwargs) -> None: + """ + Register a service factory. + + Args: + name: Service name for lookup + service_class: The service class + *args, **kwargs: Arguments to pass to service constructor + """ + cls._factories[name] = (service_class, args, kwargs) + logger.debug(f"Registered service factory: {name}") + + @classmethod + def get(cls, name: str) -> BaseService: + """ + Get or create a service instance. + + Creates a new instance on first access, then returns cached instance. + """ + if name not in cls._services: + if name not in cls._factories: + raise ValueError(f"Service '{name}' not registered") + service_class, args, kwargs = cls._factories[name] + cls._services[name] = service_class(*args, **kwargs) + return cls._services[name] + + @classmethod + def create(cls, name: str) -> BaseService: + """Create a new service instance (not cached).""" + if name not in cls._factories: + raise ValueError(f"Service '{name}' not registered") + service_class, args, kwargs = cls._factories[name] + return service_class(*args, **kwargs) + + @classmethod + def list_services(cls) -> List[str]: + """List all registered service names.""" + return list(cls._factories.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registered services and instances.""" + cls._services.clear() + cls._factories.clear() + + +# ============================================================================= +# Dependency Injection Helper +# ============================================================================= + + +def create_service_dependency( + service_class: Type[BaseService], repository_class: Type[BaseRepository], **service_kwargs +) -> Callable: + """ + Create a FastAPI dependency for a service. + + Example: + from fastapi import Depends + + get_user_service = create_service_dependency(UserService, UserRepository) + + async def get_users(service: UserService = Depends(get_user_service)): + return await service.get_all() + """ + # Cache the repository instance + _repository = None + _service = None + + def get_service() -> BaseService: + nonlocal _repository, _service + if _repository is None: + _repository = repository_class() + if _service is None: + _service = service_class(_repository, **service_kwargs) + return _service + + return get_service + + +# ============================================================================= +# Utility Functions +# ============================================================================= + + +def create_crud_service( + repository: BaseRepository[T, ID], entity_name: str = "Entity" +) -> CRUDService[T, ID]: + """ + Quick factory to create a CRUD service. + + Example: + user_repo = UserRepository() + user_service = create_crud_service(user_repo, "User") + """ + return CRUDService(repository, entity_name) diff --git a/tests/final_test.py b/tests/final_test.py index f91991b..f9197d3 100644 --- a/tests/final_test.py +++ b/tests/final_test.py @@ -2,11 +2,12 @@ Final Comprehensive Test Suite for RunApi Framework Tests all major functionality and ensures everything works correctly """ + import os import sys import tempfile -import shutil from pathlib import Path + from fastapi.testclient import TestClient @@ -15,8 +16,7 @@ def test_framework_installation(): print("๐Ÿงช Testing RunApi installation...") try: import runapi - from runapi import create_runapi_app, RunApiConfig, get_config - from runapi import JSONResponse, ValidationError, create_access_token + print("โœ… RunApi framework imports successfully!") print(f" Version: {getattr(runapi, '__version__', 'unknown')}") return True @@ -30,9 +30,9 @@ def test_cli_functionality(): print("๐Ÿงช Testing CLI functionality...") try: import subprocess - result = subprocess.run(['runapi', '--help'], - capture_output=True, text=True, timeout=10) - if result.returncode == 0 and 'RunApi' in result.stdout: + + result = subprocess.run(["runapi", "--help"], capture_output=True, text=True, timeout=10) + if result.returncode == 0 and "RunApi" in result.stdout: print("โœ… CLI is working correctly!") return True else: @@ -48,25 +48,21 @@ def test_basic_app_creation(): print("๐Ÿงช Testing basic app creation...") try: from runapi import create_runapi_app, get_config - + # Test app creation - app = create_runapi_app( - title="Test API", - description="Test Description", - version="1.0.0" - ) - + app = create_runapi_app(title="Test API", description="Test Description", version="1.0.0") + fastapi_app = app.get_app() assert fastapi_app.title == "Test API" assert fastapi_app.description == "Test Description" assert fastapi_app.version == "1.0.0" - + # Test configuration config = get_config() - assert hasattr(config, 'debug') - assert hasattr(config, 'host') - assert hasattr(config, 'port') - + assert hasattr(config, "debug") + assert hasattr(config, "host") + assert hasattr(config, "port") + print("โœ… Basic app creation and configuration works!") return True except Exception as e: @@ -77,29 +73,29 @@ def test_basic_app_creation(): def test_file_based_routing(): """Test file-based routing with actual route files""" print("๐Ÿงช Testing file-based routing...") - + with tempfile.TemporaryDirectory() as temp_dir: try: temp_path = Path(temp_dir) routes_path = temp_path / "routes" routes_path.mkdir() (routes_path / "__init__.py").touch() - + # Create index route - index_content = ''' + index_content = """ from runapi import JSONResponse async def get(): return JSONResponse({"message": "Hello from index", "route": "index"}) -''' - (routes_path / "index.py").write_text(index_content, encoding='utf-8') - +""" + (routes_path / "index.py").write_text(index_content, encoding="utf-8") + # Create API directory and route api_path = routes_path / "api" api_path.mkdir() (api_path / "__init__.py").touch() - - api_content = ''' + + api_content = """ from runapi import JSONResponse, Request async def get(): @@ -107,65 +103,67 @@ async def get(): async def post(request: Request): return JSONResponse({"message": "API endpoint", "method": "POST"}) -''' - (api_path / "test.py").write_text(api_content, encoding='utf-8') - +""" + (api_path / "test.py").write_text(api_content, encoding="utf-8") + # Create dynamic route users_path = routes_path / "users" users_path.mkdir() (users_path / "__init__.py").touch() - - dynamic_content = ''' + + dynamic_content = """ from runapi import JSONResponse, Request async def get(request: Request): user_id = request.path_params.get("id", "unknown") return JSONResponse({"user_id": user_id, "method": "GET"}) -''' - (users_path / "[id].py").write_text(dynamic_content, encoding='utf-8') - +""" + (users_path / "[id].py").write_text(dynamic_content, encoding="utf-8") + # Change to temp directory and test old_cwd = os.getcwd() try: os.chdir(temp_dir) - + from runapi import create_runapi_app + app = create_runapi_app() - + with TestClient(app.get_app()) as client: # Test index route response = client.get("/") assert response.status_code == 200 data = response.json() assert data["route"] == "index" - + # Test API route response = client.get("/api/test") assert response.status_code == 200 data = response.json() assert data["method"] == "GET" - + # Test POST to API route response = client.post("/api/test") assert response.status_code == 200 data = response.json() assert data["method"] == "POST" - + # Test dynamic route response = client.get("/users/123") assert response.status_code == 200 data = response.json() assert data["user_id"] == "123" - + print("โœ… File-based routing works correctly!") return True - + finally: os.chdir(old_cwd) - + except Exception as e: print(f"โŒ File-based routing test failed: {e}") import traceback + print(traceback.format_exc()) return False @@ -175,20 +173,20 @@ def test_middleware_and_security(): print("๐Ÿงช Testing middleware and security...") try: from runapi import create_runapi_app - + app = create_runapi_app() - + with TestClient(app.get_app()) as client: # Test that security headers are added response = client.get("/docs") assert response.status_code == 200 - + # Check for security headers headers = response.headers assert "X-Content-Type-Options" in headers assert headers["X-Content-Type-Options"] == "nosniff" assert "X-Frame-Options" in headers - + print("โœ… Middleware and security features work!") return True except Exception as e: @@ -200,8 +198,8 @@ def test_error_handling(): """Test error handling system""" print("๐Ÿงช Testing error handling...") try: - from runapi import ValidationError, NotFoundError, create_error_response - + from runapi import ValidationError, create_error_response + # Test custom exceptions try: raise ValidationError("Test validation error", {"field": "test"}) @@ -209,11 +207,11 @@ def test_error_handling(): assert e.status_code == 400 assert e.error_code == "VALIDATION_ERROR" assert e.details["field"] == "test" - + # Test error response creation response = create_error_response("Test error", 404, "TEST_ERROR") assert response.status_code == 404 - + print("โœ… Error handling system works!") return True except Exception as e: @@ -224,36 +222,37 @@ def test_error_handling(): def test_generated_project(): """Test that generated projects work correctly""" print("๐Ÿงช Testing generated project functionality...") - + # Test the current test-project if it exists if os.path.exists("test-project") and os.path.exists("test-project/main.py"): try: import sys + sys.path.insert(0, "test-project") - + # Import the generated project's app from main import app - + with TestClient(app) as client: # Test index route response = client.get("/") assert response.status_code == 200 data = response.json() assert "message" in data - + # Test API routes if they exist response = client.get("/api/hello") if response.status_code == 200: data = response.json() assert "message" in data - + # Test docs response = client.get("/docs") assert response.status_code == 200 - + print("โœ… Generated project works correctly!") return True - + except Exception as e: print(f"โŒ Generated project test failed: {e}") return False @@ -270,31 +269,30 @@ def test_documentation_generation(): print("๐Ÿงช Testing API documentation generation...") try: from runapi import create_runapi_app - + app = create_runapi_app( - title="Documentation Test API", - description="Testing automatic docs generation" + title="Documentation Test API", description="Testing automatic docs generation" ) - + with TestClient(app.get_app()) as client: # Test OpenAPI JSON response = client.get("/openapi.json") assert response.status_code == 200 - + openapi_data = response.json() assert "openapi" in openapi_data assert "info" in openapi_data assert openapi_data["info"]["title"] == "Documentation Test API" - + # Test Swagger UI response = client.get("/docs") assert response.status_code == 200 assert "text/html" in response.headers["content-type"] - + # Test ReDoc response = client.get("/redoc") assert response.status_code == 200 - + print("โœ… API documentation generation works!") return True except Exception as e: @@ -307,7 +305,7 @@ def run_comprehensive_tests(): print("๐Ÿš€ RunApi Framework - Final Comprehensive Test Suite") print("=" * 60) print() - + tests = [ ("Framework Installation", test_framework_installation), ("CLI Functionality", test_cli_functionality), @@ -318,15 +316,15 @@ def run_comprehensive_tests(): ("Generated Project", test_generated_project), ("Documentation Generation", test_documentation_generation), ] - + passed = 0 failed = 0 results = [] - + for test_name, test_func in tests: print(f"Running: {test_name}") print("-" * 40) - + try: if test_func(): passed += 1 @@ -337,24 +335,24 @@ def run_comprehensive_tests(): except Exception as e: failed += 1 results.append((test_name, f"โŒ ERROR: {e}")) - + print() - + # Print final summary print("=" * 60) print("๐Ÿ FINAL TEST RESULTS") print("=" * 60) - + for test_name, result in results: print(f"{result:<10} {test_name}") - + print() - print(f"๐Ÿ“Š Summary:") + print("๐Ÿ“Š Summary:") print(f" โœ… Passed: {passed}/{len(tests)}") print(f" โŒ Failed: {failed}/{len(tests)}") - print(f" ๐Ÿ“ˆ Success Rate: {(passed/len(tests)*100):.1f}%") + print(f" ๐Ÿ“ˆ Success Rate: {(passed / len(tests) * 100):.1f}%") print() - + if failed == 0: print("๐ŸŽ‰ ALL TESTS PASSED! RunApi framework is working perfectly!") print("๐Ÿš€ The framework is ready for production use!") @@ -362,7 +360,7 @@ def run_comprehensive_tests(): print("โœจ Features successfully tested:") print(" โ€ข File-based routing with dynamic routes") print(" โ€ข Middleware system with security features") - print(" โ€ข Configuration management") + print(" โ€ข Configuration management") print(" โ€ข Error handling and custom exceptions") print(" โ€ข CLI tools for project management") print(" โ€ข Automatic API documentation") @@ -376,4 +374,4 @@ def run_comprehensive_tests(): if __name__ == "__main__": success = run_comprehensive_tests() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/tests/simple_test.py b/tests/simple_test.py index dce8a5f..6e15cdb 100644 --- a/tests/simple_test.py +++ b/tests/simple_test.py @@ -1,17 +1,14 @@ """ Simple test script for RunApi framework basic functionality """ -import os + import sys -from pathlib import Path def test_imports(): """Test basic imports""" print("๐Ÿงช Testing imports...") try: - from runapi import create_runapi_app, RunApiConfig, get_config - from runapi import JSONResponse, ValidationError, create_access_token print("โœ… All imports successful!") return True except Exception as e: @@ -24,12 +21,12 @@ def test_config(): print("๐Ÿงช Testing configuration...") try: from runapi import RunApiConfig - + # Test basic config creation config = RunApiConfig() - assert hasattr(config, 'debug') - assert hasattr(config, 'host') - assert hasattr(config, 'port') + assert hasattr(config, "debug") + assert hasattr(config, "host") + assert hasattr(config, "port") print("โœ… Configuration test passed!") return True except Exception as e: @@ -42,10 +39,10 @@ def test_app_creation(): print("๐Ÿงช Testing app creation...") try: from runapi import create_runapi_app - + app = create_runapi_app(title="Test API") fastapi_app = app.get_app() - + assert fastapi_app.title == "Test API" print("โœ… App creation test passed!") return True @@ -64,6 +61,7 @@ def test_jwt_auth(): return True except Exception as e: import traceback + print(f"โŒ JWT authentication test failed: {e}") print(f"Traceback: {traceback.format_exc()}") return False @@ -74,18 +72,18 @@ def test_error_handling(): print("๐Ÿงช Testing error handling...") try: from runapi import ValidationError, create_error_response - + # Test custom exception try: raise ValidationError("Test error") except ValidationError as e: assert e.status_code == 400 assert e.error_code == "VALIDATION_ERROR" - + # Test error response response = create_error_response("Test", 404, "TEST_ERROR") assert response.status_code == 404 - + print("โœ… Error handling test passed!") return True except Exception as e: @@ -97,30 +95,31 @@ def test_basic_routing(): """Test basic routing with TestClient""" print("๐Ÿงช Testing basic routing...") try: - from runapi import create_runapi_app from fastapi import APIRouter from fastapi.testclient import TestClient - + + from runapi import create_runapi_app + # Create app runapi_app = create_runapi_app(title="Test API") app = runapi_app.get_app() - + # Add a simple test route router = APIRouter() - + @router.get("/test") async def test_endpoint(): return {"message": "test successful"} - + app.include_router(router) - + # Test with client client = TestClient(app) response = client.get("/test") - + assert response.status_code == 200 assert response.json()["message"] == "test successful" - + print("โœ… Basic routing test passed!") return True except Exception as e: @@ -133,10 +132,10 @@ def test_cli_functionality(): print("๐Ÿงช Testing CLI functionality...") try: from runapi.cli import app as cli_app - + # Test that CLI app is created assert cli_app is not None - + print("โœ… CLI functionality test passed!") return True except Exception as e: @@ -147,20 +146,20 @@ def test_cli_functionality(): def run_all_tests(): """Run all simple tests""" print("๐Ÿš€ Running RunApi Simple Tests\n") - + tests = [ test_imports, - test_config, + test_config, test_app_creation, test_jwt_auth, test_error_handling, test_basic_routing, test_cli_functionality, ] - + passed = 0 failed = 0 - + for test in tests: try: if test(): @@ -171,11 +170,11 @@ def run_all_tests(): print(f"โŒ Test {test.__name__} crashed: {e}") failed += 1 print() # Add space between tests - - print(f"๐Ÿ“Š Results:") + + print("๐Ÿ“Š Results:") print(f"โœ… Passed: {passed}/{len(tests)}") print(f"โŒ Failed: {failed}/{len(tests)}") - + if failed == 0: print("๐ŸŽ‰ All tests passed! RunApi framework basic functionality is working!") return True @@ -186,4 +185,4 @@ def run_all_tests(): if __name__ == "__main__": success = run_all_tests() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/tests/test_runapi.py b/tests/test_runapi.py index c0cc980..1294eca 100644 --- a/tests/test_runapi.py +++ b/tests/test_runapi.py @@ -6,11 +6,8 @@ import asyncio import os import tempfile -import shutil from pathlib import Path -import json -import httpx -import pytest + from fastapi.testclient import TestClient @@ -20,9 +17,7 @@ def test_basic_app_creation(): from runapi import create_runapi_app - app = create_runapi_app( - title="Test API", description="Test RunApi API", version="1.0.0" - ) + app = create_runapi_app(title="Test API", description="Test RunApi API", version="1.0.0") fastapi_app = app.get_app() @@ -47,7 +42,7 @@ def test_configuration_system(): config = RunApiConfig() - assert config.debug == True + assert config.debug assert config.host == "0.0.0.0" assert config.port == 9000 assert config.secret_key == "test-secret-key" @@ -59,7 +54,7 @@ def test_error_handling(): """Test error handling system""" print("๐Ÿงช Testing error handling...") - from runapi import ValidationError, NotFoundError, create_error_response + from runapi import ValidationError, create_error_response # Test custom exceptions try: @@ -83,17 +78,23 @@ def test_authentication_system(): """Test JWT authentication system""" print("๐Ÿงช Testing authentication system...") - from runapi import create_access_token, verify_token + # Set a proper secret key for testing BEFORE importing + os.environ["SECRET_KEY"] = "test-secret-key-at-least-32-characters-long" + + # Use JWTManager directly with a custom secret key to avoid config caching issues + from runapi.auth import JWTManager + + jwt_manager = JWTManager(secret_key="test-secret-key-at-least-32-characters-long") # Test token creation and verification user_data = {"sub": "user123", "username": "testuser", "roles": ["user"]} - token = create_access_token(user_data) + token = jwt_manager.create_access_token(user_data) assert isinstance(token, str) assert len(token.split(".")) == 3 # JWT has 3 parts # Test token verification - payload = verify_token(token) + payload = jwt_manager.verify_token(token) assert payload is not None assert payload["sub"] == "user123" assert payload["username"] == "testuser" @@ -179,7 +180,7 @@ def test_middleware_system(): """Test middleware system""" print("๐Ÿงช Testing middleware system...") - from runapi import create_runapi_app, RunApiMiddleware + from runapi import RunApiMiddleware, create_runapi_app # Custom test middleware class TestMiddleware(RunApiMiddleware): @@ -294,11 +295,7 @@ def test_cors_configuration(): with TestClient(app.get_app()) as client: # Test preflight request response = client.options( - "/", - headers={ - "Origin": "http://localhost:3000", - "Access-Control-Request-Method": "GET", - }, + "/", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"} ) # Should allow the request @@ -339,154 +336,743 @@ def test_static_file_serving(): print("โœ… Static file serving test passed!") -def run_all_tests(): - """Run all tests""" - print("๐Ÿš€ Starting RunApi Framework Tests\n") +def test_schema_system(): + """Test schema base classes and utilities""" + print("๐Ÿงช Testing schema system...") - tests = [ - test_basic_app_creation, - test_configuration_system, - test_error_handling, - test_authentication_system, - test_file_based_routing, - test_middleware_system, - test_dynamic_routes, - test_cors_configuration, - test_static_file_serving, - test_router_discovery, - test_nested_routing_behavior, + from datetime import datetime + from typing import Optional + + from runapi import ( + BaseSchema, + IDMixin, + MessageResponse, + PaginatedResponse, + PaginationParams, + TimestampMixin, + ) + + # Test BaseSchema + class UserResponse(BaseSchema, IDMixin, TimestampMixin): + email: str + name: Optional[str] = None + + user = UserResponse(id=1, email="test@example.com", name="Test User", created_at=datetime.now()) + + assert user.id == 1 + assert user.email == "test@example.com" + assert user.name == "Test User" + assert user.created_at is not None + + # Test MessageResponse + msg = MessageResponse(message="Operation successful") + assert msg.message == "Operation successful" + assert msg.success + + # Test PaginationParams + params = PaginationParams(page=2, page_size=10) + assert params.offset == 10 # (2-1) * 10 + assert params.limit == 10 + + # Test PaginatedResponse + items = [user] + paginated = PaginatedResponse.create(items=items, total=100, page=1, page_size=10) + assert paginated.total == 100 + assert paginated.pages == 10 + assert len(paginated.items) == 1 + + print("โœ… Schema system test passed!") + + +def test_schema_auto_discovery(): + """Test schema auto-discovery from schemas/ folder""" + print("๐Ÿงช Testing schema auto-discovery...") + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + schemas_path = temp_path / "schemas" + schemas_path.mkdir() + (schemas_path / "__init__.py").touch() + + # Create a test schema file + user_schema = """ +from pydantic import BaseModel, Field +from typing import Optional + +class UserCreate(BaseModel): + email: str + name: str + +class UserResponse(BaseModel): + id: int + email: str + name: Optional[str] = None +""" + + (schemas_path / "user.py").write_text(user_schema, encoding="utf-8") + + # Create nested schema + api_schemas = schemas_path / "api" + api_schemas.mkdir() + (api_schemas / "__init__.py").touch() + + product_schema = """ +from pydantic import BaseModel + +class ProductCreate(BaseModel): + name: str + price: float + +class ProductResponse(BaseModel): + id: int + name: str + price: float +""" + + (api_schemas / "product.py").write_text(product_schema, encoding="utf-8") + + old_cwd = os.getcwd() + try: + os.chdir(temp_dir) + + from runapi.schemas import SchemaRegistry, get_schema, list_schemas, load_schemas + + # Clear registry before test + SchemaRegistry.clear() + + # Load schemas + loaded = load_schemas(schemas_path) + + assert len(loaded) >= 2, f"Expected at least 2 modules, got {len(loaded)}" + + # Check registry + user_create = get_schema("UserCreate") + assert user_create is not None, "UserCreate schema not found" + + user_response = get_schema("UserResponse") + assert user_response is not None, "UserResponse schema not found" + + product_create = get_schema("ProductCreate") + assert product_create is not None, "ProductCreate schema not found" + + # Test schema functionality + user = user_create(email="test@example.com", name="Test User") + assert user.email == "test@example.com" + + # Test list_schemas + schema_names = list_schemas() + assert "UserCreate" in schema_names + assert "ProductResponse" in schema_names + + finally: + os.chdir(old_cwd) + SchemaRegistry.clear() + + print("โœ… Schema auto-discovery test passed!") + + +def test_schema_integration_with_routes(): + """Test using schemas in route handlers""" + print("๐Ÿงช Testing schema integration with routes...") + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create schemas directory + schemas_path = temp_path / "schemas" + schemas_path.mkdir() + (schemas_path / "__init__.py").touch() + + user_schema = """ +from pydantic import BaseModel, Field +from typing import Optional + +class UserCreate(BaseModel): + email: str + name: str + +class UserResponse(BaseModel): + id: int + email: str + name: str +""" + (schemas_path / "user.py").write_text(user_schema, encoding="utf-8") + + # Create routes directory + routes_path = temp_path / "routes" + routes_path.mkdir() + (routes_path / "__init__.py").touch() + + api_path = routes_path / "api" + api_path.mkdir() + (api_path / "__init__.py").touch() + + # Create route that uses schemas + users_route = ''' +from runapi import JSONResponse, Request +from schemas.user import UserCreate, UserResponse + +async def get(): + """Get list of users.""" + users = [ + {"id": 1, "email": "user1@example.com", "name": "User One"}, + {"id": 2, "email": "user2@example.com", "name": "User Two"}, ] + return JSONResponse([UserResponse(**u).model_dump() for u in users]) - passed = 0 - failed = 0 +async def post(request: Request): + """Create a new user.""" + body = await request.json() + user_data = UserCreate(**body) + # Simulate user creation + new_user = UserResponse(id=123, email=user_data.email, name=user_data.name) + return JSONResponse(new_user.model_dump(), status_code=201) +''' + (api_path / "users.py").write_text(users_route, encoding="utf-8") - for test in tests: + old_cwd = os.getcwd() try: - test() - passed += 1 - except Exception as e: - print(f"โŒ Test {test.__name__} failed: {e}") - failed += 1 + os.chdir(temp_dir) - print(f"\n๐Ÿ“Š Test Results:") - print(f"โœ… Passed: {passed}") - print(f"โŒ Failed: {failed}") - print(f"๐Ÿ“ˆ Success Rate: {passed/(passed+failed)*100:.1f}%") + # Add temp_dir to path so schemas can be imported + import sys - if failed == 0: - print("\n๐ŸŽ‰ All tests passed! RunApi framework is working correctly.") - else: - print(f"\nโš ๏ธ {failed} test(s) failed. Please check the output above.") + sys.path.insert(0, temp_dir) - return failed == 0 + from runapi import create_runapi_app + + app = create_runapi_app() + fastapi_app = app.get_app() + + with TestClient(fastapi_app) as client: + # Test GET users + response = client.get("/api/users") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["email"] == "user1@example.com" + + # Test POST user + new_user_data = {"email": "new@example.com", "name": "New User"} + response = client.post("/api/users", json=new_user_data) + assert response.status_code == 201 + data = response.json() + assert data["id"] == 123 + assert data["email"] == "new@example.com" + assert data["name"] == "New User" + + finally: + os.chdir(old_cwd) + if temp_dir in sys.path: + sys.path.remove(temp_dir) + + print("โœ… Schema integration with routes test passed!") + + +def test_repository_in_memory(): + """Test InMemoryRepository basic operations""" + print("๐Ÿงช Testing InMemoryRepository...") + + from runapi import InMemoryRepository + + async def run_tests(): + repo = InMemoryRepository() + + # Test create + user1 = await repo.create({"name": "John", "email": "john@example.com"}) + assert user1["id"] == 1 + assert user1["name"] == "John" + assert "created_at" in user1 + + user2 = await repo.create({"name": "Jane", "email": "jane@example.com"}) + assert user2["id"] == 2 + + # Test get + fetched = await repo.get(1) + assert fetched["name"] == "John" + + # Test get_all + all_users = await repo.get_all() + assert len(all_users) == 2 + + # Test get_all with filters + johns = await repo.get_all(name="John") + assert len(johns) == 1 + assert johns[0]["name"] == "John" + + # Test update + updated = await repo.update(1, {"name": "Johnny"}) + assert updated["name"] == "Johnny" + assert updated["email"] == "john@example.com" + + # Test count + count = await repo.count() + assert count == 2 + + # Test exists + assert await repo.exists(1) + assert not await repo.exists(999) + + # Test delete + deleted = await repo.delete(1) + assert deleted + + remaining = await repo.get_all() + assert len(remaining) == 1 + + # Test get_by + found = await repo.get_by(email="jane@example.com") + assert found["name"] == "Jane" + + # Clear for next tests + repo.clear() + assert await repo.count() == 0 + + asyncio.run(run_tests()) + print("โœ… InMemoryRepository test passed!") + + +def test_typed_repository(): + """Test TypedInMemoryRepository with Pydantic models""" + print("๐Ÿงช Testing TypedInMemoryRepository...") + + from datetime import datetime + from typing import Optional + + from pydantic import BaseModel + + from runapi import TypedInMemoryRepository + + class User(BaseModel): + id: Optional[int] = None + name: str + email: str + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + async def run_tests(): + repo = TypedInMemoryRepository(User) + # Test create - returns User model instance + user = await repo.create({"name": "Alice", "email": "alice@example.com"}) + assert isinstance(user, User) + assert user.id == 1 + assert user.name == "Alice" -def test_router_discovery(): - """Test router discovery and route detection""" - print("๐Ÿงช Testing router discovery...") + # Test get - returns User model instance + fetched = await repo.get(1) + assert isinstance(fetched, User) + assert fetched.email == "alice@example.com" + + # Test update - returns User model instance + updated = await repo.update(1, {"name": "Alicia"}) + assert isinstance(updated, User) + assert updated.name == "Alicia" + + # Test get_all - returns list of User instances + all_users = await repo.get_all() + assert all(isinstance(u, User) for u in all_users) + + repo.clear() + + asyncio.run(run_tests()) + print("โœ… TypedInMemoryRepository test passed!") + + +def test_repository_factory(): + """Test RepositoryFactory registration and creation""" + print("๐Ÿงช Testing RepositoryFactory...") + + from runapi import InMemoryRepository, RepositoryFactory + + # Clear any existing registrations + RepositoryFactory.clear() + + # Test register + RepositoryFactory.register("users", InMemoryRepository) + + # Test get + repo_class = RepositoryFactory.get("users") + assert repo_class == InMemoryRepository + + # Test list + repos = RepositoryFactory.list_repositories() + assert "users" in repos + + # Test create + repo = RepositoryFactory.create("users") + assert isinstance(repo, InMemoryRepository) + + # Clean up + RepositoryFactory.clear() + + print("โœ… RepositoryFactory test passed!") + + +def test_repository_with_routes(): + """Test using repositories in route handlers""" + print("๐Ÿงช Testing repository integration with routes...") with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) + + # Create routes directory routes_path = temp_path / "routes" routes_path.mkdir() + (routes_path / "__init__.py").touch() - # Create standard routes - # routes/index.py -> / - (routes_path / "index.py").write_text( - "from runapi import JSONResponse\n" - 'async def get(): return JSONResponse({"route": "index"})', - encoding="utf-8", - ) + api_path = routes_path / "api" + api_path.mkdir() + (api_path / "__init__.py").touch() - # routes/about.py -> /about - (routes_path / "about.py").write_text( - "from runapi import JSONResponse\n" - 'async def get(): return JSONResponse({"route": "about"})', - encoding="utf-8", - ) + # Create route that uses repository + items_route = ''' +from runapi import JSONResponse, Request, InMemoryRepository +from datetime import datetime + +# Create repository instance +items_repo = InMemoryRepository() + +def serialize_item(item): + """Convert datetime objects to ISO format strings.""" + result = {} + for k, v in item.items(): + if isinstance(v, datetime): + result[k] = v.isoformat() + else: + result[k] = v + return result + +async def get(): + """Get all items.""" + items = await items_repo.get_all() + return JSONResponse([serialize_item(i) for i in items]) + +async def post(request: Request): + """Create a new item.""" + body = await request.json() + item = await items_repo.create(body) + return JSONResponse(serialize_item(item), status_code=201) +''' + (api_path / "items.py").write_text(items_route, encoding="utf-8") old_cwd = os.getcwd() try: os.chdir(temp_dir) + from runapi import create_runapi_app app = create_runapi_app() fastapi_app = app.get_app() with TestClient(fastapi_app) as client: - # Test index route detection - resp = client.get("/") - assert resp.status_code == 200 - assert resp.json()["route"] == "index" + # Initially empty + response = client.get("/api/items") + assert response.status_code == 200 + assert response.json() == [] + + # Create item + response = client.post("/api/items", json={"name": "Test Item", "price": 9.99}) + assert response.status_code == 201 + data = response.json() + assert data["id"] == 1 + assert data["name"] == "Test Item" - # Test specific route detection - resp = client.get("/about") - assert resp.status_code == 200 - assert resp.json()["route"] == "about" + # Now should have one item + response = client.get("/api/items") + assert response.status_code == 200 + items = response.json() + assert len(items) == 1 finally: os.chdir(old_cwd) - print("โœ… Router discovery test passed!") + print("โœ… Repository integration with routes test passed!") + + +def test_crud_service(): + """Test CRUDService basic operations""" + print("๐Ÿงช Testing CRUDService...") + + from runapi import CRUDService, InMemoryRepository, NotFoundError + + async def run_tests(): + repo = InMemoryRepository() + service = CRUDService(repo, "User") + + # Test create + user = await service.create({"name": "John", "email": "john@example.com"}) + assert user["id"] == 1 + assert user["name"] == "John" + + # Test get + fetched = await service.get(1) + assert fetched["name"] == "John" + + # Test get - not found + try: + await service.get(999) + raise AssertionError("Should have raised NotFoundError") + except NotFoundError as e: + assert "999" in str(e) + + # Test get_or_none + result = await service.get_or_none(999) + assert result is None + + # Test get_all + await service.create({"name": "Jane", "email": "jane@example.com"}) + all_users = await service.get_all() + assert len(all_users) == 2 + + # Test update + updated = await service.update(1, {"name": "Johnny"}) + assert updated["name"] == "Johnny" + + # Test update - not found + try: + await service.update(999, {"name": "Nobody"}) + raise AssertionError("Should have raised NotFoundError") + except NotFoundError: + pass + + # Test delete + deleted = await service.delete(1) + assert deleted + + # Test delete - not found + try: + await service.delete(999) + raise AssertionError("Should have raised NotFoundError") + except NotFoundError: + pass + + # Test exists + assert await service.exists(2) + assert not await service.exists(999) + + # Test count + count = await service.count() + assert count == 1 + + repo.clear() + + asyncio.run(run_tests()) + print("โœ… CRUDService test passed!") + + +def test_validated_service(): + """Test ValidatedService with schema validation""" + print("๐Ÿงช Testing ValidatedService...") + + from typing import Optional + + from pydantic import BaseModel, Field + + from runapi import InMemoryRepository, ValidatedService + + class UserCreate(BaseModel): + name: str = Field(..., min_length=1) + email: str + + class UserUpdate(BaseModel): + name: Optional[str] = None + email: Optional[str] = None + + class UserService(ValidatedService): + create_schema = UserCreate + update_schema = UserUpdate + + async def run_tests(): + repo = InMemoryRepository() + service = UserService(repo, "User") + + # Test create with validation + user = await service.create({"name": "Alice", "email": "alice@example.com"}) + assert user["name"] == "Alice" + + # Test create with invalid data + try: + await service.create({"name": "", "email": "test@example.com"}) + raise AssertionError("Should have raised validation error") + except Exception: + pass # Pydantic validation error expected + + # Test update with validation + updated = await service.update(user["id"], {"name": "Alicia"}) + assert updated["name"] == "Alicia" + repo.clear() -def test_nested_routing_behavior(): - """Test nested routing behavior""" - print("๐Ÿงช Testing nested routing behavior...") + asyncio.run(run_tests()) + print("โœ… ValidatedService test passed!") + + +def test_service_factory(): + """Test ServiceFactory registration and creation""" + print("๐Ÿงช Testing ServiceFactory...") + + from runapi import CRUDService, InMemoryRepository, ServiceFactory + + # Clear any existing registrations + ServiceFactory.clear() + + # Create repository + repo = InMemoryRepository() + + # Test register + ServiceFactory.register("users", CRUDService, repo, "User") + + # Test list + services = ServiceFactory.list_services() + assert "users" in services + + # Test get (creates and caches) + service1 = ServiceFactory.get("users") + service2 = ServiceFactory.get("users") + assert service1 is service2 # Same instance + + # Test create (new instance) + service3 = ServiceFactory.create("users") + assert service3 is not service1 # Different instance + + # Clean up + ServiceFactory.clear() + + print("โœ… ServiceFactory test passed!") + + +def test_service_with_routes(): + """Test using services in route handlers""" + print("๐Ÿงช Testing service integration with routes...") with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) + + # Create routes directory routes_path = temp_path / "routes" routes_path.mkdir() + (routes_path / "__init__.py").touch() - # Create nested structure: routes/api/v1/users.py - api_path = routes_path / "api" / "v1" - api_path.mkdir(parents=True) - (routes_path / "api" / "__init__.py").touch() - (routes_path / "api" / "v1" / "__init__.py").touch() + api_path = routes_path / "api" + api_path.mkdir() + (api_path / "__init__.py").touch() - (api_path / "users.py").write_text( - "from runapi import JSONResponse\n" - 'async def get(): return JSONResponse({"path": "users"})', - encoding="utf-8", - ) + # Create route that uses service + products_route = ''' +from runapi import JSONResponse, Request, InMemoryRepository, CRUDService, NotFoundError +from datetime import datetime + +# Setup service layer +product_repo = InMemoryRepository() +product_service = CRUDService(product_repo, "Product") + +def serialize(item): + """Convert datetime objects to ISO format strings.""" + result = {} + for k, v in item.items(): + if isinstance(v, datetime): + result[k] = v.isoformat() + else: + result[k] = v + return result - # routes/blog/posts.py - blog_path = routes_path / "blog" - blog_path.mkdir() - (blog_path / "__init__.py").touch() +async def get(): + """Get all products.""" + products = await product_service.get_all() + return JSONResponse([serialize(p) for p in products]) - (blog_path / "posts.py").write_text( - "from runapi import JSONResponse\n" - 'async def get(): return JSONResponse({"section": "blog", "type": "posts"})', - encoding="utf-8", - ) +async def post(request: Request): + """Create a new product.""" + body = await request.json() + product = await product_service.create(body) + return JSONResponse(serialize(product), status_code=201) +''' + (api_path / "products.py").write_text(products_route, encoding="utf-8") old_cwd = os.getcwd() try: os.chdir(temp_dir) + from runapi import create_runapi_app app = create_runapi_app() fastapi_app = app.get_app() with TestClient(fastapi_app) as client: - # Test nested API route - resp = client.get("/api/v1/users") - assert resp.status_code == 200 - assert resp.json()["path"] == "users" - - # Test blog route - resp = client.get("/blog/posts") - assert resp.status_code == 200 - data = resp.json() - assert data["section"] == "blog" - assert data["type"] == "posts" + # Initially empty + response = client.get("/api/products") + assert response.status_code == 200 + assert response.json() == [] + + # Create product via service + response = client.post("/api/products", json={"name": "Widget", "price": 19.99}) + assert response.status_code == 201 + data = response.json() + assert data["id"] == 1 + assert data["name"] == "Widget" + + # Verify created + response = client.get("/api/products") + assert response.status_code == 200 + products = response.json() + assert len(products) == 1 finally: os.chdir(old_cwd) - print("โœ… Nested routing behavior test passed!") + print("โœ… Service integration with routes test passed!") + + +def run_all_tests(): + """Run all tests""" + print("๐Ÿš€ Starting RunApi Framework Tests\n") + + tests = [ + test_basic_app_creation, + test_configuration_system, + test_error_handling, + test_authentication_system, + test_file_based_routing, + test_middleware_system, + test_dynamic_routes, + test_cors_configuration, + test_static_file_serving, + test_schema_system, + test_schema_auto_discovery, + test_schema_integration_with_routes, + test_repository_in_memory, + test_typed_repository, + test_repository_factory, + test_repository_with_routes, + test_crud_service, + test_validated_service, + test_service_factory, + test_service_with_routes, + ] + + passed = 0 + failed = 0 + + for test in tests: + try: + test() + passed += 1 + except Exception as e: + print(f"โŒ Test {test.__name__} failed: {e}") + failed += 1 + + print("\n๐Ÿ“Š Test Results:") + print(f"โœ… Passed: {passed}") + print(f"โŒ Failed: {failed}") + print(f"๐Ÿ“ˆ Success Rate: {passed / (passed + failed) * 100:.1f}%") + + if failed == 0: + print("\n๐ŸŽ‰ All tests passed! RunApi framework is working correctly.") + else: + print(f"\nโš ๏ธ {failed} test(s) failed. Please check the output above.") + + return failed == 0 if __name__ == "__main__": diff --git a/tests/test_runapi_installation.py b/tests/test_runapi_installation.py index 83fc7b6..9476c0d 100644 --- a/tests/test_runapi_installation.py +++ b/tests/test_runapi_installation.py @@ -5,14 +5,11 @@ """ import os -import sys +import shutil import subprocess +import sys import tempfile -import shutil -import time -import signal from pathlib import Path -import threading class RunApiTester: @@ -21,36 +18,31 @@ def __init__(self): self.original_cwd = os.getcwd() self.success_count = 0 self.total_tests = 0 - + def log(self, message, status="INFO"): """Log test messages with status""" - status_symbols = { - "INFO": "โ„น๏ธ", - "SUCCESS": "โœ…", - "ERROR": "โŒ", - "WARNING": "โš ๏ธ" - } + status_symbols = {"INFO": "โ„น๏ธ", "SUCCESS": "โœ…", "ERROR": "โŒ", "WARNING": "โš ๏ธ"} symbol = status_symbols.get(status, "โ€ข") print(f"{symbol} {message}") - + def run_command(self, cmd, timeout=30, expect_success=True): """Run a command and return result""" try: result = subprocess.run( - cmd, - shell=True, - capture_output=True, - text=True, + cmd, + shell=True, + capture_output=True, + text=True, timeout=timeout, - cwd=self.test_dir or self.original_cwd + cwd=self.test_dir or self.original_cwd, ) - + if expect_success and result.returncode != 0: self.log(f"Command failed: {cmd}", "ERROR") self.log(f"STDOUT: {result.stdout}", "ERROR") self.log(f"STDERR: {result.stderr}", "ERROR") return False, result - + return True, result except subprocess.TimeoutExpired: self.log(f"Command timed out: {cmd}", "WARNING") @@ -58,12 +50,12 @@ def run_command(self, cmd, timeout=30, expect_success=True): except Exception as e: self.log(f"Command exception: {cmd} - {e}", "ERROR") return False, None - + def test_cli_available(self): """Test 1: Check if runapi CLI is available""" self.total_tests += 1 self.log("Testing CLI availability...") - + success, result = self.run_command("runapi --help") if success and "RunApi" in result.stdout: self.log("CLI is available and working", "SUCCESS") @@ -72,48 +64,42 @@ def test_cli_available(self): else: self.log("CLI is not available or not working", "ERROR") return False - + def test_project_creation(self): """Test 2: Create a new project""" self.total_tests += 1 self.log("Testing project creation...") - + # Create temporary directory self.test_dir = tempfile.mkdtemp(prefix="runapi_test_") os.chdir(self.test_dir) - + success, result = self.run_command("runapi init testproject") if success and Path("testproject").exists(): self.log("Project created successfully", "SUCCESS") self.success_count += 1 - + # Change to project directory self.test_dir = str(Path(self.test_dir) / "testproject") os.chdir(self.test_dir) - + return True else: self.log("Project creation failed", "ERROR") return False - + def test_project_structure(self): """Test 3: Verify project structure""" self.total_tests += 1 self.log("Testing project structure...") - - required_files = [ - "main.py", - ".env", - "README.md", - "routes/index.py", - "routes/api/hello.py" - ] - + + required_files = ["main.py", ".env", "README.md", "routes/index.py", "routes/api/hello.py"] + missing_files = [] for file_path in required_files: if not Path(file_path).exists(): missing_files.append(file_path) - + if not missing_files: self.log("All required files present", "SUCCESS") self.success_count += 1 @@ -121,13 +107,13 @@ def test_project_structure(self): else: self.log(f"Missing files: {missing_files}", "ERROR") return False - + def test_main_import(self): """Test 4: Test if main.py can be imported""" self.total_tests += 1 self.log("Testing main.py import...") - - success, result = self.run_command('python -c "import main; print(\'SUCCESS\')"') + + success, result = self.run_command("python -c \"import main; print('SUCCESS')\"") if success and "SUCCESS" in result.stdout: self.log("main.py imports successfully", "SUCCESS") self.success_count += 1 @@ -137,13 +123,13 @@ def test_main_import(self): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def test_runapi_import(self): """Test 5: Test if runapi package can be imported""" self.total_tests += 1 self.log("Testing runapi package import...") - - success, result = self.run_command('python -c "import runapi; print(\'SUCCESS\')"') + + success, result = self.run_command("python -c \"import runapi; print('SUCCESS')\"") if success and "SUCCESS" in result.stdout: self.log("runapi package imports successfully", "SUCCESS") self.success_count += 1 @@ -153,13 +139,13 @@ def test_runapi_import(self): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def test_app_creation(self): """Test 6: Test app creation""" self.total_tests += 1 self.log("Testing app creation...") - - test_script = ''' + + test_script = """ import sys sys.path.insert(0, ".") try: @@ -171,8 +157,8 @@ def test_app_creation(self): except Exception as e: print(f"ERROR: {e}") sys.exit(1) -''' - +""" + success, result = self.run_command(f'python -c "{test_script}"') if success and "SUCCESS: App created" in result.stdout: self.log("App creation successful", "SUCCESS") @@ -183,14 +169,14 @@ def test_app_creation(self): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def test_uvicorn_direct(self): """Test 7: Test uvicorn directly""" self.total_tests += 1 self.log("Testing uvicorn direct import...") - + # Test if uvicorn can import the main:app - test_script = ''' + test_script = """ import sys import importlib.util sys.path.insert(0, ".") @@ -200,21 +186,21 @@ def test_uvicorn_direct(self): spec = importlib.util.spec_from_file_location("main", "main.py") main_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(main_module) - + if hasattr(main_module, "app"): print("SUCCESS: main:app accessible") else: print("ERROR: main.app not found") sys.exit(1) - + except Exception as e: print(f"ERROR: {e}") sys.exit(1) -''' - +""" + success, result = self.run_command(f'python -c "{test_script}"') if success and "SUCCESS: main:app accessible" in result.stdout: - self.log("uvicorn can access main:app", "SUCCESS") + self.log("uvicorn can access main:app", "SUCCESS") self.success_count += 1 return True else: @@ -222,14 +208,14 @@ def test_uvicorn_direct(self): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def test_server_startup(self): """Test 8: Test if server can start (without running indefinitely)""" self.total_tests += 1 self.log("Testing server startup (quick test)...") - + # Create a test script that starts the server and immediately stops it - test_script = ''' + test_script = """ import sys import os import threading @@ -240,7 +226,7 @@ def test_server(): try: import uvicorn import main - + # Test if we can create a server instance config = uvicorn.Config("main:app", host="127.0.0.1", port=8999) server = uvicorn.Server(config) @@ -254,8 +240,8 @@ def test_server(): sys.exit(0) else: sys.exit(1) -''' - +""" + success, result = self.run_command(f'python -c "{test_script}"') if success and "SUCCESS: Server can be created" in result.stdout: self.log("Server startup test passed", "SUCCESS") @@ -266,14 +252,14 @@ def test_server(): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def test_cli_dev_dry_run(self): """Test 9: Test CLI dev command validation (without actual server start)""" self.total_tests += 1 self.log("Testing CLI dev command validation...") - + # We'll test the CLI's pre-validation logic - test_script = ''' + test_script = """ import sys import os sys.path.insert(0, ".") @@ -284,28 +270,28 @@ def test_cli_dev_dry_run(self): if not os.path.exists("main.py"): print("ERROR: main.py not found") sys.exit(1) - + # Check if main can be imported import importlib.util spec = importlib.util.spec_from_file_location("main", "main.py") if spec is None: print("ERROR: Cannot load main.py") sys.exit(1) - + main_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(main_module) - + if not hasattr(main_module, "app"): print("ERROR: main.py does not have app attribute") sys.exit(1) - + print("SUCCESS: CLI validation passed") - + except Exception as e: print(f"ERROR: {e}") sys.exit(1) -''' - +""" + success, result = self.run_command(f'python -c "{test_script}"') if success and "SUCCESS: CLI validation passed" in result.stdout: self.log("CLI dev command validation passed", "SUCCESS") @@ -316,26 +302,30 @@ def test_cli_dev_dry_run(self): if result: self.log(f"Error: {result.stderr}", "ERROR") return False - + def cleanup(self): """Clean up test directory""" try: os.chdir(self.original_cwd) if self.test_dir and Path(self.test_dir).exists(): # Go up to temp directory and remove the whole test dir - test_root = Path(self.test_dir).parents[0] if "testproject" in self.test_dir else Path(self.test_dir) + test_root = ( + Path(self.test_dir).parents[0] + if "testproject" in self.test_dir + else Path(self.test_dir) + ) shutil.rmtree(test_root, ignore_errors=True) self.log("Test directory cleaned up", "INFO") except Exception as e: self.log(f"Cleanup warning: {e}", "WARNING") - + def run_all_tests(self): """Run all tests""" self.log("๐Ÿš€ Starting RunApi Installation Tests", "INFO") self.log(f"Python: {sys.executable}", "INFO") self.log(f"Working directory: {os.getcwd()}", "INFO") print("-" * 60) - + try: # Run tests in sequence tests = [ @@ -349,7 +339,7 @@ def run_all_tests(self): self.test_server_startup, self.test_cli_dev_dry_run, ] - + for i, test in enumerate(tests, 1): self.log(f"Running test {i}/{len(tests)}: {test.__name__}", "INFO") try: @@ -357,15 +347,18 @@ def run_all_tests(self): except Exception as e: self.log(f"Test {test.__name__} threw exception: {e}", "ERROR") print("-" * 40) - + finally: self.cleanup() - + # Results print("=" * 60) self.log("๐Ÿ TEST RESULTS", "INFO") - self.log(f"Passed: {self.success_count}/{self.total_tests}", "SUCCESS" if self.success_count == self.total_tests else "WARNING") - + self.log( + f"Passed: {self.success_count}/{self.total_tests}", + "SUCCESS" if self.success_count == self.total_tests else "WARNING", + ) + if self.success_count == self.total_tests: self.log("๐ŸŽ‰ All tests passed! RunApi is working correctly.", "SUCCESS") return True @@ -378,4 +371,4 @@ def run_all_tests(self): if __name__ == "__main__": tester = RunApiTester() success = tester.run_all_tests() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1)