diff --git a/.gitignore b/.gitignore index 92a6143..7be981c 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,5 @@ cython_debug/ .cursorignore .cursorindexingignore -alembic/versions/ \ No newline at end of file +alembic/versions/ +.DS_Store \ No newline at end of file diff --git a/config/database.py b/config/database.py index a614eae..b3052d1 100644 --- a/config/database.py +++ b/config/database.py @@ -1,5 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.orm import sessionmaker, DeclarativeBase import os from dotenv import load_dotenv import warnings @@ -13,8 +13,19 @@ # Read the database URL from the environment variable DATABASE_URL = os.getenv("DATABASE_URL") if DATABASE_URL is not None: - # Create the SQLAlchemy engine with the database URL - engine = create_engine(DATABASE_URL) + # Create the SQLAlchemy engine with the database URL. + # pool_size / max_overflow cap concurrent DB connections so one slow + # query can't starve the whole server. statement_timeout kills any + # single query that runs longer than 60 s so a runaway request doesn't + # hold a connection indefinitely. + engine = create_engine( + DATABASE_URL, + pool_size=10, + max_overflow=5, + pool_timeout=30, + pool_pre_ping=True, + connect_args={"options": "-c statement_timeout=60000"}, + ) # Create a session factory SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) else: @@ -23,7 +34,8 @@ # Base is the base class for all the SQLAlchemy ORM models. # It tells SQLAlchemy that a model maps to a real table. # Without inheriting from Base, the class won’t be recognized by SQLAlchemy’s ORM. -Base = declarative_base() +class Base(DeclarativeBase): + pass # Dependency to get a DB session for FastAPI routes (used in controllers) def get_db(): diff --git a/controllers/disco_controller.py b/controllers/disco_controller.py index 4b71af5..1a48838 100644 --- a/controllers/disco_controller.py +++ b/controllers/disco_controller.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request from sqlalchemy.orm import Session from services.disco_service import DiscoService from dtos.generic_response_dto import GenericResponseDTO, build_url @@ -56,16 +56,24 @@ async def get_events( None, description="Total number of Atlas probes active in the reported stream (ASN, Country, or geographical area)."), ongoing: Optional[str] = Query( None, description="Deprecated, this value is unused"), + include_probe_details: bool = Query( + False, description="Include per-probe details in the response."), page: Optional[int] = Query( 1, ge=1, description="A page number within the paginated result set."), ordering: Optional[str] = Query( None, description="Which field to use when ordering the results") ) -> GenericResponseDTO[DiscoEventsDTO]: """ - List network disconnections detected with RIPE Atlas. + List network disconnections detected with RIPE Atlas. These events have different levels of granularity - it can be at a network level (AS), city, or country level. """ + if not any([starttime, starttime__gte, starttime__lte, endtime, endtime__gte, endtime__lte]): + raise HTTPException( + status_code=400, + detail="At least one time parameter is required: starttime, starttime__gte, starttime__lte, endtime, endtime__gte, or endtime__lte." + ) + events_data, total_count = DiscoController.service.get_disco_events( db, streamname=streamname, @@ -86,6 +94,7 @@ async def get_events( totalprobes_gte=totalprobes__gte, totalprobes_lte=totalprobes__lte, ongoing=ongoing, + include_probe_details=include_probe_details, page=page, order_by=ordering ) diff --git a/docs/add_new_endpoint.md b/docs/add_new_endpoint.md index 5b2aa19..0bf875e 100644 --- a/docs/add_new_endpoint.md +++ b/docs/add_new_endpoint.md @@ -19,6 +19,55 @@ Create a service file in the `services/` directory or modify an existing one. Th ### 3. **Create the Repository** Add a repository file in the `repositories/` directory or modify an existing one. Ensure it handles pagination and ordering using `offset` and `limit`. +Use the SQLAlchemy 2.0 `select()` style — **not** the legacy `db.query()` API. + +Example: +```python +# filepath: repositories/new_entity_repository.py +from sqlalchemy.orm import Session +from sqlalchemy import select, func +from models.new_entity_model import NewEntity +from typing import Optional, List, Tuple +from utils import page_size + + +class NewEntityRepository: + def get_all( + self, + db: Session, + field1: Optional[str] = None, + page: int = 1, + order_by: Optional[str] = None, + ) -> Tuple[List[NewEntity], int]: + stmt = select(NewEntity) + + if field1: + stmt = stmt.where(NewEntity.field1 == field1) + + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) + + if order_by and hasattr(NewEntity, order_by): + stmt = stmt.order_by(getattr(NewEntity, order_by)) + + offset = (page - 1) * page_size + results = db.scalars(stmt.offset(offset).limit(page_size)).all() + + return results, total_count +``` + +If the model has a relationship that needs to be loaded eagerly alongside the main query, use `contains_eager` with `of_type()`: +```python +from sqlalchemy.orm import contains_eager, aliased + +RelatedModel = aliased(NewEntity.related_relation.property.mapper.class_) +stmt = ( + select(NewEntity) + .join(NewEntity.related_relation.of_type(RelatedModel)) + .options(contains_eager(NewEntity.related_relation.of_type(RelatedModel))) +) +# then add .where() clauses and call db.scalars(...).unique().all() +``` + --- ### 4. **Define the Model** @@ -88,14 +137,13 @@ Add a DTO in the `dtos/` directory to define the structure of the response. Example: ```python # filepath: dtos/new_entity_dto.py -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class NewEntityDTO(BaseModel): field1: str field2: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) ``` --- diff --git a/dtos/country_dto.py b/dtos/country_dto.py index 268f47d..61eca5f 100644 --- a/dtos/country_dto.py +++ b/dtos/country_dto.py @@ -1,8 +1,7 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class CountryDTO(BaseModel): code: str name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/disco_events_dto.py b/dtos/disco_events_dto.py index f6aeaad..30e7b01 100644 --- a/dtos/disco_events_dto.py +++ b/dtos/disco_events_dto.py @@ -1,5 +1,5 @@ from dtos.disco_probes_dto import DiscoProbesDTO -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import List, Optional @@ -13,13 +13,12 @@ class DiscoEventsDTO(BaseModel): nbdiscoprobes: int totalprobes: int ongoing: bool - discoprobes: List[DiscoProbesDTO] + discoprobes: Optional[List[DiscoProbesDTO]] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) @staticmethod - def from_model(disco_event): + def from_model(disco_event, include_probe_details: bool = False): return DiscoEventsDTO( id=disco_event.id, streamtype=disco_event.streamtype, @@ -30,6 +29,6 @@ def from_model(disco_event): nbdiscoprobes=disco_event.nbdiscoprobes, totalprobes=disco_event.totalprobes, ongoing=disco_event.ongoing, - discoprobes=[DiscoProbesDTO.from_orm( - probe) for probe in disco_event.probes] + discoprobes=[DiscoProbesDTO.model_validate(probe) for probe in disco_event.probes] + if include_probe_details else None, ) diff --git a/dtos/disco_probes_dto.py b/dtos/disco_probes_dto.py index 030b9ff..95e09dc 100644 --- a/dtos/disco_probes_dto.py +++ b/dtos/disco_probes_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -13,5 +13,4 @@ class DiscoProbesDTO(BaseModel): lat: float lon: float - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/hegemony_alarms_dto.py b/dtos/hegemony_alarms_dto.py index c9469eb..b1b1e20 100644 --- a/dtos/hegemony_alarms_dto.py +++ b/dtos/hegemony_alarms_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -11,5 +11,4 @@ class HegemonyAlarmsDTO(BaseModel): asn_name: str originasn_name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/hegemony_cone_dto.py b/dtos/hegemony_cone_dto.py index d92089a..02d277a 100644 --- a/dtos/hegemony_cone_dto.py +++ b/dtos/hegemony_cone_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -8,5 +8,4 @@ class HegemonyConeDTO(BaseModel): conesize: int af: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/hegemony_country_dto.py b/dtos/hegemony_country_dto.py index 43d6717..3f84dab 100644 --- a/dtos/hegemony_country_dto.py +++ b/dtos/hegemony_country_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -13,5 +13,4 @@ class HegemonyCountryDTO(BaseModel): weightscheme: str transitonly: bool - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/hegemony_dto.py b/dtos/hegemony_dto.py index efb4d51..9edbee7 100644 --- a/dtos/hegemony_dto.py +++ b/dtos/hegemony_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -11,5 +11,4 @@ class HegemonyDTO(BaseModel): asn_name: str originasn_name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/hegemony_prefix_dto.py b/dtos/hegemony_prefix_dto.py index 0264ce7..c0a266f 100644 --- a/dtos/hegemony_prefix_dto.py +++ b/dtos/hegemony_prefix_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -20,5 +20,4 @@ class HegemonyPrefixDTO(BaseModel): originasn_name: str asn_name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/metis_atlas_deployment_dto.py b/dtos/metis_atlas_deployment_dto.py index 482a4bf..21ae16a 100644 --- a/dtos/metis_atlas_deployment_dto.py +++ b/dtos/metis_atlas_deployment_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -11,5 +11,4 @@ class MetisAtlasDeploymentDTO(BaseModel): nbsamples: int asn_name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/metis_atlas_selection_dto.py b/dtos/metis_atlas_selection_dto.py index 4b582de..eed57d4 100644 --- a/dtos/metis_atlas_selection_dto.py +++ b/dtos/metis_atlas_selection_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -10,5 +10,4 @@ class MetisAtlasSelectionDTO(BaseModel): af: int asn_name: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/network_delay_alarms_dto.py b/dtos/network_delay_alarms_dto.py index 6416fd7..09ae87c 100644 --- a/dtos/network_delay_alarms_dto.py +++ b/dtos/network_delay_alarms_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -12,8 +12,7 @@ class NetworkDelayAlarmsDTO(BaseModel): endpoint_af: int deviation: float - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) @staticmethod def from_model(atlas_delay_alarm): diff --git a/dtos/network_delay_dto.py b/dtos/network_delay_dto.py index 5e6ce95..26449cd 100644 --- a/dtos/network_delay_dto.py +++ b/dtos/network_delay_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -17,8 +17,7 @@ class NetworkDelayDTO(BaseModel): hop: int nbrealrtts: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) @staticmethod def from_model(atlasDelay): diff --git a/dtos/network_delay_locations_dto.py b/dtos/network_delay_locations_dto.py index 6613843..097bcea 100644 --- a/dtos/network_delay_locations_dto.py +++ b/dtos/network_delay_locations_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class NetworkDelayLocationsDTO(BaseModel): @@ -6,5 +6,4 @@ class NetworkDelayLocationsDTO(BaseModel): name: str af: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) diff --git a/dtos/networks_dto.py b/dtos/networks_dto.py index 237859b..44a9901 100644 --- a/dtos/networks_dto.py +++ b/dtos/networks_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class NetworksDTO(BaseModel): @@ -8,8 +8,7 @@ class NetworksDTO(BaseModel): delay_forwarding: bool disco: bool - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) @staticmethod def from_model(asn): diff --git a/dtos/tr_hegemony_dto.py b/dtos/tr_hegemony_dto.py index 9c5c809..31719e7 100644 --- a/dtos/tr_hegemony_dto.py +++ b/dtos/tr_hegemony_dto.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from datetime import datetime @@ -14,8 +14,7 @@ class TRHegemonyDTO(BaseModel): af: int nbsamples: int - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) @staticmethod def from_model(tr_hegemony): diff --git a/main.py b/main.py index ed82620..52ad66e 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,13 @@ import importlib import pkgutil -from fastapi import FastAPI +from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware +from fastapi.routing import APIRoute +from fastapi.responses import JSONResponse from controllers import __path__ as controllers_path from dotenv import load_dotenv import os +from starlette.routing import Match try: load_dotenv() @@ -32,11 +35,40 @@ root_path="" if PROXY_PATH is None else f"/{PROXY_PATH}", title="IHR API", description=description, - version="v1.13", + version="v2.0", redoc_url=None, - swagger_ui_parameters={ "defaultModelsExpandDepth": -1 } + swagger_ui_parameters={ "defaultModelsExpandDepth": -1 }, ) +@app.middleware("http") +async def reject_unknown_query_params(request: Request, call_next): + scope = request.scope + allowed_params = None + for route in app.routes: + if not isinstance(route, APIRoute): + continue + # Use Starlette's own matching — handles path params, methods, everything + match, _ = route.matches(scope) + if match == Match.FULL: + allowed_params = {p.name for p in route.dependant.query_params} + break + if allowed_params is not None: + extra = set(request.query_params.keys()) - allowed_params + if extra: + return JSONResponse( + { + "error": "invalid_query_params", + "unexpected": sorted(extra), + "allowed": sorted(allowed_params), + }, + status_code=400, + ) + return await call_next(request) + +@app.get("/favicon.ico", include_in_schema=False) +async def favicon(): + return Response(status_code=204) + # Automatically import and register all routers inside "controllers" for _, module_name, _ in pkgutil.iter_modules(controllers_path): module = importlib.import_module(f"controllers.{module_name}") @@ -54,4 +86,4 @@ allow_origins=origins, allow_methods=["GET"], allow_headers=["*"], -) +) \ No newline at end of file diff --git a/repositories/atlas_delay_alarms_repository.py b/repositories/atlas_delay_alarms_repository.py index 694c04d..c58d7d1 100644 --- a/repositories/atlas_delay_alarms_repository.py +++ b/repositories/atlas_delay_alarms_repository.py @@ -1,10 +1,9 @@ from sqlalchemy.orm import Session, aliased, contains_eager -from sqlalchemy import and_, or_ +from sqlalchemy import select, func, and_, or_ from models.atlas_delay_alarms import AtlasDelayAlarms from datetime import datetime from typing import List, Optional, Tuple from utils import page_size -from sqlalchemy import func class AtlasDelayAlarmsRepository: @@ -30,39 +29,38 @@ def get_alarms( """ Get network delay alarms with all possible filters. """ - Startpoint = aliased( - AtlasDelayAlarms.startpoint_relation.property.mapper.class_) - Endpoint = aliased( - AtlasDelayAlarms.endpoint_relation.property.mapper.class_) - - query = db.query(AtlasDelayAlarms)\ - .join(Startpoint, AtlasDelayAlarms.startpoint_relation)\ - .join(Endpoint, AtlasDelayAlarms.endpoint_relation)\ - .options( - contains_eager(AtlasDelayAlarms.startpoint_relation, alias=Startpoint), - contains_eager(AtlasDelayAlarms.endpoint_relation, alias=Endpoint) - ) - + Startpoint = aliased(AtlasDelayAlarms.startpoint_relation.property.mapper.class_) + Endpoint = aliased(AtlasDelayAlarms.endpoint_relation.property.mapper.class_) + + stmt = ( + select(AtlasDelayAlarms) + .join(AtlasDelayAlarms.startpoint_relation.of_type(Startpoint)) + .join(AtlasDelayAlarms.endpoint_relation.of_type(Endpoint)) + .options( + contains_eager(AtlasDelayAlarms.startpoint_relation.of_type(Startpoint)), + contains_eager(AtlasDelayAlarms.endpoint_relation.of_type(Endpoint)) + ) + ) # If no time filters specified, get rows with max timebin if not timebin and not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(AtlasDelayAlarms.timebin)).scalar() - query = query.filter(AtlasDelayAlarms.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(AtlasDelayAlarms.timebin))) + stmt = stmt.where(AtlasDelayAlarms.timebin == max_timebin) + if timebin: - query = query.filter(AtlasDelayAlarms.timebin == timebin) + stmt = stmt.where(AtlasDelayAlarms.timebin == timebin) if timebin_gte: - query = query.filter(AtlasDelayAlarms.timebin >= timebin_gte) + stmt = stmt.where(AtlasDelayAlarms.timebin >= timebin_gte) if timebin_lte: - query = query.filter(AtlasDelayAlarms.timebin <= timebin_lte) + stmt = stmt.where(AtlasDelayAlarms.timebin <= timebin_lte) if startpoint_names: names = startpoint_names.split('|') - query = query.filter(Startpoint.name.in_(names)) + stmt = stmt.where(Startpoint.name.in_(names)) if startpoint_type: - query = query.filter(Startpoint.type == startpoint_type) + stmt = stmt.where(Startpoint.type == startpoint_type) if startpoint_af: - query = query.filter(Startpoint.af == startpoint_af) + stmt = stmt.where(Startpoint.af == startpoint_af) if startpoint_key: startpoint_conditions = [] for key in startpoint_key.split('|'): @@ -83,15 +81,15 @@ def get_alarms( startpoint_conditions.append(and_(*conditions)) if startpoint_conditions: - query = query.filter(or_(*startpoint_conditions)) + stmt = stmt.where(or_(*startpoint_conditions)) if endpoint_names: names = endpoint_names.split('|') - query = query.filter(Endpoint.name.in_(names)) + stmt = stmt.where(Endpoint.name.in_(names)) if endpoint_type: - query = query.filter(Endpoint.type == endpoint_type) + stmt = stmt.where(Endpoint.type == endpoint_type) if endpoint_af: - query = query.filter(Endpoint.af == endpoint_af) + stmt = stmt.where(Endpoint.af == endpoint_af) if endpoint_key: endpoint_conditions = [] for key in endpoint_key.split('|'): @@ -112,19 +110,19 @@ def get_alarms( endpoint_conditions.append(and_(*conditions)) if endpoint_conditions: - query = query.filter(or_(*endpoint_conditions)) + stmt = stmt.where(or_(*endpoint_conditions)) if deviation_gte: - query = query.filter(AtlasDelayAlarms.deviation >= deviation_gte) + stmt = stmt.where(AtlasDelayAlarms.deviation >= deviation_gte) if deviation_lte: - query = query.filter(AtlasDelayAlarms.deviation <= deviation_lte) + stmt = stmt.where(AtlasDelayAlarms.deviation <= deviation_lte) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) if order_by and hasattr(AtlasDelayAlarms, order_by): - query = query.order_by(getattr(AtlasDelayAlarms, order_by)) + stmt = stmt.order_by(getattr(AtlasDelayAlarms, order_by)) offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/atlas_delay_repository.py b/repositories/atlas_delay_repository.py index 0cac3ca..1222319 100644 --- a/repositories/atlas_delay_repository.py +++ b/repositories/atlas_delay_repository.py @@ -1,10 +1,9 @@ from sqlalchemy.orm import Session, aliased, contains_eager -from sqlalchemy import and_, or_ +from sqlalchemy import select, func, and_, or_ from models.atlas_delay import AtlasDelay from datetime import datetime from typing import List, Optional, Tuple from utils import page_size -from sqlalchemy import func class AtlasDelayRepository: @@ -34,39 +33,39 @@ def get_delays( # Create SQLAlchemy aliases for the AtlasLocation table, used in both startpoint and endpoint relationships. # This is necessary because we are joining the same table (AtlasLocation) twice in the query, # and SQL requires different aliases for each instance to avoid ambiguity. - Startpoint = aliased( - AtlasDelay.startpoint_relation.property.mapper.class_) + Startpoint = aliased(AtlasDelay.startpoint_relation.property.mapper.class_) Endpoint = aliased(AtlasDelay.endpoint_relation.property.mapper.class_) - query = db.query(AtlasDelay)\ - .join(Startpoint, AtlasDelay.startpoint_relation)\ - .join(Endpoint, AtlasDelay.endpoint_relation)\ - .options( - contains_eager(AtlasDelay.startpoint_relation, alias=Startpoint), - contains_eager(AtlasDelay.endpoint_relation, alias=Endpoint) - ) - + stmt = ( + select(AtlasDelay) + .join(AtlasDelay.startpoint_relation.of_type(Startpoint)) + .join(AtlasDelay.endpoint_relation.of_type(Endpoint)) + .options( + contains_eager(AtlasDelay.startpoint_relation.of_type(Startpoint)), + contains_eager(AtlasDelay.endpoint_relation.of_type(Endpoint)) + ) + ) # If no time filters specified, get rows with max timebin if not timebin and not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(AtlasDelay.timebin)).scalar() - query = query.filter(AtlasDelay.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(AtlasDelay.timebin))) + stmt = stmt.where(AtlasDelay.timebin == max_timebin) + # Apply timebin filters if timebin: - query = query.filter(AtlasDelay.timebin == timebin) + stmt = stmt.where(AtlasDelay.timebin == timebin) if timebin_gte: - query = query.filter(AtlasDelay.timebin >= timebin_gte) + stmt = stmt.where(AtlasDelay.timebin >= timebin_gte) if timebin_lte: - query = query.filter(AtlasDelay.timebin <= timebin_lte) + stmt = stmt.where(AtlasDelay.timebin <= timebin_lte) if startpoint_names: names = startpoint_names.split('|') - query = query.filter(Startpoint.name.in_(names)) + stmt = stmt.where(Startpoint.name.in_(names)) if startpoint_type: - query = query.filter(Startpoint.type == startpoint_type) + stmt = stmt.where(Startpoint.type == startpoint_type) if startpoint_af: - query = query.filter(Startpoint.af == startpoint_af) + stmt = stmt.where(Startpoint.af == startpoint_af) if startpoint_key: startpoint_conditions = [] for key in startpoint_key.split('|'): @@ -87,15 +86,15 @@ def get_delays( startpoint_conditions.append(and_(*conditions)) if startpoint_conditions: - query = query.filter(or_(*startpoint_conditions)) + stmt = stmt.where(or_(*startpoint_conditions)) if endpoint_names: names = endpoint_names.split('|') - query = query.filter(Endpoint.name.in_(names)) + stmt = stmt.where(Endpoint.name.in_(names)) if endpoint_type: - query = query.filter(Endpoint.type == endpoint_type) + stmt = stmt.where(Endpoint.type == endpoint_type) if endpoint_af: - query = query.filter(Endpoint.af == endpoint_af) + stmt = stmt.where(Endpoint.af == endpoint_af) if endpoint_key: endpoint_conditions = [] for key in endpoint_key.split('|'): @@ -116,23 +115,23 @@ def get_delays( endpoint_conditions.append(and_(*conditions)) if endpoint_conditions: - query = query.filter(or_(*endpoint_conditions)) + stmt = stmt.where(or_(*endpoint_conditions)) if median: - query = query.filter(AtlasDelay.median == median) + stmt = stmt.where(AtlasDelay.median == median) if median_gte: - query = query.filter(AtlasDelay.median >= median_gte) + stmt = stmt.where(AtlasDelay.median >= median_gte) if median_lte: - query = query.filter(AtlasDelay.median <= median_lte) + stmt = stmt.where(AtlasDelay.median <= median_lte) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(AtlasDelay, order_by): - query = query.order_by(getattr(AtlasDelay, order_by)) + stmt = stmt.order_by(getattr(AtlasDelay, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/atlas_location_repository.py b/repositories/atlas_location_repository.py index 4d30db9..f2f5251 100644 --- a/repositories/atlas_location_repository.py +++ b/repositories/atlas_location_repository.py @@ -1,4 +1,5 @@ from sqlalchemy.orm import Session +from sqlalchemy import select, func from models.atlas_location import AtlasLocation from typing import Optional, List, Tuple from utils import page_size @@ -14,24 +15,24 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[AtlasLocation], int]: - query = db.query(AtlasLocation) + stmt = select(AtlasLocation) # Apply filters if name: - query = query.filter(AtlasLocation.name.ilike(f"%{name}%")) + stmt = stmt.where(AtlasLocation.name.ilike(f"%{name}%")) if type: - query = query.filter(AtlasLocation.type == type) + stmt = stmt.where(AtlasLocation.type == type) if af: - query = query.filter(AtlasLocation.af == af) + stmt = stmt.where(AtlasLocation.af == af) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(AtlasLocation, order_by): - query = query.order_by(getattr(AtlasLocation, order_by)) + stmt = stmt.order_by(getattr(AtlasLocation, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).all() return results, total_count diff --git a/repositories/country_repository.py b/repositories/country_repository.py index 76a3b4a..be07fd9 100644 --- a/repositories/country_repository.py +++ b/repositories/country_repository.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session +from sqlalchemy import select, func, asc from models.country import Country -from typing import Optional, List, Tuple # Added Tuple for return type -from sqlalchemy import asc +from typing import Optional, List, Tuple from utils import page_size @@ -11,32 +11,32 @@ def get_all( db: Session, code: Optional[str] = None, name: Optional[str] = None, - page: int = 1, # Page number, defaults to 1 - order_by: Optional[str] = None, # Column name to sort by - ) -> Tuple[List[Country], int]: # Returns list of countries and total count + page: int = 1, + order_by: Optional[str] = None, + ) -> Tuple[List[Country], int]: """ Retrieves countries with pagination and ordering at database level. Returns: Tuple[List[Country], total_count] """ # Initialize base query - query = db.query(Country) + stmt = select(Country) # Apply filters if provided if code: - query = query.filter(Country.code == code) + stmt = stmt.where(Country.code == code) if name: - query = query.filter(Country.name.ilike(f"%{name}%")) + stmt = stmt.where(Country.name.ilike(f"%{name}%")) # Executes getting total count of countries - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if specified if order_by and hasattr(Country, order_by): - query = query.order_by(asc(getattr(Country, order_by))) + stmt = stmt.order_by(asc(getattr(Country, order_by))) # Calculate offset based on page number and size offset = (page - 1) * page_size # Apply pagination and execute query - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).all() return results, total_count diff --git a/repositories/disco_events_repository.py b/repositories/disco_events_repository.py index 39d5a82..9354325 100644 --- a/repositories/disco_events_repository.py +++ b/repositories/disco_events_repository.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import and_ +from sqlalchemy.orm import Session, joinedload, noload +from sqlalchemy import select, func from models.disco_events import DiscoEvents from datetime import datetime from typing import List, Optional, Tuple @@ -28,63 +28,67 @@ def get_disco_events( totalprobes_gte: Optional[int] = None, totalprobes_lte: Optional[int] = None, ongoing: Optional[str] = None, + include_probe_details: bool = False, page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[DiscoEvents], int]: - query = db.query(DiscoEvents) + # Build conditions as a list so the count query can reuse them without + # the joinedload option (which would expand rows and give a wrong count). + conditions = [] if streamname: - query = query.filter(DiscoEvents.streamname == streamname) + conditions.append(DiscoEvents.streamname == streamname) if streamtype: - query = query.filter(DiscoEvents.streamtype == streamtype) + conditions.append(DiscoEvents.streamtype == streamtype) if starttime: - query = query.filter(DiscoEvents.starttime == starttime) + conditions.append(DiscoEvents.starttime == starttime) if starttime_gte: - query = query.filter(DiscoEvents.starttime >= starttime_gte) + conditions.append(DiscoEvents.starttime >= starttime_gte) if starttime_lte: - query = query.filter(DiscoEvents.starttime <= starttime_lte) + conditions.append(DiscoEvents.starttime <= starttime_lte) if endtime: - query = query.filter(DiscoEvents.endtime == endtime) + conditions.append(DiscoEvents.endtime == endtime) if endtime_gte: - query = query.filter(DiscoEvents.endtime >= endtime_gte) + conditions.append(DiscoEvents.endtime >= endtime_gte) if endtime_lte: - query = query.filter(DiscoEvents.endtime <= endtime_lte) + conditions.append(DiscoEvents.endtime <= endtime_lte) if avglevel: - query = query.filter(DiscoEvents.avglevel == avglevel) + conditions.append(DiscoEvents.avglevel == avglevel) if avglevel_gte: - query = query.filter(DiscoEvents.avglevel >= avglevel_gte) + conditions.append(DiscoEvents.avglevel >= avglevel_gte) if avglevel_lte: - query = query.filter(DiscoEvents.avglevel <= avglevel_lte) + conditions.append(DiscoEvents.avglevel <= avglevel_lte) if nbdiscoprobes: - query = query.filter(DiscoEvents.nbdiscoprobes == nbdiscoprobes) + conditions.append(DiscoEvents.nbdiscoprobes == nbdiscoprobes) if nbdiscoprobes_gte: - query = query.filter( - DiscoEvents.nbdiscoprobes >= nbdiscoprobes_gte) + conditions.append(DiscoEvents.nbdiscoprobes >= nbdiscoprobes_gte) if nbdiscoprobes_lte: - query = query.filter( - DiscoEvents.nbdiscoprobes <= nbdiscoprobes_lte) + conditions.append(DiscoEvents.nbdiscoprobes <= nbdiscoprobes_lte) if totalprobes: - query = query.filter(DiscoEvents.totalprobes == totalprobes) + conditions.append(DiscoEvents.totalprobes == totalprobes) if totalprobes_gte: - query = query.filter(DiscoEvents.totalprobes >= totalprobes_gte) + conditions.append(DiscoEvents.totalprobes >= totalprobes_gte) if totalprobes_lte: - query = query.filter(DiscoEvents.totalprobes <= totalprobes_lte) + conditions.append(DiscoEvents.totalprobes <= totalprobes_lte) if ongoing: - query = query.filter(DiscoEvents.ongoing == ongoing) + conditions.append(DiscoEvents.ongoing == ongoing) - total_count = query.count() + total_count = db.scalar(select(func.count(DiscoEvents.id)).where(*conditions)) + + load_opt = joinedload(DiscoEvents.probes) if include_probe_details else noload(DiscoEvents.probes) + stmt = select(DiscoEvents).where(*conditions).options(load_opt) if order_by and hasattr(DiscoEvents, order_by): - query = query.order_by(getattr(DiscoEvents, order_by)) + stmt = stmt.order_by(getattr(DiscoEvents, order_by)) offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/hegemony_alarms_repository.py b/repositories/hegemony_alarms_repository.py index 070c466..b6f4a16 100644 --- a/repositories/hegemony_alarms_repository.py +++ b/repositories/hegemony_alarms_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, aliased, contains_eager +from sqlalchemy import select, func from models.hegemony_alarms import HegemonyAlarms from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class HegemonyAlarmsRepository: @@ -22,45 +22,46 @@ def get_all( ) -> Tuple[List[HegemonyAlarms], int]: ASN = aliased(HegemonyAlarms.asn_relation.property.mapper.class_) OriginASN = aliased(HegemonyAlarms.originasn_relation.property.mapper.class_) - - query = db.query(HegemonyAlarms)\ - .join(ASN, HegemonyAlarms.asn_relation)\ - .join(OriginASN, HegemonyAlarms.originasn_relation)\ - .options( - contains_eager(HegemonyAlarms.asn_relation, alias=ASN), - contains_eager(HegemonyAlarms.originasn_relation, alias=OriginASN) - ) + + stmt = ( + select(HegemonyAlarms) + .join(HegemonyAlarms.asn_relation.of_type(ASN)) + .join(HegemonyAlarms.originasn_relation.of_type(OriginASN)) + .options( + contains_eager(HegemonyAlarms.asn_relation.of_type(ASN)), + contains_eager(HegemonyAlarms.originasn_relation.of_type(OriginASN)) + ) + ) # If no time filters specified, get rows with max timebin if not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(HegemonyAlarms.timebin)).scalar() - query = query.filter(HegemonyAlarms.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(HegemonyAlarms.timebin))) + stmt = stmt.where(HegemonyAlarms.timebin == max_timebin) + # Apply filters if timebin_gte: - query = query.filter(HegemonyAlarms.timebin >= timebin_gte) + stmt = stmt.where(HegemonyAlarms.timebin >= timebin_gte) if timebin_lte: - query = query.filter(HegemonyAlarms.timebin <= timebin_lte) + stmt = stmt.where(HegemonyAlarms.timebin <= timebin_lte) if asn_ids: - query = query.filter(HegemonyAlarms.asn.in_(asn_ids)) + stmt = stmt.where(HegemonyAlarms.asn.in_(asn_ids)) if originasn_ids: - query = query.filter( - HegemonyAlarms.originasn.in_(originasn_ids)) + stmt = stmt.where(HegemonyAlarms.originasn.in_(originasn_ids)) if af is not None: - query = query.filter(HegemonyAlarms.af == af) + stmt = stmt.where(HegemonyAlarms.af == af) if deviation_gte: - query = query.filter(HegemonyAlarms.deviation >= deviation_gte) + stmt = stmt.where(HegemonyAlarms.deviation >= deviation_gte) if deviation_lte: - query = query.filter(HegemonyAlarms.deviation <= deviation_lte) + stmt = stmt.where(HegemonyAlarms.deviation <= deviation_lte) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(HegemonyAlarms, order_by): - query = query.order_by(getattr(HegemonyAlarms, order_by)) + stmt = stmt.order_by(getattr(HegemonyAlarms, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/hegemony_cone_repository.py b/repositories/hegemony_cone_repository.py index 52b8da7..fe03f8d 100644 --- a/repositories/hegemony_cone_repository.py +++ b/repositories/hegemony_cone_repository.py @@ -1,9 +1,9 @@ from datetime import datetime, timedelta from sqlalchemy.orm import Session +from sqlalchemy import select, func from models.hegemony_cone import HegemonyCone from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class HegemonyConeRepository: @@ -17,33 +17,33 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[HegemonyCone], int]: - query = db.query(HegemonyCone) + stmt = select(HegemonyCone) # If no time filters specified, get rows with max timebin if not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(HegemonyCone.timebin)).scalar() - query = query.filter(HegemonyCone.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(HegemonyCone.timebin))) + stmt = stmt.where(HegemonyCone.timebin == max_timebin) + # Apply filters if timebin_gte: - query = query.filter(HegemonyCone.timebin >= timebin_gte) + stmt = stmt.where(HegemonyCone.timebin >= timebin_gte) if timebin_lte: - query = query.filter(HegemonyCone.timebin <= timebin_lte) + stmt = stmt.where(HegemonyCone.timebin <= timebin_lte) if asn_ids: - query = query.filter(HegemonyCone.asn.in_(asn_ids)) + stmt = stmt.where(HegemonyCone.asn.in_(asn_ids)) if af: - query = query.filter(HegemonyCone.af == af) + stmt = stmt.where(HegemonyCone.af == af) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(HegemonyCone, order_by): - query = query.order_by(getattr(HegemonyCone, order_by)) + stmt = stmt.order_by(getattr(HegemonyCone, order_by)) else: - query = query.order_by(HegemonyCone.timebin) + stmt = stmt.order_by(HegemonyCone.timebin) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).all() return results, total_count diff --git a/repositories/hegemony_country_repository.py b/repositories/hegemony_country_repository.py index 200b0a8..01fa5a7 100644 --- a/repositories/hegemony_country_repository.py +++ b/repositories/hegemony_country_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, contains_eager, aliased +from sqlalchemy import select, func from models.hegemony_country import HegemonyCountry from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class HegemonyCountryRepository: @@ -24,48 +24,48 @@ def get_all( order_by: Optional[str] = None ) -> Tuple[List[HegemonyCountry], int]: ASN = aliased(HegemonyCountry.asn_relation.property.mapper.class_) - - query = db.query(HegemonyCountry)\ - .join(ASN, HegemonyCountry.asn_relation)\ - .options( - contains_eager(HegemonyCountry.asn_relation, alias=ASN), - ) + + stmt = ( + select(HegemonyCountry) + .join(HegemonyCountry.asn_relation.of_type(ASN)) + .options(contains_eager(HegemonyCountry.asn_relation.of_type(ASN))) + ) # If no time filters specified, get rows with max timebin if not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(HegemonyCountry.timebin)).scalar() - query = query.filter(HegemonyCountry.timebin == max_timebin) + max_timebin = db.scalar(select(func.max(HegemonyCountry.timebin))) + stmt = stmt.where(HegemonyCountry.timebin == max_timebin) # Apply filters if timebin_gte: - query = query.filter(HegemonyCountry.timebin >= timebin_gte) + stmt = stmt.where(HegemonyCountry.timebin >= timebin_gte) if timebin_lte: - query = query.filter(HegemonyCountry.timebin <= timebin_lte) + stmt = stmt.where(HegemonyCountry.timebin <= timebin_lte) if asn_ids: - query = query.filter(HegemonyCountry.asn.in_(asn_ids)) + stmt = stmt.where(HegemonyCountry.asn.in_(asn_ids)) if countries: - query = query.filter(HegemonyCountry.country.in_(countries)) + stmt = stmt.where(HegemonyCountry.country.in_(countries)) if af is not None: - query = query.filter(HegemonyCountry.af == af) + stmt = stmt.where(HegemonyCountry.af == af) if weightscheme is not None: - query = query.filter(HegemonyCountry.weightscheme == weightscheme) + stmt = stmt.where(HegemonyCountry.weightscheme == weightscheme) if transitonly is not None: - query = query.filter(HegemonyCountry.transitonly == transitonly) + stmt = stmt.where(HegemonyCountry.transitonly == transitonly) if hege is not None: - query = query.filter(HegemonyCountry.hege == hege) + stmt = stmt.where(HegemonyCountry.hege == hege) if hege_gte is not None: - query = query.filter(HegemonyCountry.hege >= hege_gte) + stmt = stmt.where(HegemonyCountry.hege >= hege_gte) if hege_lte is not None: - query = query.filter(HegemonyCountry.hege <= hege_lte) + stmt = stmt.where(HegemonyCountry.hege <= hege_lte) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(HegemonyCountry, order_by): - query = query.order_by(getattr(HegemonyCountry, order_by)) + stmt = stmt.order_by(getattr(HegemonyCountry, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/hegemony_prefix_repository.py b/repositories/hegemony_prefix_repository.py index 589f8cd..5fe4b45 100644 --- a/repositories/hegemony_prefix_repository.py +++ b/repositories/hegemony_prefix_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, contains_eager, aliased +from sqlalchemy import select, func from models.hegemony_prefix import HegemonyPrefix from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class HegemonyPrefixRepository: @@ -28,67 +28,64 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[HegemonyPrefix], int]: - ASN = aliased(HegemonyPrefix.asn_relation.property.mapper.class_) OriginASN = aliased(HegemonyPrefix.originasn_relation.property.mapper.class_) - - query = db.query(HegemonyPrefix)\ - .join(ASN, HegemonyPrefix.asn_relation)\ - .join(OriginASN, HegemonyPrefix.originasn_relation)\ - .options( - contains_eager(HegemonyPrefix.asn_relation, alias=ASN), - contains_eager(HegemonyPrefix.originasn_relation, alias=OriginASN) - ) + + stmt = ( + select(HegemonyPrefix) + .join(HegemonyPrefix.asn_relation.of_type(ASN)) + .join(HegemonyPrefix.originasn_relation.of_type(OriginASN)) + .options( + contains_eager(HegemonyPrefix.asn_relation.of_type(ASN)), + contains_eager(HegemonyPrefix.originasn_relation.of_type(OriginASN)) + ) + ) + # If no time filters specified, get rows with max timebin if not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(HegemonyPrefix.timebin)).scalar() - query = query.filter(HegemonyPrefix.timebin == max_timebin) + max_timebin = db.scalar(select(func.max(HegemonyPrefix.timebin))) + stmt = stmt.where(HegemonyPrefix.timebin == max_timebin) # Apply filters if timebin_gte: - query = query.filter(HegemonyPrefix.timebin >= timebin_gte) + stmt = stmt.where(HegemonyPrefix.timebin >= timebin_gte) if timebin_lte: - query = query.filter(HegemonyPrefix.timebin <= timebin_lte) + stmt = stmt.where(HegemonyPrefix.timebin <= timebin_lte) if prefixes: - query = query.filter(HegemonyPrefix.prefix.in_(prefixes)) + stmt = stmt.where(HegemonyPrefix.prefix.in_(prefixes)) if asn_ids: - query = query.filter(HegemonyPrefix.asn.in_(asn_ids)) + stmt = stmt.where(HegemonyPrefix.asn.in_(asn_ids)) if originasn_ids: - query = query.filter(HegemonyPrefix.originasn.in_(originasn_ids)) + stmt = stmt.where(HegemonyPrefix.originasn.in_(originasn_ids)) if countries: - query = query.filter(HegemonyPrefix.country.in_(countries)) + stmt = stmt.where(HegemonyPrefix.country.in_(countries)) if rpki_status: - query = query.filter( - HegemonyPrefix.rpki_status.contains(rpki_status)) + stmt = stmt.where(HegemonyPrefix.rpki_status.contains(rpki_status)) if irr_status: - query = query.filter( - HegemonyPrefix.irr_status.contains(irr_status)) + stmt = stmt.where(HegemonyPrefix.irr_status.contains(irr_status)) if delegated_prefix_status: - query = query.filter( - HegemonyPrefix.delegated_prefix_status.contains(delegated_prefix_status)) + stmt = stmt.where(HegemonyPrefix.delegated_prefix_status.contains(delegated_prefix_status)) if delegated_asn_status: - query = query.filter( - HegemonyPrefix.delegated_asn_status.contains(delegated_asn_status)) + stmt = stmt.where(HegemonyPrefix.delegated_asn_status.contains(delegated_asn_status)) if af is not None: - query = query.filter(HegemonyPrefix.af == af) + stmt = stmt.where(HegemonyPrefix.af == af) if hege is not None: - query = query.filter(HegemonyPrefix.hege == hege) + stmt = stmt.where(HegemonyPrefix.hege == hege) if hege_gte is not None: - query = query.filter(HegemonyPrefix.hege >= hege_gte) + stmt = stmt.where(HegemonyPrefix.hege >= hege_gte) if hege_lte is not None: - query = query.filter(HegemonyPrefix.hege <= hege_lte) + stmt = stmt.where(HegemonyPrefix.hege <= hege_lte) if origin_only: - query = query.filter( - HegemonyPrefix.originasn == HegemonyPrefix.asn) + stmt = stmt.where(HegemonyPrefix.originasn == HegemonyPrefix.asn) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(HegemonyPrefix, order_by): - query = query.order_by(getattr(HegemonyPrefix, order_by)) + stmt = stmt.order_by(getattr(HegemonyPrefix, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/hegemony_repository.py b/repositories/hegemony_repository.py index 7a13a24..065dbc3 100644 --- a/repositories/hegemony_repository.py +++ b/repositories/hegemony_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, contains_eager, aliased +from sqlalchemy import select, func from models.hegemony import Hegemony from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class HegemonyRepository: @@ -21,48 +21,50 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[Hegemony], int]: - ASN = aliased(Hegemony.asn_relation.property.mapper.class_) OriginASN = aliased(Hegemony.originasn_relation.property.mapper.class_) - - query = db.query(Hegemony)\ - .join(ASN, Hegemony.asn_relation)\ - .join(OriginASN, Hegemony.originasn_relation)\ - .options( - contains_eager(Hegemony.asn_relation, alias=ASN), - contains_eager(Hegemony.originasn_relation, alias=OriginASN) - ) - + + stmt = ( + select(Hegemony) + .join(Hegemony.asn_relation.of_type(ASN)) + .join(Hegemony.originasn_relation.of_type(OriginASN)) + .options( + contains_eager(Hegemony.asn_relation.of_type(ASN)), + contains_eager(Hegemony.originasn_relation.of_type(OriginASN)) + ) + ) + # If no time filters specified, get rows with max timebin if not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(Hegemony.timebin)).scalar() - query = query.filter(Hegemony.timebin == max_timebin) + max_timebin = db.scalar(select(func.max(Hegemony.timebin))) + stmt = stmt.where(Hegemony.timebin == max_timebin) + # Apply filters if timebin_gte: - query = query.filter(Hegemony.timebin >= timebin_gte) + stmt = stmt.where(Hegemony.timebin >= timebin_gte) if timebin_lte: - query = query.filter(Hegemony.timebin <= timebin_lte) + stmt = stmt.where(Hegemony.timebin <= timebin_lte) if asn_ids: - query = query.filter(Hegemony.asn.in_(asn_ids)) + stmt = stmt.where(Hegemony.asn.in_(asn_ids)) if originasn_ids: - query = query.filter(Hegemony.originasn.in_(originasn_ids)) + stmt = stmt.where(Hegemony.originasn.in_(originasn_ids)) if af is not None: - query = query.filter(Hegemony.af == af) + stmt = stmt.where(Hegemony.af == af) if hege is not None: - query = query.filter(Hegemony.hege == hege) + stmt = stmt.where(Hegemony.hege == hege) if hege_gte: - query = query.filter(Hegemony.hege >= hege_gte) + stmt = stmt.where(Hegemony.hege >= hege_gte) if hege_lte: - query = query.filter(Hegemony.hege <= hege_lte) + stmt = stmt.where(Hegemony.hege <= hege_lte) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(Hegemony, order_by): - query = query.order_by(getattr(Hegemony, order_by)) + stmt = stmt.order_by(getattr(Hegemony, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/metis_atlas_deployment_repository.py b/repositories/metis_atlas_deployment_repository.py index 8e525f1..6518df4 100644 --- a/repositories/metis_atlas_deployment_repository.py +++ b/repositories/metis_atlas_deployment_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, contains_eager +from sqlalchemy import select, func from models.metis_atlas_deployment import MetisAtlasDeployment from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class MetisAtlasDeploymentRepository: @@ -21,41 +21,43 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[MetisAtlasDeployment], int]: - query = db.query(MetisAtlasDeployment)\ - .join(MetisAtlasDeployment.asn_relation)\ + stmt = ( + select(MetisAtlasDeployment) + .join(MetisAtlasDeployment.asn_relation) .options(contains_eager(MetisAtlasDeployment.asn_relation)) + ) # If no time filters specified, get rows with max timebin if not timebin and not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(MetisAtlasDeployment.timebin)).scalar() - query = query.filter(MetisAtlasDeployment.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(MetisAtlasDeployment.timebin))) + stmt = stmt.where(MetisAtlasDeployment.timebin == max_timebin) + # Apply filters if timebin: - query = query.filter(MetisAtlasDeployment.timebin == timebin) + stmt = stmt.where(MetisAtlasDeployment.timebin == timebin) if timebin_gte: - query = query.filter(MetisAtlasDeployment.timebin >= timebin_gte) + stmt = stmt.where(MetisAtlasDeployment.timebin >= timebin_gte) if timebin_lte: - query = query.filter(MetisAtlasDeployment.timebin <= timebin_lte) + stmt = stmt.where(MetisAtlasDeployment.timebin <= timebin_lte) if rank: - query = query.filter(MetisAtlasDeployment.rank == rank) + stmt = stmt.where(MetisAtlasDeployment.rank == rank) if rank_lte: - query = query.filter(MetisAtlasDeployment.rank <= rank_lte) + stmt = stmt.where(MetisAtlasDeployment.rank <= rank_lte) if rank_gte: - query = query.filter(MetisAtlasDeployment.rank >= rank_gte) + stmt = stmt.where(MetisAtlasDeployment.rank >= rank_gte) if metric: - query = query.filter(MetisAtlasDeployment.metric == metric) + stmt = stmt.where(MetisAtlasDeployment.metric == metric) if af: - query = query.filter(MetisAtlasDeployment.af == af) + stmt = stmt.where(MetisAtlasDeployment.af == af) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(MetisAtlasDeployment, order_by): - query = query.order_by(getattr(MetisAtlasDeployment, order_by)) + stmt = stmt.order_by(getattr(MetisAtlasDeployment, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/metis_atlas_selection_repository.py b/repositories/metis_atlas_selection_repository.py index 348af87..68fece1 100644 --- a/repositories/metis_atlas_selection_repository.py +++ b/repositories/metis_atlas_selection_repository.py @@ -1,9 +1,9 @@ from datetime import datetime from sqlalchemy.orm import Session, contains_eager +from sqlalchemy import select, func from models.metis_atlas_selection import MetisAtlasSelection from typing import Optional, List, Tuple from utils import page_size -from sqlalchemy import func class MetisAtlasSelectionRepository: @@ -21,42 +21,43 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[MetisAtlasSelection], int]: - query = db.query(MetisAtlasSelection)\ - .join(MetisAtlasSelection.asn_relation)\ + stmt = ( + select(MetisAtlasSelection) + .join(MetisAtlasSelection.asn_relation) .options(contains_eager(MetisAtlasSelection.asn_relation)) + ) # If no time filters specified, get rows with max timebin if not timebin and not timebin_gte and not timebin_lte: - max_timebin = db.query( - func.max(MetisAtlasSelection.timebin)).scalar() - query = query.filter(MetisAtlasSelection.timebin == max_timebin) + max_timebin = db.scalar(select(func.max(MetisAtlasSelection.timebin))) + stmt = stmt.where(MetisAtlasSelection.timebin == max_timebin) # Apply filters if timebin: - query = query.filter(MetisAtlasSelection.timebin == timebin) + stmt = stmt.where(MetisAtlasSelection.timebin == timebin) if timebin_gte: - query = query.filter(MetisAtlasSelection.timebin >= timebin_gte) + stmt = stmt.where(MetisAtlasSelection.timebin >= timebin_gte) if timebin_lte: - query = query.filter(MetisAtlasSelection.timebin <= timebin_lte) + stmt = stmt.where(MetisAtlasSelection.timebin <= timebin_lte) if rank: - query = query.filter(MetisAtlasSelection.rank == rank) + stmt = stmt.where(MetisAtlasSelection.rank == rank) if rank_lte: - query = query.filter(MetisAtlasSelection.rank <= rank_lte) + stmt = stmt.where(MetisAtlasSelection.rank <= rank_lte) if rank_gte: - query = query.filter(MetisAtlasSelection.rank >= rank_gte) + stmt = stmt.where(MetisAtlasSelection.rank >= rank_gte) if metric: - query = query.filter(MetisAtlasSelection.metric == metric) + stmt = stmt.where(MetisAtlasSelection.metric == metric) if af: - query = query.filter(MetisAtlasSelection.af == af) + stmt = stmt.where(MetisAtlasSelection.af == af) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(MetisAtlasSelection, order_by): - query = query.order_by(getattr(MetisAtlasSelection, order_by)) + stmt = stmt.order_by(getattr(MetisAtlasSelection, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/repositories/networks_repository.py b/repositories/networks_repository.py index d7add3e..d90dd46 100644 --- a/repositories/networks_repository.py +++ b/repositories/networks_repository.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from sqlalchemy import or_, String +from sqlalchemy import select, func, or_, String from models.asn import ASN from typing import Optional, List, Tuple from utils import page_size @@ -17,17 +17,17 @@ def get_all( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[ASN], int]: - query = db.query(ASN) + stmt = select(ASN) # Apply filters if name: - query = query.filter(ASN.name.ilike(f"%{name}%")) + stmt = stmt.where(ASN.name.ilike(f"%{name}%")) if numbers: - query = query.filter(ASN.number.in_(numbers)) + stmt = stmt.where(ASN.number.in_(numbers)) if number_gte: - query = query.filter(ASN.number >= number_gte) + stmt = stmt.where(ASN.number >= number_gte) if number_lte: - query = query.filter(ASN.number <= number_lte) + stmt = stmt.where(ASN.number <= number_lte) if search: # Handle AS/IX prefix in search search_value = search @@ -36,20 +36,19 @@ def get_all( search_value = str(int(search[2:])) except ValueError: pass - - query = query.filter(or_( + stmt = stmt.where(or_( ASN.number.cast(String).contains(search_value), ASN.name.ilike(f"%{search}%") )) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) # Apply ordering if order_by and hasattr(ASN, order_by): - query = query.order_by(getattr(ASN, order_by)) + stmt = stmt.order_by(getattr(ASN, order_by)) # Apply pagination offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).all() return results, total_count diff --git a/repositories/tr_hegemony_repository.py b/repositories/tr_hegemony_repository.py index 27a5462..e4560b0 100644 --- a/repositories/tr_hegemony_repository.py +++ b/repositories/tr_hegemony_repository.py @@ -1,10 +1,9 @@ from sqlalchemy.orm import Session, aliased, contains_eager -from sqlalchemy import and_, or_ +from sqlalchemy import select, func, and_, or_ from models.tr_hegemony import TRHegemony from datetime import datetime from typing import List, Optional, Tuple from utils import page_size -from sqlalchemy import func class TRHegemonyRepository: @@ -27,63 +26,62 @@ def get_tr_hegemony( page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[TRHegemony], int]: - Origin = aliased(TRHegemony.origin_relation.property.mapper.class_) - Dependency = aliased( - TRHegemony.dependency_relation.property.mapper.class_) + Dependency = aliased(TRHegemony.dependency_relation.property.mapper.class_) - query = db.query(TRHegemony)\ - .join(Origin, TRHegemony.origin_relation)\ - .join(Dependency, TRHegemony.dependency_relation)\ + stmt = ( + select(TRHegemony) + .join(TRHegemony.origin_relation.of_type(Origin)) + .join(TRHegemony.dependency_relation.of_type(Dependency)) .options( - contains_eager(TRHegemony.origin_relation, alias=Origin), - contains_eager(TRHegemony.dependency_relation, alias=Dependency) + contains_eager(TRHegemony.origin_relation.of_type(Origin)), + contains_eager(TRHegemony.dependency_relation.of_type(Dependency)) ) + ) # If no time filters specified, get rows with max timebin if not timebin and not timebin_gte and not timebin_lte: - max_timebin = db.query(func.max(TRHegemony.timebin)).scalar() - query = query.filter(TRHegemony.timebin == max_timebin) - + max_timebin = db.scalar(select(func.max(TRHegemony.timebin))) + stmt = stmt.where(TRHegemony.timebin == max_timebin) + if timebin: - query = query.filter(TRHegemony.timebin == timebin) + stmt = stmt.where(TRHegemony.timebin == timebin) if timebin_gte: - query = query.filter(TRHegemony.timebin >= timebin_gte) + stmt = stmt.where(TRHegemony.timebin >= timebin_gte) if timebin_lte: - query = query.filter(TRHegemony.timebin <= timebin_lte) + stmt = stmt.where(TRHegemony.timebin <= timebin_lte) if origin_names: names = origin_names.split('|') - query = query.filter(Origin.name.in_(names)) + stmt = stmt.where(Origin.name.in_(names)) if origin_type: - query = query.filter(Origin.type == origin_type) + stmt = stmt.where(Origin.type == origin_type) if origin_af: - query = query.filter(Origin.af == origin_af) + stmt = stmt.where(Origin.af == origin_af) if dependency_names: names = dependency_names.split('|') - query = query.filter(Dependency.name.in_(names)) + stmt = stmt.where(Dependency.name.in_(names)) if dependency_type: - query = query.filter(Dependency.type == dependency_type) + stmt = stmt.where(Dependency.type == dependency_type) if dependency_af: - query = query.filter(Dependency.af == dependency_af) + stmt = stmt.where(Dependency.af == dependency_af) if hege: - query = query.filter(TRHegemony.hege == hege) + stmt = stmt.where(TRHegemony.hege == hege) if hege_gte: - query = query.filter(TRHegemony.hege >= hege_gte) + stmt = stmt.where(TRHegemony.hege >= hege_gte) if hege_lte: - query = query.filter(TRHegemony.hege <= hege_lte) - + stmt = stmt.where(TRHegemony.hege <= hege_lte) if af: - query = query.filter(TRHegemony.af == af) + stmt = stmt.where(TRHegemony.af == af) - total_count = query.count() + total_count = db.scalar(select(func.count()).select_from(stmt.subquery())) if order_by and hasattr(TRHegemony, order_by): - query = query.order_by(getattr(TRHegemony, order_by)) + stmt = stmt.order_by(getattr(TRHegemony, order_by)) offset = (page - 1) * page_size - results = query.offset(offset).limit(page_size).all() + results = db.scalars(stmt.offset(offset).limit(page_size)).unique().all() return results, total_count diff --git a/requirements.txt b/requirements.txt index 20a8242..4d8b2a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -fastapi==0.115.11 -pydantic==2.10.6 -python-dotenv==1.0.1 -SQLAlchemy==2.0.38 -uvicorn==0.34.0 -psycopg2==2.9.10 -alembic==1.15.1 \ No newline at end of file +fastapi~=0.135.2 +pydantic~=2.12.5 +python-dotenv~=1.2.2 +SQLAlchemy~=2.0.48 +uvicorn~=0.42.0 +psycopg2~=2.9.11 +alembic~=1.18.4 \ No newline at end of file diff --git a/services/disco_service.py b/services/disco_service.py index 8a965dc..d994a3f 100644 --- a/services/disco_service.py +++ b/services/disco_service.py @@ -30,6 +30,7 @@ def get_disco_events( totalprobes_gte: Optional[int] = None, totalprobes_lte: Optional[int] = None, ongoing: Optional[str] = None, + include_probe_details: bool = False, page: int = 1, order_by: Optional[str] = None ) -> Tuple[List[DiscoEventsDTO], int]: @@ -54,8 +55,9 @@ def get_disco_events( totalprobes_gte=totalprobes_gte, totalprobes_lte=totalprobes_lte, ongoing=ongoing, + include_probe_details=include_probe_details, page=page, order_by=order_by ) - return [DiscoEventsDTO.from_model(event) for event in events_data], total_count + return [DiscoEventsDTO.from_model(event, include_probe_details=include_probe_details) for event in events_data], total_count