Skip to content

Commit 20f195d

Browse files
committed
Add IF-like behaviour
1 parent 23d47ef commit 20f195d

File tree

1 file changed

+37
-2
lines changed

1 file changed

+37
-2
lines changed

distclassipy/anomaly.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pandas as pd
32
from sklearn.base import BaseEstimator, OutlierMixin
43
from sklearn.utils.validation import check_is_fitted, check_array
54
from sklearn.preprocessing import minmax_scale
@@ -24,10 +23,11 @@ class DistanceAnomaly(OutlierMixin, BaseEstimator):
2423
A list of distance metrics to use for the ensemble. If None, uses a
2524
predefined list of 16 stable metrics from the package.
2625
27-
cluster_agg : {'min', 'median'}, default='min'
26+
cluster_agg : {'min', 'mean', 'median'}, default='min'
2827
The aggregation method for distances to different class centroids for a
2928
single metric.
3029
- 'min': An object's distance is its distance to the *nearest* known class.
30+
- 'mean': An object's distance is its mean distance to all *nearest* known classes.
3131
- 'median': A more robust measure of an object's typical distance to all classes.
3232
3333
metric_agg : {'median', 'mean', 'min', 'percentile_25'}, default='median'
@@ -108,6 +108,10 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> "DistanceAnomaly":
108108
else:
109109
self.metrics_ = self.metrics
110110

111+
# Calculate anomaly threshold based on train scores
112+
train_scores = self.decision_function(X)
113+
self.offset_ = np.quantile(train_scores, 1.0 - self.contamination)
114+
111115
return self
112116

113117
def decision_function(self, X: np.ndarray) -> np.ndarray:
@@ -140,6 +144,8 @@ def decision_function(self, X: np.ndarray) -> np.ndarray:
140144
score_for_metric = dist_df.min(axis=1).values
141145
elif self.cluster_agg == "median":
142146
score_for_metric = dist_df.median(axis=1).values
147+
elif self.cluster_agg == "mean":
148+
score_for_metric = dist_df.mean(axis=1).values
143149
else:
144150
raise ValueError(f"Unknown cluster_agg method: {self.cluster_agg}")
145151

@@ -171,6 +177,35 @@ def decision_function(self, X: np.ndarray) -> np.ndarray:
171177

172178
return scores
173179

180+
def score_samples(self, X: np.ndarray) -> np.ndarray:
181+
"""
182+
Calculate the anomaly score, matching scikit-learn's convention.
183+
184+
Note: Opposite of decision_function. Higher scores mean less anomalous (more normal).
185+
This is for compatibility with tools that expect this behavior, like IsolationForest.
186+
"""
187+
return -self.decision_function(X)
188+
189+
def predict(self, X: np.ndarray) -> np.ndarray:
190+
"""
191+
Predict if a particular sample is an inlier (1) or outlie (-1).
192+
193+
Parameters
194+
----------
195+
X : array-like of shape (n_samples,)
196+
The input samples.
197+
198+
Returns
199+
-------
200+
is_outlier : ndarray of shape (n_samples,)
201+
Returns -1 for outliers and 1 for inliers.
202+
"""
203+
check_is_fitted(self)
204+
scores = self.decision_function(X)
205+
is_outlier = np.ones(X.shape[0], dtype=int)
206+
is_outlier[scores >= self.offset_] = -1
207+
return is_outlier
208+
174209
# def predict(self, X: np.ndarray) -> np.ndarray:
175210
# NOTE: UNCOMMENT AFTER FIXING ABOVE offset_ DATA LEAKAGE CONCERN
176211
# """

0 commit comments

Comments
 (0)