diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 0c7a85c4..a5a28024 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -21,8 +21,8 @@ def __init__( client: Client, verbose: bool, n_examples_shown: int, - log_prob: bool, seed: int = 42, + log_prob: bool = False, **generation_kwargs, ): """ @@ -143,7 +143,8 @@ def _parse( match = re.search(pattern, string) if match is None: raise ValueError("No match found in string") - predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0)) + raw_predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0)) + predictions = [bool(prediction) for prediction in raw_predictions] assert len(predictions) == self.n_examples_shown probabilities = ( self._parse_logprobs(logprobs) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 9dcfe743..b464f235 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -39,6 +39,7 @@ def __init__( temperature=temperature, **generation_kwargs, ) + self.log_prob = log_prob def prompt(self, examples: str, explanation: str) -> list[dict]: return detection_prompt(examples, explanation) diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 4d43d071..5007b706 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -51,7 +51,7 @@ def __init__( temperature=temperature, **generation_kwargs, ) - + self.log_prob = log_prob self.threshold = threshold self.fuzz_type = fuzz_type diff --git a/delphi/scorers/classifier/intruder.py b/delphi/scorers/classifier/intruder.py index 045a1192..81ce5e90 100644 --- a/delphi/scorers/classifier/intruder.py +++ b/delphi/scorers/classifier/intruder.py @@ -170,13 +170,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: active_examples = self.rng.sample(all_active_examples, num_active_examples) # highlights the active tokens with <<>> markers - majority_examples = [] + formatted_examples = [] + chosen_examples = [] num_active_tokens = 0 for example in active_examples: text, _str_tokens = _prepare_text( example, n_incorrect=0, threshold=0.3, highlighted=True ) - majority_examples.append(text) + formatted_examples.append(text) + chosen_examples.append(example) num_active_tokens += (example.activations > 0).sum().item() avg_active_tokens_per_example = num_active_tokens // len(active_examples) @@ -193,6 +195,7 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: threshold=0.3, highlighted=True, ) + elif self.type == "internal": # randomly select a quantile to be the intruder, make sure it's not # the same as the source quantile @@ -224,10 +227,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: # select a random index to insert the intruder sentence intruder_index = self.rng.randint(0, num_active_examples) - examples = ( - majority_examples[:intruder_index] + formatted_examples = ( + formatted_examples[:intruder_index] + [intruder_sentence] - + majority_examples[intruder_index:] + + formatted_examples[intruder_index:] + ) + examples = ( + chosen_examples[:intruder_index] + + [intruder] + + chosen_examples[intruder_index:] ) example_activations = [example.activations.tolist() for example in examples] @@ -235,7 +243,7 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: batches.append( IntruderSentence( - examples=examples, + examples=formatted_examples, intruder_index=intruder_index, chosen_quantile=active_quantile, activations=example_activations, diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index d3fb014a..dfe79759 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -112,7 +112,7 @@ def _prepare_text( if n_incorrect == 0: def is_above_activation_threshold(i: int) -> bool: - return example.activations[i] >= abs_threshold + return bool((example.activations[i] >= abs_threshold).item()) return _highlight(str_toks, is_above_activation_threshold), str_toks @@ -137,6 +137,7 @@ def is_above_activation_threshold(i: int) -> bool: # The activating token is always ctx_len - ctx_len//4 # so we always highlight this one, and if num_tokens_to_highlight > 1 # we highlight num_tokens_to_highlight - 1 random ones + # TODO: This is wrong token_pos = len(str_toks) - len(str_toks) // 4 if token_pos in tokens_below_threshold: random_indices = [token_pos] diff --git a/delphi/scorers/embedding/example_embedding.py b/delphi/scorers/embedding/example_embedding.py index 38b6f17b..828478d8 100644 --- a/delphi/scorers/embedding/example_embedding.py +++ b/delphi/scorers/embedding/example_embedding.py @@ -118,10 +118,12 @@ def compute_batch_deltas(self, batch: Batch) -> tuple[float, float]: # Split the embeddings back into their components n_neg = len(batch.negative_examples) n_pos = len(batch.positive_examples) - negative_examples_embeddings = all_embeddings[:n_neg] - positive_examples_embeddings = all_embeddings[n_neg : n_neg + n_pos] - positive_query_embedding = all_embeddings[-2].unsqueeze(0) - negative_query_embedding = all_embeddings[-1].unsqueeze(0) + negative_examples_embeddings = torch.tensor(all_embeddings[:n_neg]) + positive_examples_embeddings = torch.tensor( + all_embeddings[n_neg : n_neg + n_pos] + ) + positive_query_embedding = torch.tensor(all_embeddings[-2]).unsqueeze(0) + negative_query_embedding = torch.tensor(all_embeddings[-1]).unsqueeze(0) # Compute the similarity between the query and the examples negative_similarities = self.model.similarity( @@ -165,9 +167,11 @@ def _create_batches( # which are going to be used as "explanations" positive_train_examples = record.train + number_samples = min(len(positive_train_examples), len(record.not_active)) + # Sample from the not_active examples not_active_index = self.random.sample( - range(len(record.not_active)), len(positive_train_examples) + range(len(record.not_active)), number_samples ) negative_train_examples = [record.not_active[i] for i in not_active_index] @@ -192,6 +196,7 @@ def _create_batches( positive_query_str, _ = _prepare_text( positive_query, n_incorrect=0, threshold=0.3, highlighted=True ) + # Prepare the negative query if self.method == "default": # In the default method, we just sample a random negative example @@ -205,6 +210,7 @@ def _create_batches( threshold=0.3, highlighted=True, ) + elif self.method == "internal": # In the internal method, we sample a negative example # that has a different quantile as the positive query @@ -216,13 +222,13 @@ def _create_batches( range(len(positive_test_examples)), 1 )[0] negative_query_temp = positive_test_examples[negative_query_idx] - negative_query_quantile = negative_query.distance + negative_query_quantile = negative_query_temp.quantile negative_query = NonActivatingExample( str_tokens=negative_query_temp.str_tokens, tokens=negative_query_temp.tokens, activations=negative_query_temp.activations, - distance=negative_query_temp.quantile, + distance=float(negative_query_temp.quantile), ) # Because it is a converted activating example, it will highlight # the activating tokens @@ -234,15 +240,18 @@ def _create_batches( # that have the same quantile as the positive_query positive_examples = [ e - for e in positive_train_examples + for e in positive_test_examples if e.quantile == positive_query.quantile ] if len(positive_examples) > 10: - positive_examples = self.random.sample(positive_examples, 10) + positive_examples = self.random.sample(positive_examples, 11) positive_examples_str = [ _prepare_text(e, n_incorrect=0, threshold=0.3, highlighted=True)[0] for e in positive_examples ] + # if one example is the same as the positive query, remove it + if positive_query_str in positive_examples_str: + positive_examples_str.remove(positive_query_str) # negative examples if self.method == "default": @@ -259,7 +268,7 @@ def _create_batches( # that has the same quantile as the negative_query negative_examples = [ e - for e in positive_train_examples + for e in positive_test_examples if e.quantile == negative_query.distance ] if len(negative_examples) > 10: @@ -275,7 +284,7 @@ def _create_batches( positive_query=positive_query_str, negative_query=negative_query_str, quantile_positive_query=positive_query.quantile, - distance_negative_query=negative_query.distance, + distance_negative_query=float(negative_query.distance), ) batches.append(batch) return batches diff --git a/delphi/scorers/scorer.py b/delphi/scorers/scorer.py index fa5a0ae5..9cb43061 100644 --- a/delphi/scorers/scorer.py +++ b/delphi/scorers/scorer.py @@ -14,5 +14,5 @@ class ScorerResult(NamedTuple): class Scorer(ABC): @abstractmethod - def __call__(self, record: LatentRecord) -> ScorerResult: + async def __call__(self, record: LatentRecord) -> ScorerResult: pass diff --git a/tests/test_scorers/test_classifier_contracts.py b/tests/test_scorers/test_classifier_contracts.py new file mode 100644 index 00000000..3382c621 --- /dev/null +++ b/tests/test_scorers/test_classifier_contracts.py @@ -0,0 +1,86 @@ +import pytest +import torch + +from delphi.clients.client import Client, Response +from delphi.latents import ActivatingExample, Latent, LatentRecord, NonActivatingExample +from delphi.scorers import DetectionScorer, FuzzingScorer +from delphi.scorers.scorer import ScorerResult + + +class ConstantResponseClient(Client): + def __init__(self, text: str): + super().__init__(model="dummy") + self.text = text + + async def generate(self, prompt, **kwargs): + return Response(text=self.text) + + +def _activating_example() -> ActivatingExample: + return ActivatingExample( + tokens=torch.tensor([1, 2, 3], dtype=torch.int64), + activations=torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32), + str_tokens=["a", "b", "c"], + quantile=1, + ) + + +def _non_activating_example() -> NonActivatingExample: + return NonActivatingExample( + tokens=torch.tensor([1, 2, 3], dtype=torch.int64), + activations=torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32), + str_tokens=["x", "y", "z"], + distance=0.0, + ) + + +def _record() -> LatentRecord: + return LatentRecord( + latent=Latent(module_name="layers.0", latent_index=0), + test=[_activating_example()], + not_active=[_non_activating_example()], + explanation="test explanation", + ) + + +@pytest.mark.asyncio +async def test_detection_scorer_async_contract_returns_scorer_result(): + scorer = DetectionScorer( + client=ConstantResponseClient("[1]"), + n_examples_shown=1, + verbose=False, + ) + + result = await scorer(_record()) + + assert isinstance(result, ScorerResult) + assert result.record.explanation == "test explanation" + assert len(result.score) > 0 + + +def test_detection_parse_casts_binary_ints_to_bool(): + scorer = DetectionScorer( + client=ConstantResponseClient("[0, 1]"), + n_examples_shown=2, + verbose=False, + ) + + predictions, probabilities = scorer._parse("[0, 1]") + + assert predictions == [False, True] + assert probabilities == [None, None] + + +def test_fuzzing_call_sync_contract_and_log_prob_flag(): + scorer = FuzzingScorer( + client=ConstantResponseClient("[1]"), + n_examples_shown=1, + verbose=False, + log_prob=True, + ) + + result = scorer.call_sync(_record()) + + assert scorer.log_prob is True + assert isinstance(result, ScorerResult) + assert len(result.score) > 0 diff --git a/tests/test_scorers/test_intruder_example_embedding.py b/tests/test_scorers/test_intruder_example_embedding.py new file mode 100644 index 00000000..9e100491 --- /dev/null +++ b/tests/test_scorers/test_intruder_example_embedding.py @@ -0,0 +1,107 @@ +import torch + +from delphi.clients.client import Client, Response +from delphi.latents import ActivatingExample, Latent, LatentRecord, NonActivatingExample +from delphi.scorers.classifier.intruder import IntruderScorer +from delphi.scorers.embedding.example_embedding import ExampleEmbeddingScorer + + +class DummyClient(Client): + def __init__(self): + super().__init__(model="dummy") + + async def generate(self, prompt, **kwargs): + return Response(text="[RESPONSE]: 0") + + +def _make_activating_example(token: str, quantile: int) -> ActivatingExample: + tokens = torch.tensor([1, 2, 3], dtype=torch.int64) + activations = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32) + return ActivatingExample( + tokens=tokens, + activations=activations, + str_tokens=[token, token, token], + quantile=quantile, + ) + + +def _make_non_activating_example(token: str, distance: float) -> NonActivatingExample: + tokens = torch.tensor([1, 2, 3], dtype=torch.int64) + activations = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + return NonActivatingExample( + tokens=tokens, + activations=activations, + str_tokens=[token, token, token], + distance=distance, + ) + + +def test_intruder_prepare_and_batch_returns_string_examples(monkeypatch): + def fake_prepare_text(example, n_incorrect, threshold, highlighted): + return f"fmt-{example.str_tokens[0]}-{n_incorrect}", example.str_tokens + + monkeypatch.setattr( + "delphi.scorers.classifier.intruder._prepare_text", fake_prepare_text + ) + + record = LatentRecord( + latent=Latent(module_name="layers.0", latent_index=0), + test=[ + _make_activating_example("A", quantile=0), + _make_activating_example("B", quantile=1), + ], + not_active=[ + _make_non_activating_example("N0", distance=0.1), + _make_non_activating_example("N1", distance=0.2), + ], + ) + + scorer = IntruderScorer(DummyClient(), n_examples_shown=3, type="default", seed=0) + batches = scorer._prepare_and_batch(record) + + assert len(batches) == 2 + for batch in batches: + assert all(isinstance(example, str) for example in batch.examples) + assert len(batch.activations) == len(batch.examples) + assert len(batch.tokens) == len(batch.examples) + assert 0 <= batch.intruder_index < len(batch.examples) + + +def test_example_embedding_internal_batch_creation(monkeypatch): + def fake_prepare_text(example, n_incorrect, threshold, highlighted): + return f"fmt-{example.str_tokens[0]}-{n_incorrect}", example.str_tokens + + monkeypatch.setattr( + "delphi.scorers.embedding.example_embedding._prepare_text", fake_prepare_text + ) + + record = LatentRecord( + latent=Latent(module_name="layers.0", latent_index=1), + train=[ + _make_activating_example("T0", quantile=0), + _make_activating_example("T1", quantile=1), + ], + test=[ + _make_activating_example("Q0", quantile=0), + _make_activating_example("Q1", quantile=1), + _make_activating_example("Q2", quantile=2), + ], + not_active=[ + _make_non_activating_example("N0", distance=0.1), + ], + ) + + scorer = ExampleEmbeddingScorer( + model=object(), + method="internal", + number_batches=2, + seed=0, + ) + batches = scorer._create_batches(record, number_batches=2) + + assert len(batches) == 2 + for batch in batches: + assert isinstance(batch.distance_negative_query, (int, float)) + assert batch.distance_negative_query in {0, 1, 2} + assert len(batch.negative_examples) > 0 + assert isinstance(batch.positive_examples, list) diff --git a/tests/test_scorers/test_sample_highlighting.py b/tests/test_scorers/test_sample_highlighting.py new file mode 100644 index 00000000..68dd1616 --- /dev/null +++ b/tests/test_scorers/test_sample_highlighting.py @@ -0,0 +1,40 @@ +import torch + +from delphi.latents import ActivatingExample, NonActivatingExample +from delphi.scorers.classifier.sample import _prepare_text + + +def _activating_example() -> ActivatingExample: + return ActivatingExample( + tokens=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int64), + activations=torch.tensor([0.0, 0.1, 0.2, 0.0, 0.4, 0.0, 0.9, 0.0]), + str_tokens=["a", "b", "c", "d", "e", "f", "g", "h"], + quantile=1, + ) + + +def _non_activating_example() -> NonActivatingExample: + return NonActivatingExample( + tokens=torch.tensor([1, 2, 3, 4], dtype=torch.int64), + activations=torch.tensor([0.0, 0.0, 0.0, 0.0], dtype=torch.float32), + str_tokens=["x", "y", "z", "w"], + distance=0.0, + ) + + +def test_prepare_text_highlighted_correct_example_returns_markers(): + text, str_toks = _prepare_text( + _activating_example(), n_incorrect=0, threshold=0.3, highlighted=True + ) + + assert str_toks == ["a", "b", "c", "d", "e", "f", "g", "h"] + assert "<<" in text and ">>" in text + + +def test_prepare_text_false_positive_forces_activating_token_position(): + example = _non_activating_example() + text, _ = _prepare_text(example, n_incorrect=1, threshold=0.3, highlighted=True) + + token_pos = len(example.str_tokens) - len(example.str_tokens) // 4 + expected_token = example.str_tokens[token_pos] + assert f"<<{expected_token}>>" in text