Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion contest/extract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from tqdm import tqdm
from transformers import AutoTokenizer
from dataclasses import asdict

Expand All @@ -8,7 +9,7 @@

with open("Meta-Llama-3.1-70B-Instruct-Turbo.json", "r") as f:
data = json.load(f)
for _id, row in data.items():
for _id, row in tqdm(data.items()):
claims = extract_and_align_claims(
text=row["output"],
tokens=row["greedy_tokens"],
Expand Down
73 changes: 67 additions & 6 deletions src/reclaim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""ReClaim core package."""

from importlib import metadata
from typing import List
from typing import List, Optional

from .decompose import doc2sentences
from .extract_claims import Claim, ClaimModel, ClaimsExtractor
from .extract_claims import (
Claim,
ClaimModel,
ClaimPostprocessingConfig,
ClaimsExtractor,
)
from .annotate_claims import ClaimsAnnotator
from .openai_client import OpenAIChat

Expand All @@ -20,15 +25,55 @@
"batch_extract_and_align_claims",
"doc2sentences",
"Claim",
"ClaimPostprocessingConfig",
]


def extract_claims(text: str, model: str = "gpt-4o") -> List[Claim]:
def _default_postprocess_config(
override: Optional[ClaimPostprocessingConfig],
enable_defaults: bool,
) -> Optional[ClaimPostprocessingConfig]:
"""
Resolve a postprocessing config:
- honor explicit override;
- if allowed, provide a sensible default bundle;
- otherwise, return None to disable all extras.
"""
if override is not None:
return override
if not enable_defaults:
return None
return ClaimPostprocessingConfig(
rewrite_pronouns=True,
sanitize_with_llm=True,
split_non_atomic=True,
dedupe_with_encoder=True,
dedupe_with_cosine=True,
)


def extract_claims(
text: str,
model: str = "gpt-4.1",
postprocess_config: Optional[ClaimPostprocessingConfig] = None,
enable_default_postprocessing: bool = True,
) -> List[Claim]:
"""
Extract atomic claims from plain text.

By default, enables post-processing (pronoun rewrite, LLM sanitization,
encoder + BoW dedupe). Pass a custom ClaimPostprocessingConfig or set
enable_default_postprocessing=False to skip these steps.
"""
result = doc2sentences(doc=text, mode="atomic_claims", model=model, schema=ClaimModel)
claim_texts = result.claims if isinstance(result, ClaimModel) else result
chat = OpenAIChat(openai_model=model)
config = _default_postprocess_config(postprocess_config, enable_default_postprocessing)
extractor = ClaimsExtractor(
openai_chat=chat,
postprocess_config=config,
)
claim_texts = extractor.postprocess_claims(claim_texts, text)
return [
Claim(
claim_text=claim_text,
Expand All @@ -44,17 +89,25 @@ def extract_and_align_claims(
text,
tokens,
tokenizer,
openai_model: str = "gpt-4o",
openai_model: str = "gpt-4.1",
progress_bar: bool = True,
n_threads: int = 1,
postprocess_config: Optional[ClaimPostprocessingConfig] = None,
enable_default_postprocessing: bool = True,
):
"""
Extract and align claims with token-level provenance from model output tokens.

By default, enables post-processing (pronoun rewrite, LLM sanitization,
encoder + BoW dedupe). Pass a custom ClaimPostprocessingConfig or set
enable_default_postprocessing=False to skip these steps.
"""
config = _default_postprocess_config(postprocess_config, enable_default_postprocessing)
extractor = ClaimsExtractor(
openai_chat=OpenAIChat(openai_model=openai_model),
progress_bar=progress_bar,
n_threads=n_threads,
postprocess_config=config,
)
return extractor.claims_from_text(text, tokens, tokenizer)

Expand All @@ -63,17 +116,25 @@ def batch_extract_and_align_claims(
texts: List[str],
tokens: List[List[int]],
tokenizer,
openai_model: str = "gpt-4o",
openai_model: str = "gpt-4.1",
progress_bar: bool = True,
n_threads: int = 1,
postprocess_config: Optional[ClaimPostprocessingConfig] = None,
enable_default_postprocessing: bool = True,
) -> List[List[Claim]]:
"""
Batch extract and align claims with token-level provenance from model output tokens.

By default, enables post-processing (pronoun rewrite, LLM sanitization,
encoder + BoW dedupe). Pass a custom ClaimPostprocessingConfig or set
enable_default_postprocessing=False to skip these steps.
"""
config = _default_postprocess_config(postprocess_config, enable_default_postprocessing)
extractor = ClaimsExtractor(
openai_chat=OpenAIChat(openai_model=openai_model),
progress_bar=progress_bar,
n_threads=n_threads,
postprocess_config=config,
)

return extractor.batch_claims_from_texts(texts, tokens, tokenizer)
Expand All @@ -82,7 +143,7 @@ def batch_extract_and_align_claims(
def annotate_claims(
claims: List[str],
contexts: List[str],
openai_model: str = "gpt-4o",
openai_model: str = "gpt-4.1",
progress_bar: bool = True,
n_threads: int = 1,
):
Expand Down
8 changes: 7 additions & 1 deletion src/reclaim/claim_level_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

# Lightweight prompt dictionaries keyed by language to mirror original structure.
CLAIM_EXTRACTION_PROMPTS = {
"en": "List all atomic claims from the following sentence. Return each claim on a new line starting with '- '. Sentence: {sent}",
"en": (
"List all atomic, decontextualized claims from the following sentence.\n"
"- One fact per claim (no conjunctions/enumerations).\n"
"- Replace pronouns or vague references with the specific entity so the claim stands alone.\n"
"Return JSON exactly as {\"claims\": [\"...\"]} and nothing else.\n"
"Sentence: {sent}"
),
}

MATCHING_PROMPTS = {
Expand Down
2 changes: 1 addition & 1 deletion src/reclaim/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def doc2sentences(
doc: str,
mode: str = "independent_sentences",
model: str = "gpt-4o",
model: str = "gpt-4.1",
system_role: str = "You are good at decomposing and decontextualizing text.",
num_retries: int = 5,
schema: Optional[BaseModel] = None,
Expand Down
Loading