Skip to content

Conversation

@rkuester
Copy link
Contributor

@rkuester rkuester commented Feb 3, 2026

This is a draft PR for running CI, review, and seeing the commits in
context. The commits along this branch will be individually submitted
for merge.

This obsoletes the original feat-decode branch (#3257), which has been
reworked to address review feedback.

See the linked issue for a description of the change.

BUG=implements #3256

Implement unified module for creating, reading, and modifying TFLite
models with a clean API. The module eliminates manual index tracking
and buffer management through automatic bookkeeping, supporting both
declarative and imperative construction styles.

Wrapper classes (Tensor, Operator, Subgraph, Model) hold the underlying
flatbuffer T objects as backing storage rather than copying fields into
dataclasses. This ensures all schema fields are preserved during
read-modify-write cycles, even fields not explicitly handled by
model_editor. Future schema additions will be preserved automatically.

Add comprehensive test coverage including field preservation tests that
verify unhandled schema fields survive read-modify-write.

BUG=implements tensorflow#3256
…_editor

Replace model_facade with model_editor in compress.py and tests.
model_editor provides a cleaner API with better buffer and metadata
handling.

Update BUILD dependencies accordingly.

BUG=implements tensorflow#3256
Remove model_facade module and its tests, now superseded by
model_editor.

BUG=implements tensorflow#3256
…ess_test

Replace dictionary-based test_models.build() with model_editor's
declarative API for building test models.

BUG=implements tensorflow#3256
Remove test_models module and its tests, now superseded by
model_editor.

BUG=implements tensorflow#3256
Add decode module with DecodeType constants and DecodeCommonMetadata,
per the TFLM DECODE Operator Design document.

BUG=implements tensorflow#3256
Define the plugin interface for compression methods. Each compressor
implements the Compressor protocol with a compress() method that returns
encoded data and ancillary data.

BUG=implements tensorflow#3256
Implement LutCompressor using the Compressor protocol. Lookup table
compression replaces tensor values with indices into a table of unique
values, producing packed indices and ancillary data in the format
expected by the TFLM DECODE kernel.

Supports per-tensor and per-channel compression, sizes value tables to
actual unique count, and handles unquantized tensors.

BUG=implements tensorflow#3256
Add spec types, YAML parser support, and plugin stubs for Huffman and
Pruning compression methods. The plugins raise CompressionError when
invoked, to be replaced with working implementations later.

BUG=implements tensorflow#3256
Add alt_decompression_memory_size parameter to the Python interpreter
API. When non-zero, allocates a separate memory region for DECODE
operator outputs and calls SetDecompressionMemory before AllocateTensors.

BUG=implements tensorflow#3256
Insert DECODE operators before consumers of compressed tensors. Each
consumer gets its own DECODE operator to support alternate decompression
memory, which resets allocations between DECODE invocations.

After insertion, compressed tensors are rewritten to hold encoded data
as UINT8 with shape matching byte count.

BUG=implements tensorflow#3256
Replace monolithic compression logic with a dispatch table that routes
compression requests to plugin modules based on the spec's compression
method type. After compressing tensors, insert DECODE operators into the
model graph.

Warn when compression expands data, helping users identify tensors that
don't benefit from compression.

BUG=implements tensorflow#3256
Add tests that compress models with LUT compression, run them through
the TFLM Python interpreter, and verify outputs match uncompressed
originals. Cover per-tensor and per-channel quantization, various index
bitwidths, unquantized weights, and alternate decompression memory.

BUG=implements tensorflow#3256
Add a manual test for verifying compression on proprietary models that
can't be checked into the repository. See the module docstring for usage
instructions.

BUG=implements tensorflow#3256
Comment on lines 43 to 45
#ifdef USE_TFLM_COMPRESSION
AddDecode();
#endif
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddDecode shouldn't be dependent on USE_TFLM_COMPRESSION (none of the DECODE code is conditionally compiled)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in commit d2ac3ce. The #ifdef USE_TFLM_COMPRESSION guard around AddDecode() is removed, so DECODE is registered unconditionally. The compression and proprietary integration tests also drop their with_compression_enabled gating since DECODE-based models no longer require the flag.

Explicit inheritance from Protocol enables static type checking at
definition time and makes the interface self-documenting.

BUG=implements tensorflow#3256
Comment on lines 251 to 261
# Create DECODE operator
decode_op = model_editor.Operator(
opcode=tflite.BuiltinOperator.CUSTOM,
custom_code=DECODE_CUSTOM_OP_NAME,
inputs=[info.tensor, ancillary_tensor],
outputs=[output_tensor],
)

# Insert DECODE immediately before this consumer
insert_pos = subgraph.operators.index(consumer)
subgraph.operators.insert(insert_pos, decode_op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This being located here does not allow for a single DECODE operator to have multiple encoded inputs and ancillary tensors. The example would be CONCATENATION which takes multiple inputs, where several might be encoded tensors.

The compressor design currently uses one DECODE operator per compressed tensor. Looking at the C++ kernel, I see it already supports multiple input/output pairs. Were you expecting us to batch them into a single DECODE for cases like CONCATENATION? Would that run into the same alt decompression memory problems as reusing the output of a single DECODE? Does using multiple DECODEs also run into those issues?

If a single operator (for example CONCATENATION) has multiple compressed tensors, then those tensors and their DCMs should be batched into a single DECODE operator. The DECODE kernel already handles this (as you observed) and handles the alternate decompression memory correctly in this case.

Or to put it more succinctly: multiple encoded tensors for a single operator, MUST be passed as multiple inputs to a single DECODE operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Multiple compressed tensor inputs to the same operator are now batched into a single DECODE. The grouping is per-consumer, so a tensor shared across different consumers still gets a separate DECODE before each one to avoid clobbering the alternate decompression memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a test where the compression spec is empty? The original model and the no-spec "compressed" model should give the same results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than pass through an empty spec silently, the compressor now rejects an empty spec as an error, since it's almost certainly a mistake. There's a corresponding test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need a test where the simple model has an operator with two inputs that are compressed (FULLY_CONNECTED weight + bias, or two inputs of CONCATENATION). Each encoded tensor could have different bit-width, thus generating different DCM for each. Or perhaps extend an existing test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. test_multiple_compressed_inputs_batched tests a CONCATENATION with two compressed tensor inputs at different bitwidths, verifying a single DECODE with 4 inputs and 2 outputs where each ancillary tensor carries its own distinct data. test_mixed_compressed_and_uncompressed_inputs covers the case where only one of two CONCATENATION inputs is compressed.

The DECODE kernel and its dependencies are already compiled
unconditionally -- none are guarded by USE_TFLM_COMPRESSION. Remove the
#ifdef around AddDecode() in PythonOpsResolver so DECODE-based
compressed models work in a default Python build.

Remove the with_compression_enabled gating from compression and
proprietary integration tests, since they use DECODE-based models that
no longer require the flag.
Now that DECODE is always registered, compress() produces models that
load successfully, making the old test wrong. Rewrite to inject raw
COMPRESSION_METADATA into the flatbuffer metadata via model_editor,
directly exercising the HasCompressionMetadata() detection path for
legacy-compressed models.
Add test_multiple_compressed_inputs_batched: a CONCATENATION with two
compressed tensor inputs, each with a different bitwidth, should
produce a single DECODE with 4 inputs and 2 outputs, each ancillary
tensor carrying its own distinct data. Marked expectedFailure until
the implementation lands.

Add test_mixed_compressed_and_uncompressed_inputs: a CONCATENATION with
one compressed and one plain input leaves the plain input untouched.
This already passes with the current code.
When a single operator (e.g., CONCATENATION) has multiple compressed
tensor inputs, group them into one DECODE instead of creating a separate
DECODE for each.

Grouping is per-consumer, so a tensor shared across different consumers
still gets a separate DECODE before each one to avoid clobbering the
alternate decompression memory.
An empty spec list passed to compress() previously returned an
unmodified model silently. Fail early with a clear error instead,
since an empty spec is almost certainly a mistake.
Comment on lines 91 to 92
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a comment on the meaning of these environment vars

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d40a84e.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants