Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions triton_kernel_agent/opt_worker_component/searching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Searching infrastructure for Optimization Kernel."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""History module for tracking optimization attempts.

Provides persistent storage for kernel optimization attempts, enabling:
- Resume: Continue runs after interruption
- History: Track what was tried and outcomes
- Learning: Use past attempts to guide exploration
"""

from .records import AttemptRecord, Outcome
from .store import AttemptStore, JsonAttemptStore

__all__ = ["AttemptRecord", "Outcome", "AttemptStore", "JsonAttemptStore"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data records for tracking optimization attempts."""

from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum


class Outcome(Enum):
"""Result of an optimization attempt."""

IMPROVED = "improved"
REGRESSED = "regressed"
FAILED = "failed"


@dataclass
class AttemptRecord:
"""A single optimization attempt."""

id: str
kernel_code: str
time_ms: float
outcome: Outcome
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
parent_id: str | None = None

def __repr__(self) -> str:
return f"AttemptRecord(id={self.id}, time_ms={self.time_ms:.4f}, outcome={self.outcome.value})"

def to_dict(self) -> dict:
"""Serialize to dictionary for JSON storage."""
return {
"id": self.id,
"kernel_code": self.kernel_code,
"time_ms": self.time_ms,
"outcome": self.outcome.value,
"created_at": self.created_at.isoformat(),
"parent_id": self.parent_id,
}

@staticmethod
def from_dict(data: dict) -> "AttemptRecord":
"""Deserialize from dictionary."""
return AttemptRecord(
id=data["id"],
kernel_code=data["kernel_code"],
time_ms=data["time_ms"],
outcome=Outcome(data["outcome"]),
created_at=datetime.fromisoformat(data["created_at"]),
parent_id=data.get("parent_id"),
)
126 changes: 126 additions & 0 deletions triton_kernel_agent/opt_worker_component/searching/history/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Storage interface and implementations for optimization attempts.

The attempt store provides persistent storage for kernel optimization
attempts discovered during the search process. This enables:

- Resume: Continue optimization runs after interruption
- History: Track what was tried and what worked/failed
- Learning: Use past attempts to guide future exploration
- Analysis: Understand optimization trajectories post-hoc

Thread/process safety:
- Only the main optimization loop should write to the store
- Workers return results via queue; manager calls add()
"""

import json
from pathlib import Path
from typing import Protocol

from .records import AttemptRecord, Outcome


class AttemptStore(Protocol):
"""Interface for storing and querying optimization attempts.

Implementations must provide:
- add(): Store a new attempt
- get_recent(): Get recent attempts for history context
- get_top_k(): Get best performers for parent selection
- get_best(): Get single best attempt
- count(): Count total attempts
"""

def add(self, attempt: AttemptRecord) -> None:
"""Store an attempt."""
...

def get_recent(self, n: int) -> list[AttemptRecord]:
"""Get the n most recent attempts (oldest first)."""
...

def get_top_k(self, k: int) -> list[AttemptRecord]:
"""Get the k best attempts by time_ms (fastest first)."""
...

def get_best(self) -> AttemptRecord | None:
"""Get the attempt with the lowest time_ms."""
...

def count(self) -> int:
"""Count total attempts in the store."""
...


class JsonAttemptStore:
"""JSON file-based implementation of AttemptStore."""

def __init__(self, path: Path | str) -> None:
self.path = Path(path)
self._attempts: list[AttemptRecord] = []
self._load()

def _load(self) -> None:
"""Load attempts from JSON file if it exists.

Falls back to empty store if the file is corrupted (e.g., partial write).
"""
if self.path.exists():
try:
with open(self.path) as f:
data = json.load(f)
self._attempts = [AttemptRecord.from_dict(d) for d in data]
except (json.JSONDecodeError, KeyError) as e:
import warnings

warnings.warn(f"Corrupted store at {self.path}, starting fresh: {e}")
self._attempts = []

def _save(self) -> None:
"""Save attempts to JSON file."""
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, "w") as f:
json.dump([a.to_dict() for a in self._attempts], f, indent=2)

def add(self, attempt: AttemptRecord) -> None:
"""Store an attempt and persist to disk."""
self._attempts.append(attempt)
self._save()

def get_recent(self, n: int) -> list[AttemptRecord]:
"""Get the n most recent attempts (oldest first)."""
return self._attempts[-n:]

def get_top_k(self, k: int) -> list[AttemptRecord]:
"""Get the k best attempts by time_ms (fastest first).

