Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions unstructured/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import List, Optional, Union

import click
import numpy as np
import pandas as pd
from numba import njit

from unstructured.staging.base import elements_from_json, elements_to_text

Expand Down Expand Up @@ -214,12 +216,15 @@ def _pstdev(scores: List[Optional[float]], rounding: Optional[int] = 3) -> Union
Args:
rounding (int): optional argument that allows user to define decimal points. Default at 3.
"""
scores = [score for score in scores if score is not None]
if len(scores) <= 1:
# Convert to float array, filtering None (preserve original behavior)
clean_scores = [score for score in scores if score is not None]
if len(clean_scores) <= 1:
return None
arr = np.array(clean_scores, dtype=np.float64)
std = _numba_pstdev(arr)
if not rounding:
return statistics.pstdev(scores)
return round(statistics.pstdev(scores), rounding)
return std
return round(std, rounding)


def _count(scores: List[Optional[float]]) -> float:
Expand All @@ -244,3 +249,21 @@ def _read_text_file(path):
except OSError as e:
# Handle other I/O related errors
raise IOError(f"An error occurred when reading the file at {path}: {e}")


@njit(cache=True, fastmath=True)
def _numba_pstdev(arr: np.ndarray) -> float:
"""
Numba-accelerated population standard deviation (stddev, not sample) for a 1D float64 numpy array,
matching the semantics of statistics.pstdev.
"""
n = arr.size
mean = 0.0
for i in range(n):
mean += arr[i]
mean /= n
ssd = 0.0
for i in range(n):
diff = arr[i] - mean
ssd += diff * diff
return (ssd / n) ** 0.5