Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions examples/flower-dashboard/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
Flower Dashboard
================

An interactive Streamlit dashboard that visualises metrics from a Flower federated learning experiment. The app embraces a plug-and-play workflow: drop in a metrics export and immediately inspect client participation, model convergence, and client-level anomalies.

Features
--------

* 📊 Line charts for loss and accuracy trends across rounds.
* 🙋 Client participation breakdown that highlights which clients finished, straggled, or dropped.
* 🧮 Round-by-round contribution analysis using example counts (or aggregation weights when available).
* 🚨 Automated alerts for stragglers, dropped clients, and anomalous client updates (z-score based).
* ⚙️ Configurable detection thresholds directly from the sidebar.

Getting started
---------------

1. Install dependencies::

pip install -r requirements.txt

2. Launch the dashboard::

streamlit run streamlit_app.py

3. Load metrics

* Toggle the "Use bundled sample data" switch to explore the included ``sample_metrics.json`` in ``assets``.
* Or upload your own metrics export (Flower simulation history, aggregator callback logs, or a custom JSON that mirrors the sample schema).

Expected JSON schema
--------------------

The dashboard expects a JSON document with a list of rounds. Every round entry may contain global metrics (loss, accuracy, server time) and a list of ``clients`` with per-client metrics.

.. code-block:: json

{
"rounds": [
{
"round": 1,
"loss": 1.85,
"accuracy": 0.45,
"server_time": 4.2,
"clients": [
{
"client_id": "client_1",
"status": "completed",
"loss": 1.95,
"accuracy": 0.40,
"examples": 128,
"duration": 12.1
}
]
}
]
}

Generating metrics from Flower code
-----------------------------------

There are several ways to produce a compatible JSON file:

* Use ``History.to_json()`` from a Flower simulation and save the result to disk.
* Collect metrics in a strategy callback (for example ``evaluate_round``) and export them when training finishes.
* Record client-side metrics (examples, loss, duration) in the ``fit`` return payload. The dashboard will automatically switch to aggregation weights if ``examples`` are missing.

Extending
---------

* Replace the static JSON loader with a websocket, database, or message-bus consumer to power near-real-time monitoring.
* Add authentication and deploy the dashboard to a shared monitoring cluster via Streamlit Community Cloud or Flower Intelligence.

License
-------

This example follows the main project license. See :code:`../../LICENSE`.
59 changes: 59 additions & 0 deletions examples/flower-dashboard/assets/sample_metrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"rounds": [
{
"round": 1,
"loss": 1.85,
"accuracy": 0.45,
"server_time": 4.2,
"clients": [
{"client_id": "client_1", "status": "completed", "loss": 1.95, "accuracy": 0.40, "examples": 128, "duration": 12.1},
{"client_id": "client_2", "status": "completed", "loss": 1.80, "accuracy": 0.48, "examples": 128, "duration": 13.7},
{"client_id": "client_3", "status": "dropped", "loss": null, "accuracy": null, "examples": 0, "duration": 25.1}
]
},
{
"round": 2,
"loss": 1.60,
"accuracy": 0.55,
"server_time": 3.8,
"clients": [
{"client_id": "client_1", "status": "completed", "loss": 1.50, "accuracy": 0.52, "examples": 128, "duration": 11.4},
{"client_id": "client_2", "status": "completed", "loss": 1.40, "accuracy": 0.56, "examples": 128, "duration": 10.1},
{"client_id": "client_3", "status": "straggler", "loss": 1.95, "accuracy": 0.38, "examples": 128, "duration": 31.2}
]
},
{
"round": 3,
"loss": 1.35,
"accuracy": 0.62,
"server_time": 3.5,
"clients": [
{"client_id": "client_1", "status": "completed", "loss": 1.25, "accuracy": 0.60, "examples": 128, "duration": 10.2},
{"client_id": "client_2", "status": "completed", "loss": 1.30, "accuracy": 0.59, "examples": 128, "duration": 10.0},
{"client_id": "client_3", "status": "completed", "loss": 1.70, "accuracy": 0.50, "examples": 128, "duration": 19.7}
]
},
{
"round": 4,
"loss": 1.10,
"accuracy": 0.70,
"server_time": 3.1,
"clients": [
{"client_id": "client_1", "status": "completed", "loss": 1.05, "accuracy": 0.68, "examples": 128, "duration": 9.8},
{"client_id": "client_2", "status": "completed", "loss": 1.00, "accuracy": 0.71, "examples": 128, "duration": 9.2},
{"client_id": "client_3", "status": "completed", "loss": 1.40, "accuracy": 0.58, "examples": 128, "duration": 18.3}
]
},
{
"round": 5,
"loss": 0.95,
"accuracy": 0.76,
"server_time": 2.8,
"clients": [
{"client_id": "client_1", "status": "completed", "loss": 0.90, "accuracy": 0.74, "examples": 128, "duration": 9.0},
{"client_id": "client_2", "status": "completed", "loss": 0.88, "accuracy": 0.77, "examples": 128, "duration": 8.9},
{"client_id": "client_3", "status": "completed", "loss": 1.20, "accuracy": 0.62, "examples": 128, "duration": 17.6}
]
}
]
}
4 changes: 4 additions & 0 deletions examples/flower-dashboard/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
streamlit>=1.32.0
pandas>=2.1.0
altair>=5.0.0
numpy>=1.24.0
242 changes: 242 additions & 0 deletions examples/flower-dashboard/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""Flower Dashboard: monitor federated learning rounds in real time."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List

