"""
title: Replicate.com custom image model pipe for OpenWebUI.
author: demodomain.dev
author_url: https://github.com/yupguv
version: 1.0
license: MIT
You can request a particular aspect ratio and number of images when sending a prompt. eg:
man sitting in a chair: This will use the default max width and max height and default number of images
16:9 man sitting in a chair: This will return a 16:9 image within the set max width/height, and return the default number of images
x4 man sitting in a chair: This will return an image with the default aspect ratio, and return 4 images
16:9 x4 man sitting in a chair: This will return a 16:9 image within the set max width/height, and return 4 images
The maximum megapixels OpenWebUI can handle (calculated on the number of image outputs x the MP of each image).
Outputting 2 images at 1440 x 1440 is under 4.7MP and so OpenWebUI can handle it.
But if you request 3 or 4 images and your max width and height is set to 1440, the combined MP of all images will be over 4.7,
so all images will be scaled down to be under the 4.7MP limit.
This pipe uses asynchronous API calls because when generating larger images and more than one image at a time, a synchronous API request can be problematic.
The pipe "polls" the Replicate API until the job is done and then returns the results
"""
from typing import (
Dict,
Generator,
Iterator,
Union,
Optional,
Literal,
Tuple,
List,
AsyncIterator,
)
from pydantic import BaseModel, Field
import os
import base64
import aiohttp
import asyncio
import json
import uuid
import time
OutputFormatType = Literal["jpg", "png"]
NumberofOutputImages = Literal[1, 2, 3, 4]
class Pipe:
"""A pipe that generates images from your custom-trained model, using Replicate.com."""
class Valves(BaseModel):
REPLICATE_API_TOKEN: str = Field(
default="", description="Your Replicate API token"
)
REPLICATE_MODEL_VERSION: str = Field(
default="", description="The version of your custom image model."
)
OWUI_TOTAL_MEGAPIXEL_LIMIT: float = Field(
default=4.7,
description="The maximum megapixels OpenWebUI can handle (calculated on the number of image outputs x the MP of each image). Outputting 2 images at 1440 x 1440 is under 4.7MP and so OpenWebUI can handle it. But if you request 3 or 4 images and your max width and height is set to 1440, the combined MP of all images will be over 4.7, so all images will be scaled down to be under the 4.7MP limit.",
)
FLUX_MAX_WIDTH: int = Field(
default=1440, description="Output image width. Max = 1440."
)
FLUX_MAX_HEIGHT: int = Field(
default=1440, description="Output image height. Max = 1440."
)
FLUX_NUMBER_IMAGES_PER_RESPONSE: NumberofOutputImages = Field(
default=1, description="The number of images you want per response."
)
FLUX_OUTPUT_FORMAT: OutputFormatType = Field(
default="jpg", description="Output image format"
)
def __init__(self):
self.type = "pipe"
self.id = "Replicate custom trained model"
self.name = "Replicate custom trained model"
self.MODEL_URL = "https://api.replicate.com/v1/predictions"
self.valves = self.Valves(
REPLICATE_API_TOKEN=os.getenv("REPLICATE_API_TOKEN", ""),
FLUX_MAX_WIDTH=os.getenv("FLUX_MAX_WIDTH", 1440),
FLUX_MAX_HEIGHT=os.getenv("FLUX_MAX_HEIGHT", 1440),
REPLICATE_MODEL_VERSION=os.getenv("REPLICATE_MODEL_VERSION", ""),
OWUI_TOTAL_MEGAPIXEL_LIMIT=os.getenv("OWUI_TOTAL_MEGAPIXEL_LIMIT", 4.7),
FLUX_NUMBER_IMAGES_PER_RESPONSE=os.getenv(
"FLUX_NUMBER_IMAGES_PER_RESPONSE", 1
),
FLUX_OUTPUT_FORMAT=os.getenv("FLUX_OUTPUT_FORMAT", "jpg"),
)
async def _process_image(
self,
urls_or_data: Union[str, List[str]],
prompt: str,
params: Dict,
stream: bool = True,
) -> str: # Changed return type to str
"""Process image data and return HTML for grid layout."""
# Convert single URL to list for consistent handling
if isinstance(urls_or_data, str):
urls_or_data = [urls_or_data]
# Calculate grid layout class based on number of images
num_images = len(urls_or_data)
if num_images == 1:
grid_class = "w-full" # Full width for single image
elif num_images == 2:
grid_class = "w-1/2" # Two columns
else: # 3 or 4 images
grid_class = "w-1/2" # 2x2 grid
# Build HTML string directly
html = ['<div class="generated-images-grid flex flex-wrap gap-2">']
# Add each image
for url in urls_or_data:
html.append(
f'<div class="{grid_class} p-1">'
f'<img src="{url}" alt="Generated Image" '
f'style="width: 100%; height: auto; border-radius: 8px;" />'
f"</div>"
)
# Close grid container
html.append("</div>")
# Add metadata
html.append(
f'<div class="image-metadata" style="font-size: 0.9em; color: var(--text-gray-600); dark:color: var(--text-gray-400);">'
f'<div style="padding: 8px; margin-top: 4px; background: var(--bg-gray-50); dark:background: var(--bg-gray-800); border-radius: 6px;">'
f"<p><strong>Prompt:</strong> {prompt}</p>"
f"<p><strong>Parameters:</strong></p>"
f'<ul style="list-style-type: none; padding-left: 12px;">'
f'<li>• Format: {params.get("output_format", self.valves.FLUX_OUTPUT_FORMAT)}</li>'
f"<li>• Number of Images: {num_images}</li>"
f"</ul></div></div>"
)
return "".join(html)
def _create_sse_chunk(
self,
content: Union[str, Dict],
content_type: str = "text/html",
finish_reason: Optional[str] = None,
) -> str:
"""Create a Server-Sent Events chunk."""
chunk_data = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "dev",
"choices": [
{
"delta": (
{}
if finish_reason
else {
"role": "assistant",
"content": content,
"content_type": content_type,
}
),
"index": 0,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk_data)}\n\n"
async def _wait_for_completion(
self, prediction_url: str, __event_emitter__=None
) -> Dict:
headers = {
"Authorization": f"Token {self.valves.REPLICATE_API_TOKEN}",
"Accept": "application/json",
"Prefer": "wait=30",
}
async with aiohttp.ClientSession() as session:
await asyncio.sleep(2)
async with session.get(
prediction_url, headers=headers, timeout=35
) as response:
response.raise_for_status()
result = await response.json()
if result.get("status") in ["succeeded", "failed", "canceled"]:
return result
await asyncio.sleep(3)
async with session.get(
prediction_url, headers=headers, timeout=35
) as response:
response.raise_for_status()
result = await response.json()
if result.get("status") in ["succeeded", "failed", "canceled"]:
return result
await asyncio.sleep(3)
async with session.get(
prediction_url, headers=headers, timeout=35
) as response:
response.raise_for_status()
final_result = await response.json()
if final_result.get("status") in ["succeeded", "failed", "canceled"]:
return final_result
raise Exception(
f"Generation incomplete after {final_result.get('status')} status"
)
def parse_settings_stepwise(self, text):
"""
Parse both aspect ratio and num outputs in a more flexible order.
Returns (aspect_ratio, num_outputs, actual_prompt)
"""
if not text:
return None, self.valves.FLUX_NUMBER_IMAGES_PER_RESPONSE, text
# Try parsing num_outputs first
num_outputs = self.valves.FLUX_NUMBER_IMAGES_PER_RESPONSE
if text.strip().startswith(("x", "×")):
num_outputs, remaining_text = self.parse_num_outputs(text)
# Then try for aspect ratios
aspect_ratio, actual_prompt = self.parse_aspect_ratio(remaining_text)
else:
# Try aspect ratio first
aspect_ratio, remaining_text = self.parse_aspect_ratio(text)
# Then try for num_outputs
num_outputs, actual_prompt = self.parse_num_outputs(
remaining_text if remaining_text else text
)
return aspect_ratio, num_outputs, actual_prompt
def parse_aspect_ratio(self, text):
"""
Parse the first word of the text to check if it's a valid aspect ratio.
Returns (aspect_ratio, remaining_prompt) if valid, (None, original_text) if not.
"""
if not text:
return None, text
parts = text.strip().split(maxsplit=1)
if len(parts) < 1:
return None, text
potential_ratio = parts[0]
remaining_text = parts[1] if len(parts) > 1 else ""
# Check if it matches the aspect ratio format (e.g., "16:9")
if ":" in potential_ratio:
try:
width, height = map(int, potential_ratio.split(":"))
if width > 0 and height > 0:
return (width, height), remaining_text
except (ValueError, TypeError):
pass
return None, text
def parse_num_outputs(self, text):
"""Parse x1-x4 from the start of text"""
if not text:
return 1, text
parts = text.strip().split(maxsplit=1)
if len(parts) < 1:
return 1, text
# Handle both 'x' and '×' characters
if parts[0].startswith(("x", "×")):
try:
requested_outputs = int(
parts[0][1:]
) # Remove first character and convert to int
num_outputs = max(1, min(4, requested_outputs))
return num_outputs, parts[1] if len(parts) > 1 else ""
except ValueError:
pass
return 1, text
def calculate_dimensions(
self,
aspect_ratio,
num_outputs,
max_total_mp=None,
base_dimension=1440,
):
"""
Calculate image dimensions based on aspect ratio and number of outputs,
ensuring total megapixels stays under max_total_mp.
Args:
aspect_ratio: Tuple of (width, height) ratio
num_outputs: Number of images to generate
max_total_mp: Maximum total megapixels allowed across all images
base_dimension: Base dimension to use for calculations
"""
# Use the valve value if no max_total_mp provided
if max_total_mp is None:
max_total_mp = self.valves.OWUI_TOTAL_MEGAPIXEL_LIMIT
width_ratio, height_ratio = aspect_ratio
# Calculate initial dimensions maintaining aspect ratio
if width_ratio > height_ratio:
# Landscape orientation
initial_width = base_dimension
initial_height = int(base_dimension / width_ratio * height_ratio)
else:
# Portrait or square orientation
initial_height = base_dimension
initial_width = int(base_dimension / height_ratio * width_ratio)
# Calculate megapixels for these dimensions
mp_per_image = (initial_width * initial_height) / 1_000_000
total_mp = mp_per_image * num_outputs
# If we're under the limit, use these dimensions
if total_mp <= max_total_mp:
return initial_width, initial_height
# If we're over the limit, scale down
scale_factor = (max_total_mp / num_outputs / mp_per_image) ** 0.5
# Apply scale factor to maintain aspect ratio while hitting MP target
new_width = int(initial_width * scale_factor)
new_height = int(initial_height * scale_factor)
return new_width, new_height
async def pipe(self, body: Dict, __event_emitter__=None) -> AsyncIterator[str]:
if not self.valves.REPLICATE_API_TOKEN:
yield "Error: REPLICATE_API_TOKEN is required"
return
try:
prompt = (body.get("messages", [{}])[-1].get("content", "") or "").strip()
if not prompt:
yield "Error: No prompt provided"
return
# Parse both settings in a flexible order
aspect_ratio, num_outputs, actual_prompt = self.parse_settings_stepwise(
prompt
)
# Set default dimensions
width, height = self.valves.FLUX_MAX_WIDTH, self.valves.FLUX_MAX_HEIGHT
# Calculate dimensions based on aspect ratio and number of outputs
width, height = self.calculate_dimensions(
aspect_ratio or (1, 1), num_outputs
)
input_params = {
"model": "dev",
"width": width,
"height": height,
"aspect_ratio": "custom",
"go_fast": False,
"lora_scale": 1,
"megapixels": "1",
"num_outputs": num_outputs,
"output_format": self.valves.FLUX_OUTPUT_FORMAT,
"guidance_scale": 3,
"output_quality": 100,
"prompt_strength": 0.8,
"extra_lora_scale": 1,
"num_inference_steps": 28,
"disable_safety_checker": True,
"prompt": "TOK "
+ actual_prompt, # Use the prompt without the aspect ratio
}
if __event_emitter__:
await __event_emitter__(
{
"type": "status",
"data": {
"description": "Starting Wincing Man generation...",
"done": False,
},
}
)
async with aiohttp.ClientSession() as session:
async with session.post(
self.MODEL_URL,
headers={
"Authorization": f"Token {self.valves.REPLICATE_API_TOKEN}",
"Content-Type": "application/json",
"Prefer": "wait=30",
},
json={
"version": self.valves.REPLICATE_MODEL_VERSION,
"input": input_params,
},
timeout=35,
) as response:
response.raise_for_status()
prediction = await response.json()
result = await self._wait_for_completion(
prediction["urls"]["get"], __event_emitter__
)
if result.get("status") != "succeeded":
raise Exception(
f"Generation failed: {result.get('error', 'Unknown error')}"
)
output_urls = result.get("output")
if not output_urls:
raise Exception("No valid output URLs in prediction result")
if result.get("status") != "succeeded":
raise Exception(
f"Generation failed: {result.get('error', 'Unknown error')}"
)
output_urls = result.get("output")
if not output_urls:
raise Exception("No valid output URLs in prediction result")
if __event_emitter__:
# Display each image using markdown
for url in output_urls:
await __event_emitter__(
{
"type": "message",
"data": {
"content": f"",
"content_type": "text/markdown",
},
}
)
# Then emit the metadata
await __event_emitter__(
{
"type": "message",
"data": {
"content": f"""
- **Prompt:** {actual_prompt}
- **Aspect Ratio:** {input_params["aspect_ratio"]}
- **Format:** {input_params["output_format"]}
- **Number of Images:** {num_outputs}
""",
"content_type": "text/markdown",
},
}
)
await __event_emitter__(
{
"type": "status",
"data": {
"description": "Images generated successfully!",
"done": True,
},
}
)
yield ""
except Exception as e:
error_msg = f"Error: {str(e)}"
if __event_emitter__:
await __event_emitter__(
{"type": "status", "data": {"description": error_msg, "done": True}}
)
yield error_msg