Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Currently targets the ONNX backend. The pipeline is designed to support addition
| [`mask-object-segnext-b2hq`](models/mask-object-segnext-b2hq/README.md) | mask | SegNext ViT-B SAx2 HQ for semantic masking |
| [`rawdenoise-nind`](models/rawdenoise-nind/README.md) | rawdenoise | UtNet2 raw denoiser trained on RawNIND (Bayer + linear Rec.2020 variants) |
| [`upscale-bsrgan`](models/upscale-bsrgan/README.md) | upscale | BSRGAN 2x and 4x blind super-resolution |
| [`upscale-realplksr`](models/upscale-realplksr/README.md) | upscale | RealPLKSR 2x and 4x super-resolution (real-world photos)|

## Repository structure

Expand Down
4 changes: 2 additions & 2 deletions models/upscale-bsrgan/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ id: upscale-bsrgan
name: "upscale bsrgan"
description: "BSRGAN 2x and 4x blind super-resolution"
task: upscale
version: "0.1"
version: "0.2"
backend: onnx
tiling: true
type: multi
Expand All @@ -17,7 +17,7 @@ attributes:
input_sizes: [256]

model_card:
long_description: "BSRGAN blind image super-resolution using practical degradation model; includes both 2x and 4x upscaling variants with RRDBNet architecture"
long_description: "BSRGAN (Blind Super-Resolution GAN). 2x or 4x photo upscaling that also cleans up noise, JPEG artefacts, and mild blur. Best for high-ISO captures, scanned negatives or prints, and other less-than-pristine source photos"
scope: "image upscaling (2x and 4x blind super-resolution)"
author: "Kai Zhang (ETH Zurich)"
source: "https://github.com/cszn/BSRGAN"
Expand Down
89 changes: 89 additions & 0 deletions models/upscale-realplksr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# RealPLKSR

Real-world variant of PLKSR (Partial Large Kernel CNN for Super-Resolution),
robust to typical photo artefacts (noise, blur, JPEG/WebP compression).
Lightweight pure-CNN architecture – faster than transformer-based upscalers.

Includes both 2x and 4x upscaling variants.

## Source

