Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 13, 2025

📄 10% (0.10x) speedup for HFEncoder.forward in invokeai/backend/flux/modules/conditioner.py

⏱️ Runtime : 2.69 milliseconds 2.46 milliseconds (best of 11 runs)

📝 Explanation and details

The optimization caches the result of TorchDevice.choose_torch_device() in the HFEncoder constructor, eliminating redundant device selection calls during inference.

Key Changes:

  • Added self.device = TorchDevice.choose_torch_device() in __init__
  • Changed TorchDevice.choose_torch_device() to self.device in the forward() method

Why This Improves Performance:
The line profiler shows that TorchDevice.choose_torch_device() takes significant time (457μs per call in the original vs 73μs for the cached device access). This method involves expensive operations like torch.cuda.is_available() and device normalization that don't change during the model's lifetime.

Impact Analysis:

  • 9% overall speedup with particularly strong gains on smaller batches (29-55% faster on basic test cases)
  • The optimization is most effective for workloads with repeated forward() calls on the same model instance
  • Larger batches see smaller relative improvements (1-3%) since tokenization dominates runtime, but still benefit from reduced device selection overhead
  • The cached device remains valid for the model's lifetime since device configuration is typically static during inference

This optimization is especially valuable for text encoding pipelines where the same HFEncoder instance processes multiple text batches, as each forward pass previously triggered unnecessary device detection logic.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 30 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest  # used for our unit tests
import torch
from invokeai.backend.flux.modules.conditioner import HFEncoder
from torch import nn
from transformers import (PreTrainedModel, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)


# Minimal mock classes to simulate HFEncoder dependencies
class DummyEncoder(nn.Module):
    def __init__(self, output_key="last_hidden_state", hidden_size=4, pooler_size=2, max_length=8):
        super().__init__()
        self.output_key = output_key
        self.hidden_size = hidden_size
        self.pooler_size = pooler_size
        self.max_length = max_length
        self.eval_called = False
        self.requires_grad_set = False

    def eval(self):
        self.eval_called = True
        return self

    def requires_grad_(self, val):
        self.requires_grad_set = val
        return self

    def forward(self, input_ids, attention_mask=None, output_hidden_states=False):
        batch_size = input_ids.shape[0]
        # Simulate output dict with required keys
        return {
            "last_hidden_state": torch.ones((batch_size, self.max_length, self.hidden_size)),
            "pooler_output": torch.ones((batch_size, self.pooler_size))
        }

class DummyTokenizer:
    def __init__(self, vocab=None, max_length=8):
        self.vocab = vocab or {"hello":1, "world":2, "foo":3, "bar":4}
        self.max_length = max_length

    def __call__(self, text, truncation, max_length, return_length, return_overflowing_tokens, padding, return_tensors):
        # Simulate tokenizer output
        batch_size = len(text)
        input_ids = []
        for t in text:
            # Tokenize by splitting and mapping to vocab, pad/truncate as needed
            tokens = [self.vocab.get(w, 0) for w in t.split()]
            if truncation:
                tokens = tokens[:max_length]
            # Pad to max_length
            tokens += [0] * (max_length - len(tokens))
            input_ids.append(tokens)
        input_ids_tensor = torch.tensor(input_ids)
        return {"input_ids": input_ids_tensor}
from invokeai.backend.flux.modules.conditioner import HFEncoder

# ------------------------- UNIT TESTS -------------------------

# BASIC TEST CASES


def test_forward_multiple_sentences_last_hidden_state():
    """Test multiple sentences, non-CLIP encoder."""
    encoder = DummyEncoder(output_key="last_hidden_state", hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=False, max_length=8)
    text = ["hello world", "foo bar"]
    codeflash_output = model.forward(text); output = codeflash_output # 68.8μs -> 53.3μs (29.0% faster)

def test_forward_single_sentence_clip_pooler_output():
    """Test basic functionality: single sentence, CLIP encoder, output shape and values."""
    encoder = DummyEncoder(output_key="pooler_output", hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=True, max_length=8)
    text = ["hello world"]
    codeflash_output = model.forward(text); output = codeflash_output # 61.4μs -> 45.8μs (34.1% faster)

def test_forward_multiple_sentences_clip_pooler_output():
    """Test multiple sentences, CLIP encoder."""
    encoder = DummyEncoder(output_key="pooler_output", hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=True, max_length=8)
    text = ["hello world", "foo bar"]
    codeflash_output = model.forward(text); output = codeflash_output # 61.8μs -> 41.5μs (48.8% faster)

# EDGE TEST CASES






def test_forward_max_length_1():
    """Test with max_length=1, should output shape (batch, 1, hidden_size)."""
    encoder = DummyEncoder(hidden_size=4, pooler_size=2, max_length=1)
    tokenizer = DummyTokenizer(max_length=1)
    model = HFEncoder(encoder, tokenizer, is_clip=False, max_length=1)
    text = ["hello world"]
    codeflash_output = model.forward(text); output = codeflash_output # 57.7μs -> 37.3μs (54.6% faster)

