Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "on-chain/lib/forge-std"]
path = on-chain/lib/forge-std
url = https://github.com/foundry-rs/forge-std
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ uv venv model_eval --python 3.12

4. Install dependencies
```shell
uv pip install requests transformers nltk numpy pandas scikit-learn textstat datasets evaluate torch torchvision torchaudio sentencepiece confluent-kafka fastapi uvicorn
uv pip install requests transformers nltk numpy pandas scikit-learn textstat datasets evaluate torch torchvision torchaudio sentencepiece confluent-kafka fastapi uvicorn web3 python-dotenv
```

### Running the app
Expand Down
165 changes: 165 additions & 0 deletions eval-node/consumer_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import json
import time
import asyncio
import os
from typing import Dict
from threading import Thread, Event
from logger_config import setup_logging
from dataset_manager import DatasetManager
from prediction_generator import PredictionGenerator
from model_evaluator import ModelEvaluator
from status_manager import TaskStatus
from kafka_connector import KafkaConnector
from dotenv import load_dotenv
from web3_client import Web3Client
import load_contract_abi

logger = setup_logging("consumer_results")
load_dotenv()

class ConsumerResults:
def __init__(
self,
kafka_connector: KafkaConnector,
result_topic: str,
):
node_url = os.getenv('ETHEREUM_NODE_URL', 'http://127.0.0.1:8545'),
contract_address = os.getenv('CONTRACT_ADDRESS'),
private_key = os.getenv('PRIVATE_KEY'),
sender_address = os.getenv('SENDER_ADDRESS')

if not all([node_url, contract_address, private_key, sender_address]):
raise ValueError("Missing required environment variables")

self.kafka_connector = kafka_connector
self.result_topic = result_topic
self.consumer = self.kafka_connector.create_consumer(group_id="worker_group", topic=self.result_topic)
self.producer = self.kafka_connector.create_producer()
self._stop_flag = Event()
# Hold reference to the background thread
self._thread = None

self.web3_client = Web3Client(
node_url=node_url,
contract_address=contract_address,
contract_abi=load_contract_abi(),
private_key=private_key,
sender_address=sender_address
)

logger.info("ConsumerResults initialized")

async def stop(self):
"""
Signal the worker to stop and close the consumer.
Called from FastAPI lifespan shutdown event.
"""
logger.info("Stopping consumer_results worker...")

# Signal the thread loop to exit
self._stop_flag.set()
if self._thread and self._thread.is_alive():
self._thread.join()
self.consumer.close()

def _poll_loop(self):
"""
Blocking loop that polls Kafka messages until _stop_flag is set.
"""
logger.info("Worker poll loop started (thread).")

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

while not self._stop_flag.is_set():
msg = self.consumer.poll(1.0)
if msg is None:
continue
if msg.error():
logger.error(f"Consumer error: {msg.error()}")
continue

try:
task = json.loads(msg.value().decode("utf-8"))
logger.info(f"Task received in poll loop: {task}")

# Run the coroutine til completion in this thread's event loop
loop.run_until_complete(self.process_task(task))

except Exception as e:
logger.error(f"Error processing task: {e}", exc_info=True)

logger.info("Worker poll loop exited.")
loop.close()

async def evaluate_task(self, task_id: str, model_name: str, dataset: str) -> Dict:
try:
logger.info(f"Starting evaluation for task {task_id}")

evaluator = ModelEvaluator()
logger.info("Initializing evaluator...")
await evaluator.initialize()
logger.info("Evaluator initialized")

# Load dataset
logger.info(f"Loading dataset {dataset}")
dataset_mgr = DatasetManager(dataset, split=f"validation[:{QUANTITY_EXAMPLES}]")
await dataset_mgr.load_dataset()
examples = dataset_mgr.get_examples()
references = dataset_mgr.get_references()
logger.info(f"Dataset loaded, got {len(examples)} examples")

# Generate predictions
logger.info("Generating predictions...")
predictor = PredictionGenerator(model_name=model_name, api_url="http://localhost:11434/api/generate")
predictions = await predictor.generate_predictions(examples)
logger.info(f"Generated {len(predictions)} predictions")

# Evaluate
logger.info("Evaluating predictions...")
eval_results = await evaluator.evaluate_predictions(predictions, references)
logger.info("Evaluation complete")

return {
"task_id": task_id,
"status": "completed",
"model_name": model_name,
"dataset": dataset,
"metrics": eval_results,
"timestamp": time.time()
}
except Exception as e:
logger.error(f"Evaluation failed: {e}", exc_info=True)
raise

async def start(self):
"""
Creates a background thread for polling.
"""
logger.info("Worker is starting (thread).")
self._stop_flag.clear()
self._thread = Thread(target=self._poll_loop, daemon=True)
self._thread.start()

async def process_task(self, task: Dict) -> Dict:
"""
Update status, run evaluation and produce to Kafka.
"""
try:
logger.info(f"Processing task: {task}")
task_id = task["task_id"]
# model_name = task["model_name"]
# dataset = task["dataset"]
# metrics = task["metrics"]
# timestamp = task["timestamp"]

logger.info(f"Processing task: {task_id}")

# Evaluate the task
result = await self.evaluate_task(task)

logger.info(f"Task {task_id} completed with status: {result['status']}")
return result
except Exception as e:
logger.error(f"Failed to process message: {e}", exc_info=True)
return {}
Loading