- Architecture and weights: [dslisleedh/PLKSR](https://github.com/dslisleedh/PLKSR) (MIT) – Dongheon Lee et al., Machine Intelligence Laboratory, University of Seoul
- Paper: [Partial Large Kernel CNNs for Efficient Super-Resolution](https://arxiv.org/abs/2404.11848) (2024)
- Checkpoints: dslisleedh's MSSIM pretrain release – [see issue #4](https://github.com/dslisleedh/PLKSR/issues/4) (MIT, same as architecture)
- x2: `2x_realplksr_mssim_pretrain.pth`
- x4: `4x_realplksr_mssim_pretrain.pth`
- Trained via the [neosr](https://github.com/neosr-project/neosr) framework with the RealESRGAN degradation pipeline
- Loaded for conversion via [chaiNNer-org/spandrel](https://github.com/chaiNNer-org/spandrel), which auto-detects the RealPLKSR variant from the checkpoint state_dict

## Architecture

PLKSR replaces standard depthwise large-kernel convolutions with *partial*
large kernels – applied only to a subset of channels – reducing FLOPs while
keeping the receptive field that drives SR quality. The Real-world variant
swaps the upsampler for DySample and is trained with stronger augmentation
for robustness to typical photo artefacts (noise, blur, compression).

The shipped weights are the MSSIM-pretrain stage (no GAN finetune) –
faithful, conservative output without the texture hallucination risk that
GAN-trained SR models exhibit.

| Property | Value |
|--------------|------------------------------------------------------|
| Architecture | RealPLKSR |
| Parameters | ~7M |
| Receptive | Large (partial 17×17 kernels) |
| Upsampler | DySample |
| Loss | MSSIM (pretrain stage) |

## ONNX Models

| Property | model_x2.onnx | model_x4.onnx |
|------------|--------------------------------------|----------------------------------------|
| Input | `input` – float32 [1, 3, 512, 512] | `input` – float32 [1, 3, 256, 256] |
| Output | `output` – float32 [1, 3, 1024, 1024]| `output` – float32 [1, 3, 1024, 1024] |
| Resolution | Static, baked at 512×512 | Static, baked at 256×256 |
| Opset | 20 | 20 |
| Normalize | [0, 1] range (divide by 255) | [0, 1] range (divide by 255) |
| Tiling | Yes (`model_x2.input_sizes: [512]`) | Yes (`model_x4.input_sizes: [256]`) |

Both variants produce a 1024×1024 output tile – x2 from a 512×512 input,
x4 from a 256×256 input. Per-stem tile sizes are declared in the manifest
so darktable picks the right size for each variant at runtime:

```yaml
attributes:
model_x2:
input_sizes: [512]
model_x4:
input_sizes: [256]
```

## Notes

- Input and output are RGB images in [0, 1] range.
- Output should be clipped to [0, 1] before converting back to uint8.
- Exported with FP32 precision. FP16 export is supported via `--fp16` in
convert args but off by default.
- Inputs are baked into the graph so JIT-compiling EPs (CoreML,
MIGraphX) only pay the compile cost once. Callers must tile at
exactly the declared size.
- Conversion uses [Spandrel](https://github.com/chaiNNer-org/spandrel)
to auto-detect the RealPLKSR variant from the checkpoint's state_dict,
avoiding the need to clone PLKSR or neosr.

## Selection Criteria

| Property | Value |
|--------------------------|-----------------------------------------------------------------------------------------------------|
| Model license | MIT (weights and architecture) |
| OSAID v1.0 | Open Source AI |
| MOF | Class II (Open Tooling) |
| Training data license | DF2K (DIV2K + Flickr2K) per neosr common practice; Flickr2K without an explicit open-source license |
| Training data provenance | Synthetic real-world degradations applied via the neosr framework |
| Training code | [PLKSR](https://github.com/dslisleedh/PLKSR) (MIT) + [neosr](https://github.com/neosr-project/neosr) (Apache-2.0) |
| Known limitations | MSSIM-pretrain checkpoints only (no GAN finetune) – conservative output, no hallucinated detail |
| Published research | [Partial Large Kernel CNNs for Efficient Super-Resolution](https://arxiv.org/abs/2404.11848) |
| Inference | Local only, no cloud dependencies |
| Scope | Image upscaling (2x and 4x super-resolution, robust to noise/blur/compression artefacts) |
| Reproducibility | Full pipeline (setup, convert, clean, demo) |
151 changes: 151 additions & 0 deletions models/upscale-realplksr/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Export RealPLKSR to ONNX via Spandrel.

Spandrel auto-detects the RealPLKSR variant from the checkpoint's
state_dict and builds the network without needing to clone neosr or
the original PLKSR training framework. For a one-shot ONNX export
this is significantly easier than wiring up the training-time scaffolding.

RealPLKSR is a pure-CNN architecture with no window-attention constraints,
so any input multiple of 4 works. We trace at the deployment dim directly
to keep the captured dims consistent with the static shape — same pattern
the DAT-2 export uses, applied here for simplicity.
"""

import argparse
import os

import torch

try:
import onnxconverter_common
HAS_ONNX_CONVERTER = True
except ImportError:
HAS_ONNX_CONVERTER = False

from spandrel import ImageModelDescriptor, ModelLoader


def _patch_plkconv_for_clean_export():
"""Force PLKConv2d to use its split+cat forward path (which is its
training-mode branch) regardless of mode. The eval-mode branch uses
in-place indexed assignment (`x[:, :idx] = conv(...)`), which
torch.onnx exports as ScatterND — ONNX Runtime then prints a
"may not be deterministic if indices are duplicated" warning for
each block at load time. The split+cat path is numerically
identical for contiguous-from-0 channels and exports as plain
Split+Conv+Concat, with no warning.
"""
from spandrel.architectures.PLKSR.__arch.RealPLKSR import PLKConv2d

def forward_export(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1)
x1 = self.conv(x1)
return torch.cat([x1, x2], dim=1)

PLKConv2d.forward = forward_export


def export_to_onnx(model, output_path, scale, height=256, width=256,
dynamic_shapes=True, opset_version=20, fp16=False):
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

import onnx

if dynamic_shapes:
trace_h, trace_w = 64, 64
dynamic_axes = {
'input': {0: 'batch', 2: 'height', 3: 'width'},
'output': {0: 'batch', 2: 'height', 3: 'width'},
}
else:
trace_h, trace_w = height, width
dynamic_axes = None

dummy_input = torch.randn(1, 3, trace_h, trace_w)

torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic_axes,
verbose=False,
)
print(f"Exported: {output_path} (traced at {trace_h}x{trace_w})")

onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print(" ONNX verification passed.")

if not dynamic_shapes:
print(f" Static dims baked: "
f"{height}x{width} -> {height * scale}x{width * scale}")

if fp16:
if not HAS_ONNX_CONVERTER:
print("Warning: onnxconverter-common not installed. Skipping FP16 conversion.")
return
print("Converting to FP16...")
from onnxconverter_common import float16
fp16_model = float16.convert_float_to_float16(onnx_model)
onnx.save(fp16_model, output_path)
print(f"FP16 model saved to {output_path}")


def convert(checkpoint, output, scale, height=256, width=256,
dynamic_shapes=True, opset=20, fp16=False, static=False):
"""Entry point for programmatic conversion."""
scale = int(scale)

_patch_plkconv_for_clean_export()

print(f"Loading RealPLKSR model via Spandrel: {checkpoint}")
descriptor = ModelLoader().load_from_file(checkpoint)
if not isinstance(descriptor, ImageModelDescriptor):
raise TypeError(
f"expected ImageModelDescriptor, got {type(descriptor).__name__}")
if descriptor.scale != scale:
raise ValueError(
f"checkpoint scale={descriptor.scale} does not match requested scale={scale}")

model = descriptor.model
model.eval()

param_count = sum(p.numel() for p in model.parameters())
print(f" Architecture: {descriptor.architecture.id}")
print(f" Scale: x{descriptor.scale}")
print(f" Parameters: {param_count:,}")

print("Exporting to ONNX...")
export_to_onnx(model, output, scale,
height=height, width=width,
dynamic_shapes=dynamic_shapes and not static,
opset_version=opset, fp16=fp16)


def main():
parser = argparse.ArgumentParser(description='Export RealPLKSR to ONNX via Spandrel')
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--scale', type=int, required=True, choices=[2, 3, 4])
parser.add_argument('--height', type=int, default=256)
parser.add_argument('--width', type=int, default=256)
parser.add_argument('--opset', type=int, default=20)
parser.add_argument('--fp16', action='store_true',
help='convert weights to FP16 after export (default: FP32)')
parser.add_argument('--static', action='store_true',
help='bake input height/width into the graph '
'(disables dynamic shape axes)')
args = parser.parse_args()

convert(args.checkpoint, args.output, args.scale,
height=args.height, width=args.width,
opset=args.opset, fp16=args.fp16, static=args.static)


if __name__ == '__main__':
main()
Loading