Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,62 @@
__pycache__
data
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Unit test / coverage reports
.pytest_cache/
.coverage
htmlcov/
coverage.xml
.tox/
.nox/
.coverage.*
.cache
nosetests.xml

# Virtual environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# IDE
.vscode/
.idea/
*.swp
*.swo
*~

# Claude
.claude/*

# Project specific
data/
batch_file.sh
salloc_*
18 changes: 9 additions & 9 deletions fsdp_optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import kron
import kron_mars
import muon
import soap
import utils
from . import kron
from . import kron_mars
from . import muon
from . import soap
from . import utils

from kron import Kron
from kron_mars import KronMars
from muon import Muon
from soap import SOAP
from .kron import Kron
from .kron_mars import KronMars
from .muon import Muon
from .soap import SOAP
2 changes: 1 addition & 1 deletion fsdp_optimizers/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string
import torch
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from utils import to_local, to_dist
from .utils import to_local, to_dist

# adapted from https://github.com/ClashLuke/kron_torch/tree/main

Expand Down
2 changes: 1 addition & 1 deletion fsdp_optimizers/kron_mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
from torch.distributed.tensor import DTensor, distribute_tensor

from utils import to_dist, to_local
from .utils import to_dist, to_local

# torch._dynamo.config.cache_size_limit = 1_000_000

Expand Down
2 changes: 1 addition & 1 deletion fsdp_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.distributed.tensor import DTensor
from torch.utils._pytree import tree_map, tree_flatten
from typing import Generator
from utils import to_local, to_dist
from .utils import to_local, to_dist

# @torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
Expand Down
2 changes: 1 addition & 1 deletion fsdp_optimizers/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from itertools import chain
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate, Shard
from utils import to_local, to_dist
from .utils import to_local, to_dist

# adapted from https://github.com/ClashLuke/SOAP

Expand Down
880 changes: 880 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
[tool.poetry]
name = "fsdp_optimizers"
version = "0.0.1"
description = "supporting pytorch FSDP for optimizers"
authors = ["Ethan Smith <98723285+ethansmith2000@users.noreply.github.com>"]
license = "Apache"
readme = "README.md"
homepage = "https://github.com/ethansmith2000/fsdp_optimizers"
repository = "https://github.com/ethansmith2000/fsdp_optimizers"
classifiers = [
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Intended Audience :: Developers",
]

[tool.poetry.dependencies]
python = "^3.9"
torch = "*"
numpy = "*"

[tool.poetry.group.test.dependencies]
pytest = "^7.0.0"
pytest-cov = "^4.0.0"
pytest-mock = "^3.10.0"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"--strict-markers",
"--strict-config",
"--cov=fsdp_optimizers",
"--cov-report=term-missing",
"--cov-report=html",
"--cov-report=xml",
"--cov-fail-under=80",
"-v"
]
markers = [
"unit: marks tests as unit tests (deselect with '-m \"not unit\"')",
"integration: marks tests as integration tests (deselect with '-m \"not integration\"')",
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
]

[tool.coverage.run]
source = ["fsdp_optimizers"]
omit = [
"*/tests/*",
"*/test_*",
"*/__pycache__/*",
"*/site-packages/*",
"*/build/*",
"*/dist/*"
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:"
]
show_missing = true
skip_covered = false

[tool.coverage.html]
directory = "htmlcov"

[tool.coverage.xml]
output = "coverage.xml"
Empty file added tests/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import MagicMock
import torch
import numpy as np


@pytest.fixture
def temp_dir():
"""Create a temporary directory that gets cleaned up after the test."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def mock_config():
"""Mock configuration object for testing."""
config = MagicMock()
config.lr = 0.01
config.weight_decay = 0.0
config.eps = 1e-8
config.momentum = 0.9
return config


@pytest.fixture
def sample_tensor():
"""Create a sample tensor for testing."""
return torch.randn(10, 10, requires_grad=True)


@pytest.fixture
def sample_model():
"""Create a simple model for testing."""
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 1)
)
return model


@pytest.fixture
def optimizer_params():
"""Standard optimizer parameters for testing."""
return {
'lr': 0.01,
'weight_decay': 0.0,
'eps': 1e-8,
'momentum': 0.9
}


@pytest.fixture
def random_seed():
"""Set random seed for reproducible tests."""
torch.manual_seed(42)
np.random.seed(42)
yield
# Cleanup is automatic


@pytest.fixture
def mock_fsdp_model():
"""Mock FSDP model for testing."""
mock_model = MagicMock()
mock_model.parameters.return_value = [torch.randn(10, 10, requires_grad=True)]
return mock_model
Empty file added tests/integration/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions tests/test_setup_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import torch
import numpy as np
from fsdp_optimizers import Kron, KronMars, Muon, SOAP


class TestSetupValidation:
"""Test suite to validate the testing infrastructure setup."""

@pytest.mark.unit
def test_imports(self):
"""Test that all main modules can be imported."""
assert Kron is not None
assert KronMars is not None
assert Muon is not None
assert SOAP is not None

@pytest.mark.unit
def test_torch_available(self):
"""Test that PyTorch is available and working."""
tensor = torch.tensor([1.0, 2.0, 3.0])
assert tensor.sum().item() == 6.0

@pytest.mark.unit
def test_numpy_available(self):
"""Test that NumPy is available and working."""
arr = np.array([1, 2, 3])
assert arr.sum() == 6

@pytest.mark.unit
def test_fixtures_available(self, temp_dir, mock_config, sample_tensor):
"""Test that pytest fixtures are working correctly."""
assert temp_dir.exists()
assert mock_config is not None
assert sample_tensor.shape == (10, 10)
assert sample_tensor.requires_grad is True

@pytest.mark.integration
def test_optimizer_creation(self, sample_model):
"""Test that optimizers can be created with model parameters."""
params = list(sample_model.parameters())

# Test creating optimizers (this is just a basic instantiation test)
# Use basic lr parameter which should be common
basic_params = {'lr': 0.01}

try:
kron = Kron(params, **basic_params)
assert kron is not None
except Exception as e:
# If the optimizer requires specific parameters or fails, that's fine for validation
# Just ensure we can import and attempt to instantiate
assert isinstance(e, (TypeError, ValueError, RuntimeError))

@pytest.mark.slow
def test_performance_baseline(self):
"""A slow test to validate the slow marker works."""
import time
start = time.time()
# Simulate some work
tensor = torch.randn(1000, 1000)
result = torch.matmul(tensor, tensor)
duration = time.time() - start
assert duration >= 0 # Basic assertion
assert result.shape == (1000, 1000)
Empty file added tests/unit/__init__.py
Empty file.