def test_forward_non_string_input_raises():
    """Test with non-string input; should raise TypeError or fail gracefully."""
    encoder = DummyEncoder(hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=False, max_length=8)
    # Pass int instead of string
    text = [123, 456]
    with pytest.raises(Exception):
        model.forward(text) # 3.68μs -> 3.56μs (3.51% faster)

# LARGE SCALE TEST CASES

def test_forward_large_batch_size():
    """Test with large batch size (e.g., 512 sentences)."""
    batch_size = 512
    encoder = DummyEncoder(hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=False, max_length=8)
    text = ["hello world"] * batch_size
    codeflash_output = model.forward(text); output = codeflash_output # 700μs -> 688μs (1.76% faster)


def test_forward_large_batch_and_max_length():
    """Test with large batch and large max_length, but <100MB tensor."""
    batch_size = 64
    max_length = 128
    hidden_size = 16
    encoder = DummyEncoder(hidden_size=hidden_size, pooler_size=2, max_length=max_length)
    tokenizer = DummyTokenizer(max_length=max_length)
    model = HFEncoder(encoder, tokenizer, is_clip=False, max_length=max_length)
    text = ["hello world"] * batch_size
    codeflash_output = model.forward(text); output = codeflash_output # 777μs -> 760μs (2.17% faster)

def test_forward_clip_large_batch():
    """Test CLIP mode with large batch size."""
    batch_size = 256
    encoder = DummyEncoder(hidden_size=4, pooler_size=2, max_length=8)
    tokenizer = DummyTokenizer(max_length=8)
    model = HFEncoder(encoder, tokenizer, is_clip=True, max_length=8)
    text = ["hello world"] * batch_size
    codeflash_output = model.forward(text); output = codeflash_output # 379μs -> 367μs (3.14% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
import torch
from invokeai.backend.flux.modules.conditioner import HFEncoder
from torch import nn
from transformers import (PreTrainedModel, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)

# --- Dummy classes for testing ---

class DummyTokenizer(PreTrainedTokenizer):
    # Minimal implementation for testing
    def __init__(self, vocab_size=10, max_length=8):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length

    def __call__(
        self,
        text,
        truncation=True,
        max_length=None,
        return_length=False,
        return_overflowing_tokens=False,
        padding="max_length",
        return_tensors="pt",
    ):
        # Simulate tokenization: each word becomes an integer, pad/truncate to max_length
        max_length = max_length if max_length is not None else self.max_length
        batch = []
        for sentence in text:
            # Tokenize: split by space, map to int (simulate), limit vocab
            tokens = [min(abs(hash(w)) % self.vocab_size, self.vocab_size - 1) for w in sentence.split()]
            # Truncate/pad
            tokens = tokens[:max_length] + [0] * max(0, max_length - len(tokens))
            batch.append(tokens)
        return {"input_ids": torch.tensor(batch)}

class DummyEncoder(nn.Module):
    # Minimal implementation for testing
    def __init__(self, output_key="last_hidden_state", embedding_dim=4, max_length=8):
        super().__init__()
        self.output_key = output_key
        self.embedding_dim = embedding_dim
        self.max_length = max_length

    def forward(self, input_ids, attention_mask=None, output_hidden_states=False):
        # Simulate encoder output
        batch_size = input_ids.size(0)
        # "last_hidden_state": (batch_size, max_length, embedding_dim)
        # "pooler_output": (batch_size, embedding_dim)
        if self.output_key == "last_hidden_state":
            return {"last_hidden_state": torch.ones(batch_size, self.max_length, self.embedding_dim)}
        else:
            return {"pooler_output": torch.ones(batch_size, self.embedding_dim)}

# --- Dummy TorchDevice for testing ---

class DummyTorchDevice:
    @classmethod
    def choose_torch_device(cls):
        # Always return CPU for tests
        return torch.device("cpu")
from invokeai.backend.flux.modules.conditioner import HFEncoder

# --- Unit Tests ---

# ----------- Basic Test Cases -----------














To edit these changes git checkout codeflash/optimize-HFEncoder.forward-mhwve1o1 and push.

Codeflash Static Badge

The optimization caches the result of `TorchDevice.choose_torch_device()` in the `HFEncoder` constructor, eliminating redundant device selection calls during inference.

**Key Changes:**
- Added `self.device = TorchDevice.choose_torch_device()` in `__init__`
- Changed `TorchDevice.choose_torch_device()` to `self.device` in the `forward()` method

**Why This Improves Performance:**
The line profiler shows that `TorchDevice.choose_torch_device()` takes significant time (457μs per call in the original vs 73μs for the cached device access). This method involves expensive operations like `torch.cuda.is_available()` and device normalization that don't change during the model's lifetime.

**Impact Analysis:**
- **9% overall speedup** with particularly strong gains on smaller batches (29-55% faster on basic test cases)
- The optimization is most effective for workloads with repeated `forward()` calls on the same model instance
- Larger batches see smaller relative improvements (1-3%) since tokenization dominates runtime, but still benefit from reduced device selection overhead
- The cached device remains valid for the model's lifetime since device configuration is typically static during inference

This optimization is especially valuable for text encoding pipelines where the same `HFEncoder` instance processes multiple text batches, as each forward pass previously triggered unnecessary device detection logic.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 13, 2025 03:29
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant