From 78765208c8d605efad91641b4af409bbc4871c07 Mon Sep 17 00:00:00 2001 From: schnamo Date: Fri, 23 Jan 2026 15:39:27 +0100 Subject: [PATCH 1/3] add senstivity and specificty to evaluation function --- chebai/result/classification.py | 48 +++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/chebai/result/classification.py b/chebai/result/classification.py index ab8b1e2d..fc779625 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -7,15 +7,18 @@ import torch from torch import Tensor from torchmetrics.classification import ( - MultilabelF1Score, - MultilabelPrecision, - MultilabelRecall, - MultilabelAUROC, - BinaryF1Score, BinaryAUROC, BinaryAveragePrecision, + BinaryF1Score, + BinaryRecall, + MultilabelAUROC, MultilabelAveragePrecision, + MultilabelF1Score, + MultilabelPrecision, + MultilabelRecall, + MultilabelSpecificity, ) +from torchmetrics.functional import specificity from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 @@ -131,13 +134,39 @@ def metrics_classification_multilabel( f1_micro = MacroF1(preds.shape[1]).to(device=device) my_auc_roc = MultilabelAUROC(preds.shape[1]).to(device=device) my_av_prec = MultilabelAveragePrecision(preds.shape[1]).to(device=device) + my_macro_specificity = MultilabelSpecificity(preds.shape[1], average="macro").to( + device=device + ) + my_micro_specificity = MultilabelSpecificity(preds.shape[1], average="micro").to( + device=device + ) + my_macro_sensitivity = MultilabelRecall(preds.shape[1], average="macro").to( + device=device + ) + my_micro_sensitivity = MultilabelRecall(preds.shape[1], average="micro").to( + device=device + ) macro_f1 = my_f1_macro(preds, labels).cpu().numpy() micro_f1 = f1_micro(preds, labels).cpu().numpy() auc_roc = my_auc_roc(preds, labels).cpu().numpy() prc_auc = my_av_prec(preds, labels).cpu().numpy() - - return auc_roc, macro_f1, micro_f1, bal_acc, prc_auc + specificity_macro = my_macro_specificity(preds, labels).cpu().numpy() + specificity_micro = my_micro_specificity(preds, labels).cpu().numpy() + sensitivity_macro = my_macro_sensitivity(preds, labels).cpu().numpy() + sensitivity_micro = my_micro_sensitivity(preds, labels).cpu().numpy() + + return ( + auc_roc, + macro_f1, + micro_f1, + bal_acc, + prc_auc, + sensitivity_macro, + sensitivity_micro, + specificity_macro, + specificity_micro, + ) def metrics_classification_binary( @@ -153,6 +182,7 @@ def metrics_classification_binary( my_f1 = BinaryF1Score().to(device=device) my_av_prec = BinaryAveragePrecision().to(device=device) my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) + my_sensitivity = BinaryRecall().to(device=device) bal_acc = my_bal_acc(preds, labels).cpu().numpy() auc_roc = my_auc_roc(preds, labels).cpu().numpy() @@ -160,5 +190,7 @@ def metrics_classification_binary( # auc_roc = my_auc_roc.compute().numpy() f1_score = my_f1(preds, labels).cpu().numpy() prc_auc = my_av_prec(preds, labels).cpu().numpy() + sensitivity = my_sensitivity(preds, labels).cpu().numpy() + specificity_result = specificity(preds, labels, task="binary").cpu().numpy() - return auc_roc, f1_score, bal_acc, prc_auc + return auc_roc, f1_score, bal_acc, prc_auc, sensitivity, specificity_result From 8f760c9d5e23735946e4c497f6a10f452fc324d5 Mon Sep 17 00:00:00 2001 From: schnamo Date: Thu, 12 Feb 2026 23:30:10 +0100 Subject: [PATCH 2/3] fixes --- chebai/loss/bce_weighted.py | 7 +++++++ chebai/preprocessing/bin/smiles_token/tokens.txt | 2 ++ configs/loss/bce_unweighted.yml | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 993d535e..8e03da43 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -7,6 +7,13 @@ from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor +class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss): + + def forward(self, input, target, **kwargs): + # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them + return super().forward(input, target) + + class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter. diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index c5553958..18e0d8d6 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -4373,3 +4373,5 @@ b [CaH2] [NH3] [OH2] +[TlH2+] +[SbH6+3] diff --git a/configs/loss/bce_unweighted.yml b/configs/loss/bce_unweighted.yml index d53533c0..ed0a00b6 100644 --- a/configs/loss/bce_unweighted.yml +++ b/configs/loss/bce_unweighted.yml @@ -1 +1 @@ -class_path: torch.nn.BCEWithLogitsLoss +class_path: chebai.loss.bce_weighted.UnWeightedBCEWithLogitsLoss From 353ff3ce2b0b5e72051ad3ae505d4efe9339813a Mon Sep 17 00:00:00 2001 From: schnamo Date: Fri, 13 Feb 2026 09:44:54 +0100 Subject: [PATCH 3/3] remove duplicate class --- chebai/loss/bce_weighted.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 86c935b6..1f21b04b 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -7,13 +7,6 @@ from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss): - - def forward(self, input, target, **kwargs): - # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them - return super().forward(input, target) - - class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter.