diff --git a/src/workers/upscaler.worker.ts b/src/workers/upscaler.worker.ts index dd1263d8..1eafa58d 100644 --- a/src/workers/upscaler.worker.ts +++ b/src/workers/upscaler.worker.ts @@ -1,4 +1,59 @@ import Upscaler from "upscaler"; +import * as tf from "@tensorflow/tfjs"; + +// Register custom layer used by ESRGAN medium/thick models (loaded from CDN) +// Without this, TF.js throws "Unknown layer: MultiplyBeta" when loading these models +class MultiplyBeta extends tf.layers.Layer { + static className = "MultiplyBeta"; + private beta: number; + + constructor(config: Record = {}) { + super(config); + this.beta = (config.beta as number) ?? 0.2; + } + + call(inputs: tf.Tensor | tf.Tensor[]): tf.Tensor { + const input = Array.isArray(inputs) ? inputs[0] : inputs; + return tf.mul(input, tf.scalar(this.beta)); + } + + getConfig() { + return { ...super.getConfig(), beta: this.beta }; + } +} +tf.serialization.registerClass(MultiplyBeta); + +// PixelShuffle layer used by ESRGAN thick models — does depth-to-space rearrangement +function createPixelShuffleClass(scale: number) { + class PixelShuffle extends tf.layers.Layer { + static className = `PixelShuffle${scale}x`; + private scale: number; + + constructor(config: Record = {}) { + super(config); + this.scale = scale; + } + + computeOutputShape(inputShape: Array): Array { + return [inputShape[0], inputShape[1], inputShape[2], 3]; + } + + call(inputs: tf.Tensor | tf.Tensor[]): tf.Tensor { + const input = Array.isArray(inputs) ? inputs[0] : inputs; + return tf.depthToSpace(input as tf.Tensor4D, this.scale, "NHWC"); + } + + getConfig() { + return { ...super.getConfig(), scale: this.scale }; + } + } + return PixelShuffle; +} + +// Register PixelShuffle for all supported scales +[2, 3, 4].forEach((s) => { + tf.serialization.registerClass(createPixelShuffleClass(s)); +}); type ModelType = "slim" | "medium" | "thick"; type ScaleType = "2x" | "3x" | "4x";