diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index c5553958..18e0d8d6 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -4373,3 +4373,5 @@ b [CaH2] [NH3] [OH2] +[TlH2+] +[SbH6+3] diff --git a/chebai/result/classification.py b/chebai/result/classification.py index 725be248..d3c04233 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -7,15 +7,18 @@ import torch from torch import Tensor from torchmetrics.classification import ( - MultilabelF1Score, - MultilabelPrecision, - MultilabelRecall, - MultilabelAUROC, - BinaryF1Score, BinaryAUROC, BinaryAveragePrecision, + BinaryF1Score, + BinaryRecall, + MultilabelAUROC, MultilabelAveragePrecision, + MultilabelF1Score, + MultilabelPrecision, + MultilabelRecall, + MultilabelSpecificity, ) +from torchmetrics.functional import specificity from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 @@ -130,13 +133,39 @@ def metrics_classification_multilabel( f1_micro = MacroF1(preds.shape[1]).to(device=device) my_auc_roc = MultilabelAUROC(preds.shape[1]).to(device=device) my_av_prec = MultilabelAveragePrecision(preds.shape[1]).to(device=device) + my_macro_specificity = MultilabelSpecificity(preds.shape[1], average="macro").to( + device=device + ) + my_micro_specificity = MultilabelSpecificity(preds.shape[1], average="micro").to( + device=device + ) + my_macro_sensitivity = MultilabelRecall(preds.shape[1], average="macro").to( + device=device + ) + my_micro_sensitivity = MultilabelRecall(preds.shape[1], average="micro").to( + device=device + ) macro_f1 = my_f1_macro(preds, labels).cpu().numpy() micro_f1 = f1_micro(preds, labels).cpu().numpy() auc_roc = my_auc_roc(preds, labels).cpu().numpy() prc_auc = my_av_prec(preds, labels).cpu().numpy() - - return auc_roc, macro_f1, micro_f1, bal_acc, prc_auc + specificity_macro = my_macro_specificity(preds, labels).cpu().numpy() + specificity_micro = my_micro_specificity(preds, labels).cpu().numpy() + sensitivity_macro = my_macro_sensitivity(preds, labels).cpu().numpy() + sensitivity_micro = my_micro_sensitivity(preds, labels).cpu().numpy() + + return ( + auc_roc, + macro_f1, + micro_f1, + bal_acc, + prc_auc, + sensitivity_macro, + sensitivity_micro, + specificity_macro, + specificity_micro, + ) def metrics_classification_binary( @@ -151,6 +180,7 @@ def metrics_classification_binary( my_f1 = BinaryF1Score().to(device=device) my_av_prec = BinaryAveragePrecision().to(device=device) my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) + my_sensitivity = BinaryRecall().to(device=device) bal_acc = my_bal_acc(preds, labels).cpu().numpy() auc_roc = my_auc_roc(preds, labels).cpu().numpy() @@ -158,5 +188,7 @@ def metrics_classification_binary( # auc_roc = my_auc_roc.compute().numpy() f1_score = my_f1(preds, labels).cpu().numpy() prc_auc = my_av_prec(preds, labels).cpu().numpy() + sensitivity = my_sensitivity(preds, labels).cpu().numpy() + specificity_result = specificity(preds, labels, task="binary").cpu().numpy() - return auc_roc, f1_score, bal_acc, prc_auc + return auc_roc, f1_score, bal_acc, prc_auc, sensitivity, specificity_result