import altair as alt
import numpy as np
import pandas as pd
import streamlit as st


st.set_page_config(
page_title="Flower Dashboard",
page_icon="🌸",
layout="wide",
menu_items={
"About": "Interactive dashboard for monitoring Flower federated learning runs.",
},
)


@st.cache_data(show_spinner=False)
def load_round_data(path: Path) -> Dict:
"""Load round metrics from a JSON file."""
with path.open("r", encoding="utf-8") as f:
return json.load(f)


def parse_rounds(raw: Dict) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Return per-round and per-client DataFrames from raw payload."""
rounds = raw.get("rounds", [])
if not rounds:
return pd.DataFrame(), pd.DataFrame()

rounds_df = pd.DataFrame(rounds)
rounds_df["round"] = rounds_df["round"].astype(int)
rounds_df = rounds_df.sort_values("round")

client_rows: List[Dict] = []
for round_info in rounds:
round_number = round_info.get("round")
for client in round_info.get("clients", []):
client_rows.append({**client, "round": round_number})

clients_df = pd.DataFrame(client_rows)
if not clients_df.empty:
clients_df["round"] = clients_df["round"].astype(int)
clients_df = clients_df.sort_values(["round", "client_id"])
return rounds_df, clients_df


def compute_stragglers(clients_df: pd.DataFrame, threshold: float) -> pd.DataFrame:
"""Flag clients whose duration is longer than the configured threshold."""
if clients_df.empty or "duration" not in clients_df:
return pd.DataFrame(columns=clients_df.columns if not clients_df.empty else [])

durations = clients_df["duration"].replace({None: np.nan}).astype(float)
cutoff = durations.quantile(threshold)
mask = durations >= cutoff
stragglers = clients_df.loc[mask].copy()
stragglers["duration_cutoff"] = round(cutoff, 2)
return stragglers


def compute_anomalies(clients_df: pd.DataFrame, metric: str, z_thresh: float) -> pd.DataFrame:
"""Detect anomalous client updates using a z-score threshold."""
if clients_df.empty or metric not in clients_df:
return pd.DataFrame(columns=clients_df.columns if not clients_df.empty else [])

metric_series = clients_df[metric].replace({None: np.nan}).astype(float).dropna()
if metric_series.std(ddof=0) == 0 or metric_series.empty:
return pd.DataFrame(columns=clients_df.columns if not clients_df.empty else [])

zscores = (metric_series - metric_series.mean()) / metric_series.std(ddof=0)
indices = zscores.index[np.abs(zscores) >= z_thresh]
anomalies = clients_df.loc[indices].copy()
anomalies["z_score"] = zscores.loc[indices].round(2)
anomalies["metric"] = metric
return anomalies


def build_loss_accuracy_chart(rounds_df: pd.DataFrame) -> alt.Chart:
melted = rounds_df.melt(id_vars="round", value_vars=["loss", "accuracy"], var_name="metric")
return (
alt.Chart(melted)
.mark_line(point=True)
.encode(
x=alt.X("round:O", title="Round"),
y=alt.Y("value:Q", title="Metric value"),
color=alt.Color("metric:N", title="Metric"),
tooltip=["round", "metric", alt.Tooltip("value:Q", format=".3f")],
)
.properties(height=300)
)


def build_participation_chart(clients_df: pd.DataFrame) -> alt.Chart:
counts = clients_df.groupby(["round", "status"], dropna=False)["client_id"].count().reset_index()
counts.rename(columns={"client_id": "count"}, inplace=True)
return (
alt.Chart(counts)
.mark_bar()
.encode(
x=alt.X("round:O", title="Round"),
y=alt.Y("count:Q", title="Client count"),
color=alt.Color("status:N", title="Status"),
tooltip=["round", "status", "count"],
)
.properties(height=300)
)


def build_contribution_chart(clients_df: pd.DataFrame, contribution_field: str) -> alt.Chart:
contributions = (
clients_df.groupby(["round", "client_id"], dropna=False)[contribution_field]
.sum()
.reset_index()
)
contributions.rename(columns={contribution_field: "contribution"}, inplace=True)
return (
alt.Chart(contributions)
.mark_bar()
.encode(
x=alt.X("round:O", title="Round"),
y=alt.Y("contribution:Q", title=contribution_field.capitalize()),
color=alt.Color("client_id:N", title="Client"),
tooltip=["round", "client_id", alt.Tooltip("contribution:Q", format=".2f")],
)
.properties(height=300)
)


st.title("🌸 Flower Dashboard")
st.caption(
"Plug-and-play analytics for Flower federated learning runs. "
"Drop a metrics export or connect to a live pipeline to inspect participation, metrics, and anomalies."
)

with st.sidebar:
st.header("Data source")
default_path = Path(__file__).parent / "assets" / "sample_metrics.json"
use_sample = st.toggle("Use bundled sample data", value=True)
uploaded = None
if not use_sample:
uploaded = st.file_uploader("Upload metrics JSON", type=["json"])

refresh_button = st.button("Refresh data")

st.divider()
st.header("Detection settings")
duration_quantile = st.slider(
"Straggler quantile threshold", min_value=0.5, max_value=0.99, value=0.9, step=0.01
)
anomaly_metric = st.selectbox(
"Metric for anomaly detection", options=["loss", "accuracy", "duration"], index=0
)
z_score_threshold = st.slider("Anomaly z-score", min_value=1.0, max_value=3.5, value=2.5, step=0.1)

if refresh_button:
st.experimental_rerun()

if uploaded is not None:
raw_data = json.load(uploaded)
elif default_path.exists():
raw_data = load_round_data(default_path)
else:
st.error("No data available. Please upload a metrics JSON file.")
st.stop()

rounds_df, clients_df = parse_rounds(raw_data)

if rounds_df.empty:
st.warning("No round-level metrics found in the provided file.")
st.stop()

st.subheader("Training overview")
metrics_container = st.container()
col1, col2, col3, col4 = metrics_container.columns(4)
col1.metric("Rounds", int(rounds_df["round"].max()))
col2.metric("Best accuracy", f"{rounds_df['accuracy'].max():.3f}")
col3.metric("Final loss", f"{rounds_df['loss'].iloc[-1]:.3f}")
if "server_time" in rounds_df:
col4.metric("Avg server time", f"{rounds_df['server_time'].mean():.2f}s")
else:
col4.metric("Data source", "static")

loss_accuracy_chart = build_loss_accuracy_chart(rounds_df)
st.altair_chart(loss_accuracy_chart, use_container_width=True)

if clients_df.empty:
st.warning("Client-level metrics missing. Participation and anomaly sections skipped.")
st.stop()

st.subheader("Client participation")
st.altair_chart(build_participation_chart(clients_df), use_container_width=True)

st.subheader("Round contributions")
contribution_field = "examples" if "examples" in clients_df else "weight"
st.altair_chart(
build_contribution_chart(clients_df, contribution_field), use_container_width=True
)

st.subheader("Alerts")
col_a, col_b, col_c = st.columns(3)

with col_a:
st.markdown("#### Stragglers")
stragglers = compute_stragglers(clients_df, duration_quantile)
if stragglers.empty:
st.success("No stragglers detected with the current threshold.")
else:
st.dataframe(
stragglers[["round", "client_id", "duration", "status"]],
use_container_width=True,
)

with col_b:
st.markdown("#### Dropped clients")
dropped = clients_df[clients_df["status"].str.lower().isin(["dropped", "failed", "timeout"])]
if dropped.empty:
st.success("No dropped clients.")
else:
st.dataframe(dropped[["round", "client_id", "status", "duration"]], use_container_width=True)

with col_c:
st.markdown("#### Anomalous updates")
anomalies = compute_anomalies(clients_df, anomaly_metric, z_score_threshold)
if anomalies.empty:
st.success("No anomalies detected with current settings.")
else:
st.dataframe(
anomalies[["round", "client_id", anomaly_metric, "z_score", "status"]],
use_container_width=True,
)

st.divider()

st.subheader("Raw data preview")
st.dataframe(clients_df, use_container_width=True, hide_index=True)