diff --git a/ai_image_gen/ai_image_gen/backend/generation.py b/ai_image_gen/ai_image_gen/backend/generation.py index dbcb393..f9067da 100644 --- a/ai_image_gen/ai_image_gen/backend/generation.py +++ b/ai_image_gen/ai_image_gen/backend/generation.py @@ -1,6 +1,7 @@ import asyncio import datetime import os +from collections.abc import Sequence from enum import Enum import reflex as rx @@ -11,6 +12,14 @@ DEFAULT_IMAGE = "/default.webp" API_TOKEN_ENV_VAR = "REPLICATE_API_TOKEN" +MODEL_ENV_VAR = "REPLICATE_MODEL" +UPSCALE_MODEL_ENV_VAR = "REPLICATE_UPSCALE_MODEL" + +DEFAULT_MODEL = "google/imagen-4-fast" +# philz1337x /clarity-upscaler:029d48aa +DEFAULT_UPSCALE_MODEL = ( + "029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5" +) CopyLocalState = rx._x.client_state(default=False, var_name="copying") @@ -67,7 +76,7 @@ async def generate_image(self): # Await the output from the replicate API response = await replicate.predictions.async_create( - "5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", + os.environ.get(MODEL_ENV_VAR, DEFAULT_MODEL), input=input, ) @@ -99,8 +108,12 @@ async def generate_image(self): await asyncio.sleep(0.15) async with self: self.upscaled_image = "" - self.output_image = response.output[0] - self.output_list = [] if len(response.output) == 1 else response.output + if isinstance(response.output, str): + self.output_image = response.output + self.output_list = [] + elif isinstance(response.output, Sequence): + self.output_image = response.output[0] + self.output_list = list(response.output[1:]) self._reset_state() except Exception as e: @@ -152,7 +165,7 @@ async def upscale_image(self): # Await the output from the replicate API response = await replicate.predictions.async_create( - "029d48aa21712d6769d7a46729c1edf0e4d41919c70b270785f10abb82989ba5", + os.environ.get(UPSCALE_MODEL_ENV_VAR, DEFAULT_UPSCALE_MODEL), input=input, )