NOTICE
Open WebUI Community is currently undergoing a major revamp to improve user experience and performance ✨

Function
pipe
v0.2
Flux 1.1 Pro
A function to use flux.1.1.pro model to generate vivid images, with together.ai API key
Function ID
flux_1_1_pro
Creator
@tonyhub
Downloads
644+

Function Content
python
"""
title: FLUX.1.1 Pro Manifold Function for Black Forest Lab Image Generation Models
author: mobilestack, credit to bgeneto
author_url: https://github.com/mobilestack/open-webui-flux-1.1-pro
            forked from https://github.com/bgeneto/open-webui-flux-image-gen
funding_url: https://github.com/open-webui
version: 0.2
license: MIT
requirements: pydantic, requests
environment_variables: FLUX_PRO_API_BASE_URL, FLUX_PRO_API_KEY
supported providers: together.ai
https://api.together.xyz/v1/images/generations
"""

import base64
import os
from typing import Any, Dict, Generator, Iterator, List, Union

import requests
from open_webui.utils.misc import get_last_user_message
from pydantic import BaseModel, Field


class Pipe:
    """
    Class representing the FLUX.1.1-pro Manifold Function.
    """

    class Valves(BaseModel):
        """
        Pydantic model for storing API keys and base URLs.
        """

        FLUX_PRO_API_KEY: str = Field(
            default="", description="Your API Key for Flux.1.1 Pro"
        )
        FLUX_PRO_API_BASE_URL: str = Field(
            default="https://api.together.xyz/v1/images/generations",
            description="Base URL for the API",
        )

    def __init__(self):
        """
        Initialize the Pipe class with default values and environment variables.
        """
        self.type = "manifold"
        self.id = "FLUX_1_1_PRO"
        self.name = "FLUX.1.1-pro: "
        self.valves = self.Valves(
            FLUX_PRO_API_KEY=os.getenv("FLUX_PRO_API_KEY", ""),
            FLUX_PRO_API_BASE_URL=os.getenv(
                "FLUX_PRO_API_BASE_URL",
                "https://api.together.xyz/v1/images/generations",
            ),
        )

    def url_to_img_data(self, url: str) -> str:
        """
        Convert a URL to base64-encoded image data.

        Args:
            url (str): The URL of the image.

        Returns:
            str: Base64-encoded image data.
        """
        headers = {"Authorization": f"Bearer {self.valves.FLUX_PRO_API_KEY}"}
        response = requests.get(url, headers=headers)
        response.raise_for_status()

        content_type = response.headers.get("Content-Type", "application/octet-stream")
        encoded_content = base64.b64encode(response.content).decode("utf-8")
        return f"data:{content_type};base64,{encoded_content}"

    def stream_response(
        self, headers: Dict[str, str], payload: Dict[str, Any]
    ) -> Generator[str, None, None]:

        yield self.non_stream_response(headers, payload)

    def get_img_extension(self, img_data: str) -> Union[str, None]:
        """
        Get the image extension based on the base64-encoded data.

        Args:
            img_data (str): Base64-encoded image data.

        Returns:
            Union[str, None]: The image extension or None if unsupported.
        """
        if img_data.startswith("/9j/"):
            return "jpeg"
        elif img_data.startswith("iVBOR"):
            return "png"
        elif img_data.startswith("R0lG"):
            return "gif"
        elif img_data.startswith("UklGR"):
            return "webp"
        return None

    def handle_json_response(self, response: requests.Response) -> str:
        """
        Handle JSON response from the API.

        Args:
            response (requests.Response): The response object.

        Returns:
            str: The formatted image data or an error message.
        """
        resp = response.json()
        if "output" in resp:
            img_data = resp["output"][0]
        elif "data" in resp and "b64_json" in resp["data"][0]:
            img_data = resp["data"][0]["b64_json"]
        else:
            return "Error: Unexpected response format for the image provider!"

        # split ;base64, from img_data
        try:
            img_data = img_data.split(";base64,")[1]
        except IndexError:
            pass

        img_ext = self.get_img_extension(img_data[:9])
        if not img_ext:
            return "Error: Unsupported image format!"

        # rebuild img_data with proper format
        img_data = f"data:image/{img_ext};base64,{img_data}"
        return f"![Image]({img_data})\n`GeneratedImage.{img_ext}`"

    def handle_image_response(self, response: requests.Response) -> str:
        """
        Handle image response from the API.

        Args:
            response (requests.Response): The response object.

        Returns:
            str: The formatted image data.
        """
        content_type = response.headers.get("Content-Type", "")
        # check image type in the content type
        img_ext = "png"
        if "image/" in content_type:
            img_ext = content_type.split("/")[-1]
        image_base64 = base64.b64encode(response.content).decode("utf-8")
        return f"![Image](data:{content_type};base64,{image_base64})\n`GeneratedImage.{img_ext}`"

    def non_stream_response(
        self, headers: Dict[str, str], payload: Dict[str, Any]
    ) -> str:
        """
        Get a non-streaming response from the API.

        Args:
            headers (Dict[str, str]): The headers for the request.
            payload (Dict[str, Any]): The payload for the request.

        Returns:
            str: The response from the API.
        """
        try:
            response = requests.post(
                url=self.valves.FLUX_PRO_API_BASE_URL,
                headers=headers,
                json=payload,
                stream=False,
                timeout=(3.05, 60),
            )
            response.raise_for_status()

            content_type = response.headers.get("Content-Type", "")
            if "application/json" in content_type:
                return self.handle_json_response(response)
            elif "image/" in content_type:
                return self.handle_image_response(response)
            else:
                return f"Error: Unsupported content type {content_type}"

        except requests.exceptions.RequestException as e:
            return f"Error: Request failed: {e}"
        except Exception as e:
            return f"Error: {e}"

    def pipes(self) -> List[Dict[str, str]]:
        """
        Get the list of available pipes.

        Returns:
            List[Dict[str, str]]: The list of pipes.
        """
        return [{"id": "flux_1_1_pro", "name": "Flux 1.1 PRO"}]

    def pipe(
        self, body: Dict[str, Any]
    ) -> Union[str, Generator[str, None, None], Iterator[str]]:
        """
        Process the pipe request.

        Args:
            body (Dict[str, Any]): The request body.

        Returns:
            Union[str, Generator[str, None, None], Iterator[str]]: The response from the API.
        """
        headers = {
            "Authorization": f"Bearer {self.valves.FLUX_PRO_API_KEY}",
            "Content-Type": "application/json",
        }

        body["stream"] = False
        prompt = get_last_user_message(body["messages"])

        headers_map = {
            "huggingface.co": {"x-wait-for-model": "true"},
            "replicate.com": {"Prefer": "wait"},
            "together.xyz": {},
        }

        payload_map = {
            "huggingface.co": {"inputs": prompt},
            "replicate.com": {
                "input": {
                    "prompt": prompt,
                    "go_fast": True,
                    "num_outputs": 1,
                    "aspect_ratio": "1:1",
                    "output_format": "webp",
                    "output_quality": 90,
                }
            },
            "together.xyz": {
                "model": "black-forest-labs/FLUX.1.1-pro",
                "prompt": prompt,
                "width": 1024,
                "height": 1024,
                "steps": 4,
                "n": 1,
                "response_format": "b64_json",
            },
        }

        payload = None
        for key in payload_map:
            if key in self.valves.FLUX_PRO_API_BASE_URL:
                payload = payload_map[key]
                headers.update(headers_map.get(key, {}))
                break

        if payload is None:
            return "Error: Unsupported API base URL! Remember, that's the beauty of open-source: you can add your own..."

        try:
            if body.get("stream", False):
                return self.stream_response(headers, payload)
            else:
                return self.non_stream_response(headers, payload)
        except requests.exceptions.RequestException as e:
            return f"Error: Request failed: {e}"
        except Exception as e:
            return f"Error: {e}"