Ties are broken by created_at (oldest first) for deterministic ordering.
"""
valid = [a for a in self._attempts if a.outcome != Outcome.FAILED]
sorted_by_time = sorted(valid, key=lambda a: (a.time_ms, a.created_at))
return sorted_by_time[:k]

def get_best(self) -> AttemptRecord | None:
"""Get the attempt with the lowest time_ms (excluding failed)."""
valid = [a for a in self._attempts if a.outcome != Outcome.FAILED]
if not valid:
return None
return min(valid, key=lambda a: a.time_ms)

def count(self) -> int:
"""Count total attempts in the store."""
return len(self._attempts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mutation module for building kernel optimization prompts."""

from .mutator import Mutator, SimpleMutator

__all__ = ["Mutator", "SimpleMutator"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mutation strategies for generating kernel optimization prompts.

Mutators build prompts for the LLM to optimize kernels, including:
- The parent kernel to improve
- History of what was tried before
- Any additional context (bottleneck analysis, inspirations, etc.)
"""

from typing import Protocol

from ..history import AttemptRecord, AttemptStore


class Mutator(Protocol):
"""Interface for building optimization prompts."""

def build_prompt(self, parent: AttemptRecord) -> str:
"""Build a prompt for the LLM to optimize the kernel.

Args:
parent: The kernel to optimize.
"""
...


class SimpleMutator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JsonAttemptStore.get_recent() returns newest first (reversed(self._attempts[-n:])).

SimpleMutator.build_prompt() prints history using history[-3:].

If callers pass store.get_recent(...) directly into the mutator, then history[-3:] will actually take the oldest of the “recent” slice (because it’s already newest-first). That’s subtle and will confuse prompt context.

Suggestion: either define a clear convention (history always oldest→newest), or have SimpleMutator treat history as newest-first and use history[:3], or sort by created_at inside the mutator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Update it to ordered as oldest-first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to just have the mutator take in the store?

We can be opinionated on how the mutator chooses history to start. Customization can be added later (e.g. n-history)

"""Minimal mutator: basic prompt with kernel and history."""

def __init__(self, store: AttemptStore) -> None:
self.store = store

def build_prompt(self, parent: AttemptRecord) -> str:
lines = [
"# Optimize this Triton kernel\n",
f"Current performance: {parent.time_ms:.4f}ms\n",
]

history = self.store.get_recent(3)
if history:
lines.append("\n## Recent attempts:\n")
for a in history:
lines.append(f"- [{a.outcome.value}] {a.time_ms:.4f}ms\n")

lines.append("\n## Kernel:\n```python\n")
lines.append(parent.kernel_code)
lines.append("\n```\n")

return "".join(lines)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sampling module for selecting parents from optimization history."""

from .sampler import BestSampler, Sampler

__all__ = ["Sampler", "BestSampler"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sampling strategies for selecting parents and inspirations from history.

Samplers control how we select:
- Parents: Which kernel to optimize next
- Inspirations: Which kernels to show as few-shot examples to the LLM
"""

from typing import Protocol

from ..history import AttemptRecord, AttemptStore


class Sampler(Protocol):
"""Interface for sampling from optimization history."""

def sample_parent(self) -> AttemptRecord | None:
"""Select a parent for the next optimization attempt."""
...

def get_top_inspirations(
self,
n: int,
) -> list[AttemptRecord]:
"""Get top-performing attempts for few-shot prompting."""
...


class BestSampler:
"""Sampler that always returns the best parent and top-k inspirations."""

def __init__(self, store: AttemptStore) -> None:
self.store = store

def sample_parent(self) -> AttemptRecord | None:
return self.store.get_best()

def get_top_inspirations(
self,
n: int,
) -> list[AttemptRecord]:
return self.store.get_top_k(n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure AttemptStore.get_top_k is deterministic on ties (e.g., break ties by created_at or id). If that’s already true, document it in the store.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update to use (time_ms, created_at) for deterministic tie-breaking

Loading