Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions delphi/scorers/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(
client: Client,
verbose: bool,
n_examples_shown: int,
log_prob: bool,
seed: int = 42,
log_prob: bool = False,
**generation_kwargs,
):
"""
Expand Down Expand Up @@ -143,7 +143,8 @@ def _parse(
match = re.search(pattern, string)
if match is None:
raise ValueError("No match found in string")
predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0))
raw_predictions: list[bool | Literal[0, 1]] = json.loads(match.group(0))
predictions = [bool(prediction) for prediction in raw_predictions]
assert len(predictions) == self.n_examples_shown
probabilities = (
self._parse_logprobs(logprobs)
Expand Down
1 change: 1 addition & 0 deletions delphi/scorers/classifier/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
temperature=temperature,
**generation_kwargs,
)
self.log_prob = log_prob

def prompt(self, examples: str, explanation: str) -> list[dict]:
return detection_prompt(examples, explanation)
Expand Down
2 changes: 1 addition & 1 deletion delphi/scorers/classifier/fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
temperature=temperature,
**generation_kwargs,
)

self.log_prob = log_prob
self.threshold = threshold
self.fuzz_type = fuzz_type

Expand Down
20 changes: 14 additions & 6 deletions delphi/scorers/classifier/intruder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
active_examples = self.rng.sample(all_active_examples, num_active_examples)

# highlights the active tokens with <<>> markers
majority_examples = []
formatted_examples = []
chosen_examples = []
num_active_tokens = 0
for example in active_examples:
text, _str_tokens = _prepare_text(
example, n_incorrect=0, threshold=0.3, highlighted=True
)
majority_examples.append(text)
formatted_examples.append(text)
chosen_examples.append(example)
num_active_tokens += (example.activations > 0).sum().item()

avg_active_tokens_per_example = num_active_tokens // len(active_examples)
Expand All @@ -193,6 +195,7 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
threshold=0.3,
highlighted=True,
)

elif self.type == "internal":
# randomly select a quantile to be the intruder, make sure it's not
# the same as the source quantile
Expand Down Expand Up @@ -224,18 +227,23 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:

# select a random index to insert the intruder sentence
intruder_index = self.rng.randint(0, num_active_examples)
examples = (
majority_examples[:intruder_index]
formatted_examples = (
formatted_examples[:intruder_index]
+ [intruder_sentence]
+ majority_examples[intruder_index:]
+ formatted_examples[intruder_index:]
)
examples = (
chosen_examples[:intruder_index]
+ [intruder]
+ chosen_examples[intruder_index:]
)

example_activations = [example.activations.tolist() for example in examples]
example_tokens = [example.str_tokens for example in examples]

batches.append(
IntruderSentence(
examples=examples,
examples=formatted_examples,
intruder_index=intruder_index,
chosen_quantile=active_quantile,
activations=example_activations,
Expand Down
3 changes: 2 additions & 1 deletion delphi/scorers/classifier/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _prepare_text(
if n_incorrect == 0:

def is_above_activation_threshold(i: int) -> bool:
return example.activations[i] >= abs_threshold
return bool((example.activations[i] >= abs_threshold).item())

return _highlight(str_toks, is_above_activation_threshold), str_toks

Expand All @@ -137,6 +137,7 @@ def is_above_activation_threshold(i: int) -> bool:
# The activating token is always ctx_len - ctx_len//4
# so we always highlight this one, and if num_tokens_to_highlight > 1
# we highlight num_tokens_to_highlight - 1 random ones
# TODO: This is wrong
token_pos = len(str_toks) - len(str_toks) // 4
if token_pos in tokens_below_threshold:
random_indices = [token_pos]
Expand Down
31 changes: 20 additions & 11 deletions delphi/scorers/embedding/example_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ def compute_batch_deltas(self, batch: Batch) -> tuple[float, float]:
# Split the embeddings back into their components
n_neg = len(batch.negative_examples)
n_pos = len(batch.positive_examples)
negative_examples_embeddings = all_embeddings[:n_neg]
positive_examples_embeddings = all_embeddings[n_neg : n_neg + n_pos]
positive_query_embedding = all_embeddings[-2].unsqueeze(0)
negative_query_embedding = all_embeddings[-1].unsqueeze(0)
negative_examples_embeddings = torch.tensor(all_embeddings[:n_neg])
positive_examples_embeddings = torch.tensor(
all_embeddings[n_neg : n_neg + n_pos]
)
positive_query_embedding = torch.tensor(all_embeddings[-2]).unsqueeze(0)
negative_query_embedding = torch.tensor(all_embeddings[-1]).unsqueeze(0)

# Compute the similarity between the query and the examples
negative_similarities = self.model.similarity(
Expand Down Expand Up @@ -165,9 +167,11 @@ def _create_batches(
# which are going to be used as "explanations"
positive_train_examples = record.train

number_samples = min(len(positive_train_examples), len(record.not_active))

# Sample from the not_active examples
not_active_index = self.random.sample(
range(len(record.not_active)), len(positive_train_examples)
range(len(record.not_active)), number_samples
)
negative_train_examples = [record.not_active[i] for i in not_active_index]

Expand All @@ -192,6 +196,7 @@ def _create_batches(
positive_query_str, _ = _prepare_text(
positive_query, n_incorrect=0, threshold=0.3, highlighted=True
)

# Prepare the negative query
if self.method == "default":
# In the default method, we just sample a random negative example
Expand All @@ -205,6 +210,7 @@ def _create_batches(
threshold=0.3,
highlighted=True,
)

elif self.method == "internal":
# In the internal method, we sample a negative example
# that has a different quantile as the positive query
Expand All @@ -216,13 +222,13 @@ def _create_batches(
range(len(positive_test_examples)), 1
)[0]
negative_query_temp = positive_test_examples[negative_query_idx]
negative_query_quantile = negative_query.distance
negative_query_quantile = negative_query_temp.quantile

negative_query = NonActivatingExample(
str_tokens=negative_query_temp.str_tokens,
tokens=negative_query_temp.tokens,
activations=negative_query_temp.activations,
distance=negative_query_temp.quantile,
distance=float(negative_query_temp.quantile),
)
# Because it is a converted activating example, it will highlight
# the activating tokens
Expand All @@ -234,15 +240,18 @@ def _create_batches(
# that have the same quantile as the positive_query
positive_examples = [
e
for e in positive_train_examples
for e in positive_test_examples
if e.quantile == positive_query.quantile
]
if len(positive_examples) > 10:
positive_examples = self.random.sample(positive_examples, 10)
positive_examples = self.random.sample(positive_examples, 11)
positive_examples_str = [
_prepare_text(e, n_incorrect=0, threshold=0.3, highlighted=True)[0]
for e in positive_examples
]
# if one example is the same as the positive query, remove it
if positive_query_str in positive_examples_str:
positive_examples_str.remove(positive_query_str)

# negative examples
if self.method == "default":
Expand All @@ -259,7 +268,7 @@ def _create_batches(
# that has the same quantile as the negative_query
negative_examples = [
e
for e in positive_train_examples
for e in positive_test_examples
if e.quantile == negative_query.distance
]
if len(negative_examples) > 10:
Expand All @@ -275,7 +284,7 @@ def _create_batches(
positive_query=positive_query_str,
negative_query=negative_query_str,
quantile_positive_query=positive_query.quantile,
distance_negative_query=negative_query.distance,
distance_negative_query=float(negative_query.distance),
)
batches.append(batch)
return batches
2 changes: 1 addition & 1 deletion delphi/scorers/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ class ScorerResult(NamedTuple):

class Scorer(ABC):
@abstractmethod
def __call__(self, record: LatentRecord) -> ScorerResult:
async def __call__(self, record: LatentRecord) -> ScorerResult:
pass
86 changes: 86 additions & 0 deletions tests/test_scorers/test_classifier_contracts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import torch

from delphi.clients.client import Client, Response
from delphi.latents import ActivatingExample, Latent, LatentRecord, NonActivatingExample
from delphi.scorers import DetectionScorer, FuzzingScorer
from delphi.scorers.scorer import ScorerResult


class ConstantResponseClient(Client):
def __init__(self, text: str):
super().__init__(model="dummy")
self.text = text

async def generate(self, prompt, **kwargs):
return Response(text=self.text)


def _activating_example() -> ActivatingExample:
return ActivatingExample(
tokens=torch.tensor([1, 2, 3], dtype=torch.int64),
activations=torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32),
str_tokens=["a", "b", "c"],
quantile=1,
)


def _non_activating_example() -> NonActivatingExample:
return NonActivatingExample(
tokens=torch.tensor([1, 2, 3], dtype=torch.int64),
activations=torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32),
str_tokens=["x", "y", "z"],
distance=0.0,
)


def _record() -> LatentRecord:
return LatentRecord(
latent=Latent(module_name="layers.0", latent_index=0),
test=[_activating_example()],
not_active=[_non_activating_example()],
explanation="test explanation",
)


@pytest.mark.asyncio
async def test_detection_scorer_async_contract_returns_scorer_result():
scorer = DetectionScorer(
client=ConstantResponseClient("[1]"),
n_examples_shown=1,
verbose=False,
)

result = await scorer(_record())

assert isinstance(result, ScorerResult)
assert result.record.explanation == "test explanation"
assert len(result.score) > 0


def test_detection_parse_casts_binary_ints_to_bool():
scorer = DetectionScorer(
client=ConstantResponseClient("[0, 1]"),
n_examples_shown=2,
verbose=False,
)

predictions, probabilities = scorer._parse("[0, 1]")

assert predictions == [False, True]
assert probabilities == [None, None]


def test_fuzzing_call_sync_contract_and_log_prob_flag():
scorer = FuzzingScorer(
client=ConstantResponseClient("[1]"),
n_examples_shown=1,
verbose=False,
log_prob=True,
)

result = scorer.call_sync(_record())

assert scorer.log_prob is True
assert isinstance(result, ScorerResult)
assert len(result.score) > 0
Loading
Loading