Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 207 additions & 13 deletions datajunction-server/datajunction_server/api/preaggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from datajunction_server.database.column import Column
from datajunction_server.database.availabilitystate import AvailabilityState
from datajunction_server.database.dimensionlink import DimensionLink
from datajunction_server.database.measure import FrozenMeasure
from datajunction_server.database.preaggregation import (
PreAggregation,
VALID_PREAGG_STRATEGIES,
Expand All @@ -43,13 +44,15 @@
DJQueryServiceClientException,
)
from datajunction_server.internal.access.authentication.http import SecureAPIRouter
from datajunction_server.models.node_type import NodeNameVersion
from datajunction_server.models.node_type import NodeNameVersion, NodeType
from datajunction_server.models.dialect import Dialect
from datajunction_server.models.materialization import MaterializationStrategy
from datajunction_server.models.preaggregation import (
BackfillRequest,
BackfillInput,
BackfillResponse,
BulkDeactivateWorkflowsResponse,
DeactivatedWorkflowInfo,
GrainMode,
DEFAULT_SCHEDULE,
PlanPreAggregationsRequest,
Expand All @@ -59,12 +62,13 @@
PreAggMaterializationInput,
UpdatePreAggregationAvailabilityRequest,
WorkflowResponse,
WorkflowStatus,
WorkflowUrl,
)
from datajunction_server.construction.build_v3.preagg_matcher import (
get_temporal_partitions,
)
from datajunction_server.models.decompose import PreAggMeasure
from datajunction_server.models.decompose import MetricRef, PreAggMeasure
from datajunction_server.models.node_type import NodeType
from datajunction_server.models.query import ColumnMetadata, V3ColumnMetadata
from datajunction_server.service_clients import QueryServiceClient
Expand Down Expand Up @@ -131,15 +135,54 @@ async def _get_upstream_source_tables(
return []


def _preagg_to_info(preagg: PreAggregation) -> PreAggregationInfo:
async def _preagg_to_info(
preagg: PreAggregation,
session: AsyncSession,
) -> PreAggregationInfo:
"""Convert a PreAggregation ORM object to a PreAggregationInfo response model."""
# Look up related metrics from FrozenMeasure relationships for each measure
measures_with_metrics: list[PreAggMeasure] = []
all_related_metrics: set[str] = set()

# Fetch all frozen measures in a single query to avoid N+1
measure_names = [measure.name for measure in preagg.measures or []]
frozen_measures = await FrozenMeasure.get_by_names(session, measure_names)
frozen_measures_map = {fm.name: fm for fm in frozen_measures}

for measure in preagg.measures or []:
# Find which metrics use this measure
measure_metrics: list[MetricRef] = []
frozen = frozen_measures_map.get(measure.name)
if frozen:
for nr in frozen.used_by_node_revisions:
if nr.type == NodeType.METRIC: # pragma: no branch
measure_metrics.append(
MetricRef(name=nr.name, display_name=nr.display_name),
)
all_related_metrics.add(nr.name)

# Create new PreAggMeasure with used_by_metrics populated
measures_with_metrics.append(
PreAggMeasure(
name=measure.name,
expression=measure.expression,
aggregation=measure.aggregation,
merge=measure.merge,
rule=measure.rule,
expr_hash=measure.expr_hash,
used_by_metrics=sorted(measure_metrics, key=lambda m: m.name)
if measure_metrics
else None,
),
)

return PreAggregationInfo(
id=preagg.id,
node_revision_id=preagg.node_revision_id,
node_name=preagg.node_revision.name,
node_version=preagg.node_revision.version,
grain_columns=preagg.grain_columns,
measures=preagg.measures,
measures=measures_with_metrics,
columns=preagg.columns,
sql=preagg.sql,
grain_group_hash=preagg.grain_group_hash,
Expand All @@ -151,6 +194,7 @@ def _preagg_to_info(preagg: PreAggregation) -> PreAggregationInfo:
status=preagg.status,
materialized_table_ref=preagg.materialized_table_ref,
max_partition=preagg.max_partition,
related_metrics=sorted(all_related_metrics) if all_related_metrics else None,
created_at=preagg.created_at,
updated_at=preagg.updated_at,
)
Expand Down Expand Up @@ -195,6 +239,10 @@ async def list_preaggregations(
default=None,
description="Filter by status: 'pending' or 'active'",
),
include_stale: bool = Query(
default=False,
description="Include pre-aggs from older node versions (stale)",
),
limit: int = Query(default=50, ge=1, le=100),
offset: int = Query(default=0, ge=0),
*,
Expand Down Expand Up @@ -246,8 +294,16 @@ async def list_preaggregations(
f"Version '{node_version}' not found for node '{node_name}'",
)
stmt = stmt.where(PreAggregation.node_revision_id == target_revision.id)
elif include_stale:
# Include all revisions for this node
all_revisions_stmt = select(NodeRevision.id).where(
NodeRevision.node_id == node.id,
)
stmt = stmt.where(
PreAggregation.node_revision_id.in_(all_revisions_stmt),
)
else:
# Use latest version
# Use latest version only (default)
stmt = stmt.where(PreAggregation.node_revision_id == node.current.id)

# Filter by grain_group_hash (direct lookup)
Expand Down Expand Up @@ -369,8 +425,8 @@ async def list_preaggregations(
)
preaggs = [p for p in preaggs if p.status == status]

# Convert to response models
items = [_preagg_to_info(p) for p in preaggs]
# Convert to response models (with related metrics lookup)
items = [await _preagg_to_info(p, session) for p in preaggs]

return PreAggregationListResponse(
items=items,
Expand Down Expand Up @@ -409,7 +465,7 @@ async def get_preaggregation(
if not preagg:
raise DJDoesNotExistException(f"Pre-aggregation with ID {preagg_id} not found")

return _preagg_to_info(preagg)
return await _preagg_to_info(preagg, session)


# =============================================================================
Expand Down Expand Up @@ -601,7 +657,7 @@ async def plan_preaggregations(
loaded_preaggs = list(result.scalars().unique().all())

return PlanPreAggregationsResponse(
preaggs=[_preagg_to_info(p) for p in loaded_preaggs],
preaggs=[await _preagg_to_info(p, session) for p in loaded_preaggs],
)


Expand Down Expand Up @@ -803,7 +859,7 @@ async def materialize_preaggregation(
labeled_urls.append(WorkflowUrl(label="workflow", url=url))
preagg.workflow_urls = labeled_urls

preagg.workflow_status = "active"
preagg.workflow_status = WorkflowStatus.ACTIVE
# Also update schedule if it wasn't set (using default)
if not preagg.schedule:
preagg.schedule = schedule
Expand All @@ -817,7 +873,7 @@ async def materialize_preaggregation(

# Return pre-agg info with workflow URLs
await session.refresh(preagg, ["node_revision", "availability"])
return _preagg_to_info(preagg)
return await _preagg_to_info(preagg, session)


class UpdatePreAggregationConfigRequest(BaseModel):
Expand Down Expand Up @@ -888,7 +944,7 @@ async def update_preaggregation_config(
preagg.lookback_window,
)

return _preagg_to_info(preagg)
return await _preagg_to_info(preagg, session)


# =============================================================================
Expand Down Expand Up @@ -980,6 +1036,144 @@ async def delete_preagg_workflow(
)


@router.delete(
"/preaggs/workflows",
response_model=BulkDeactivateWorkflowsResponse,
name="Bulk Deactivate Workflows",
)
async def bulk_deactivate_preagg_workflows(
node_name: str = Query(
description="Node name to deactivate workflows for (required)",
),
stale_only: bool = Query(
default=False,
description="If true, only deactivate workflows for stale pre-aggs "
"(pre-aggs built for non-current node versions)",
),
*,
session: AsyncSession = Depends(get_session),
request: Request,
query_service_client: QueryServiceClient = Depends(get_query_service_client),
) -> BulkDeactivateWorkflowsResponse:
"""
Bulk deactivate workflows for pre-aggregations of a node.

This is useful for cleaning up stale pre-aggregations after a node
has been updated. When stale_only=true, only deactivates workflows
for pre-aggs that were built for older node versions.

Staleness is determined by comparing the pre-agg's node_revision_id
to the node's current revision.
"""
# Get the node and its current revision
node = await Node.get_by_name(
session,
node_name,
options=[
load_only(Node.id),
joinedload(Node.current).load_only(NodeRevision.id),
],
)
if not node:
raise DJDoesNotExistException(f"Node '{node_name}' not found")

current_revision_id = node.current.id if node.current else None

# Build query for pre-aggs with active workflows
stmt = (
select(PreAggregation)
.options(joinedload(PreAggregation.node_revision))
.join(PreAggregation.node_revision)
.where(
NodeRevision.node_id == node.id,
PreAggregation.workflow_status == WorkflowStatus.ACTIVE,
)
)

# If stale_only, filter to non-current revisions
if stale_only and current_revision_id:
stmt = stmt.where(PreAggregation.node_revision_id != current_revision_id)

result = await session.execute(stmt)
preaggs = result.scalars().all()

if not preaggs:
return BulkDeactivateWorkflowsResponse(
deactivated_count=0,
deactivated=[],
skipped_count=0,
message="No active workflows found matching criteria",
)

deactivated = []
skipped_count = 0
request_headers = dict(request.headers)

for preagg in preaggs:
if not preagg.workflow_urls: # pragma: no cover
skipped_count += 1
continue

# Compute output_table for workflow identification
output_table = _compute_output_table(
preagg.node_revision.name,
preagg.grain_group_hash,
)

# Extract workflow name from URLs if available
workflow_name = None
if preagg.workflow_urls: # pragma: no branch
for wf_url in preagg.workflow_urls: # pragma: no branch
if (
hasattr(wf_url, "label") and wf_url.label == "scheduled"
): # pragma: no branch
# Extract workflow name from URL path
workflow_name = wf_url.url.split("/")[-1] if wf_url.url else None
break

try:
query_service_client.deactivate_preagg_workflow(
output_table,
request_headers=request_headers,
)

# Clear workflow state
preagg.strategy = None
preagg.schedule = None
preagg.lookback_window = None
preagg.workflow_urls = None
preagg.workflow_status = None

deactivated.append(
DeactivatedWorkflowInfo(
id=preagg.id,
workflow_name=workflow_name,
),
)

_logger.info(
"Bulk deactivate: deactivated workflow for preagg_id=%s",
preagg.id,
)
except Exception as e: # pragma: no cover
_logger.warning(
"Bulk deactivate: failed to deactivate workflow for preagg_id=%s: %s",
preagg.id,
str(e),
)
# Continue with other pre-aggs even if one fails

await session.commit()

return BulkDeactivateWorkflowsResponse(
deactivated_count=len(deactivated),
deactivated=deactivated,
skipped_count=skipped_count,
message=f"Deactivated {len(deactivated)} workflow(s) for node '{node_name}'"
+ (" (stale only)" if stale_only else ""),
)


# =============================================================================
# Backfill & Run Endpoints
# =============================================================================
Expand Down Expand Up @@ -1206,4 +1400,4 @@ async def update_preaggregation_availability(
await session.commit()
await session.refresh(preagg, ["node_revision", "availability"])

return _preagg_to_info(preagg)
return await _preagg_to_info(preagg, session)
21 changes: 21 additions & 0 deletions datajunction-server/datajunction_server/database/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ async def get_by_name(
result = await session.execute(statement)
return result.unique().scalar_one_or_none()

@classmethod
async def get_by_names(
cls,
session: AsyncSession,
names: list[str],
) -> list["FrozenMeasure"]:
"""
Get multiple measures by names in a single query.
"""
if not names:
return [] # pragma: no cover
statement = (
select(FrozenMeasure)
.where(FrozenMeasure.name.in_(names))
.options(
selectinload(FrozenMeasure.used_by_node_revisions),
)
)
result = await session.execute(statement)
return list(result.unique().scalars().all())

@classmethod
async def find_by(
cls,
Expand Down
6 changes: 6 additions & 0 deletions datajunction-server/datajunction_server/internal/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,12 @@ async def revalidate_node(
await session.commit()
await session.refresh(node.current) # type: ignore
await session.refresh(node, ["current"])

# For metric nodes, derive frozen measures (ensures they exist even for
# metrics created via deployment or updated after initial creation)
if current_node_revision.type == NodeType.METRIC and background_tasks:
background_tasks.add_task(derive_frozen_measures, node.current.id) # type: ignore

return node_validator


Expand Down
Loading
Loading