Skip to content

Commit 544aed8

Browse files
committed
add dummy labels to explain smiles
1 parent 28ed332 commit 544aed8

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

chebifier/prediction_models/electra_predictor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,23 @@ def __init__(self, model_name: str, ckpt_path: str, **kwargs):
4343
def explain_smiles(self, smiles) -> dict:
4444
from chebai.preprocessing.reader import EMBEDDING_OFFSET
4545

46+
# Add dummy labels because the collate function requires them.
47+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
48+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
49+
# Note: With New changes from https://github.com/ChEB-AI/python-chebai/pull/130, when labels are None, it also
50+
# causes problems with `missing_labels` handling. Hence using dummy labels.
51+
dummy_labels: list = list(range(1, self.predictor._model.out_dim + 1))
52+
4653
token_dict = self.predictor._dm.reader.to_data(
47-
dict(features=smiles, labels=None)
54+
dict(features=smiles, labels=dummy_labels)
4855
)
4956
tokens = np.array(token_dict["features"]).astype(int).tolist()
5057
result = self.calculate_results([token_dict])
5158

5259
token_labels = (
5360
["[CLR]"]
5461
+ [None for _ in range(EMBEDDING_OFFSET - 1)]
55-
+ list(self._predictor._dm.reader.cache.keys())
62+
+ list(self.predictor._dm.reader.cache.keys())
5663
)
5764

5865
graphs = [

0 commit comments

Comments
 (0)