diff --git a/unstructured/metrics/utils.py b/unstructured/metrics/utils.py index c490aa752b..2db44c64fd 100644 --- a/unstructured/metrics/utils.py +++ b/unstructured/metrics/utils.py @@ -6,7 +6,9 @@ from typing import List, Optional, Union import click +import numpy as np import pandas as pd +from numba import njit from unstructured.staging.base import elements_from_json, elements_to_text @@ -214,12 +216,15 @@ def _pstdev(scores: List[Optional[float]], rounding: Optional[int] = 3) -> Union Args: rounding (int): optional argument that allows user to define decimal points. Default at 3. """ - scores = [score for score in scores if score is not None] - if len(scores) <= 1: + # Convert to float array, filtering None (preserve original behavior) + clean_scores = [score for score in scores if score is not None] + if len(clean_scores) <= 1: return None + arr = np.array(clean_scores, dtype=np.float64) + std = _numba_pstdev(arr) if not rounding: - return statistics.pstdev(scores) - return round(statistics.pstdev(scores), rounding) + return std + return round(std, rounding) def _count(scores: List[Optional[float]]) -> float: @@ -244,3 +249,21 @@ def _read_text_file(path): except OSError as e: # Handle other I/O related errors raise IOError(f"An error occurred when reading the file at {path}: {e}") + + +@njit(cache=True, fastmath=True) +def _numba_pstdev(arr: np.ndarray) -> float: + """ + Numba-accelerated population standard deviation (stddev, not sample) for a 1D float64 numpy array, + matching the semantics of statistics.pstdev. + """ + n = arr.size + mean = 0.0 + for i in range(n): + mean += arr[i] + mean /= n + ssd = 0.0 + for i in range(n): + diff = arr[i] - mean + ssd += diff * diff + return (ssd / n) ** 0.5