NOTICE
Open WebUI Community is currently undergoing a major revamp to improve user experience and performance ✨

Function
action
v0.2
Mixture of Agents OpenAI API
Mixture of Agents with OpenAI API. Modified from MaxKerkula's Mixture of Agents: https://openwebui.com/f/maxkerkula/mixture_of_agents
Function ID
mixture_of_agents_openai_api
Creator
@mochgolf
Downloads
121+

Function Content
python
"""
title: Mixture of Agents OpenAI API
author: mochgolf
version: 0.2
required_open_webui_version: 0.3.9
"""

from pydantic import BaseModel, Field
from typing import Optional, List, Callable, Awaitable
import aiohttp
import random
import asyncio
import time


class Action:
    class Valves(BaseModel):
        temperature: Optional[float] = Field(
            default=None, description="Sampling temperature to use, between 0 and 2."
        )
        top_p: Optional[float] = Field(
            default=None,
            description="Probability mass for nucleus sampling, between 0 and 1.",
        )
        models: List[str] = Field(
            default=[], description="List of models to use in the MoA architecture."
        )
        aggregator_model: str = Field(
            default="", description="Model to use for aggregation tasks."
        )
        openai_api_base: str = Field(
            default="https://api.openai.com/v1", description="Base URL for OpenAI API."
        )
        openai_api_key: str = Field(
            default="", description="OpenAI API key for authentication."
        )
        num_layers: int = Field(default=1, description="Number of MoA layers.")
        num_agents_per_layer: int = Field(
            default=3, description="Number of agents to use in each layer."
        )
        emit_interval: float = Field(
            default=1.0, description="Interval in seconds between status emissions"
        )
        enable_status_indicator: bool = Field(
            default=True, description="Enable or disable status indicator emissions"
        )

    def __init__(self):
        self.valves = self.Valves()
        self.last_emit_time = 0

    async def query_openai(
        self,
        model: str,
        prompt: str,
        __event_emitter__: Callable[[dict], Awaitable[None]] = None,
    ) -> str:
        url = f"{self.valves.openai_api_base}/chat/completions"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.valves.openai_api_key}",
        }
        data = {"model": model, "messages": [{"role": "user", "content": prompt}]}
        if self.valves.temperature is not None:
            data["temperature"] = self.valves.temperature
        if self.valves.top_p is not None:
            data["top_p"] = self.valves.top_p

        try:
            await self.emit_status(
                __event_emitter__,
                "info",
                f"Sending API request to model: {model} with prompt: {prompt}",
                False,
            )

            async with aiohttp.ClientSession() as session:
                async with session.post(url, headers=headers, json=data) as response:
                    response.raise_for_status()
                    result = await response.json()

            await self.emit_status(
                __event_emitter__,
                "info",
                f"Received API response from model: {model} with content: {result['choices'][0]['message']['content']}",
                False,
            )

            return result["choices"][0]["message"]["content"]
        except aiohttp.ClientResponseError as e:
            error_message = f"HTTP error querying OpenAI API for model {model}: {e.status}, {e.message}"
            await self.emit_status(__event_emitter__, "error", error_message, True)
            return f"Error: Unable to query model {model} due to HTTP error {e.status}"
        except aiohttp.ClientError as e:
            error_message = (
                f"Network error querying OpenAI API for model {model}: {str(e)}"
            )
            await self.emit_status(__event_emitter__, "error", error_message, True)
            return f"Error: Unable to query model {model} due to network error"
        except Exception as e:
            error_message = (
                f"Unexpected error querying OpenAI API for model {model}: {str(e)}"
            )
            await self.emit_status(__event_emitter__, "error", error_message, True)
            return f"Error: Unable to query model {model} due to unexpected error"

    async def action(
        self,
        body: dict,
        __user__: Optional[dict] = None,
        __event_emitter__: Callable[[dict], Awaitable[None]] = None,
        __event_call__: Callable[[dict], Awaitable[dict]] = None,
    ) -> Optional[dict]:
        await self.emit_status(
            __event_emitter__, "info", "Starting Mixture of Agents process", False
        )

        try:
            await self.validate_models(__event_emitter__)
        except ValueError as e:
            await self.emit_status(__event_emitter__, "error", str(e), True)
            return {"error": str(e)}

        messages = body.get("messages", [])
        if not messages:
            error_msg = "No messages found in the request body"
            await self.emit_status(__event_emitter__, "error", error_msg, True)
            return {"error": error_msg}

        # user_message = messages[-2]["content"] if len(messages) > 1 else messages[0]
        if len(messages) > 2:
            user_message = messages[-2]["content"]
            previous_agent_response = messages[-1]["content"]
            combined_input = (
                f"User message: {user_message}\n"
                f"Previous agent response: {previous_agent_response}\n"
                f"Generate a new response based on the conversation context."
            )
        else:
            combined_input = messages[0]["content"]
        # moa_response = await self.moa_process(user_message, __event_emitter__)
        moa_response = await self.moa_process(combined_input, __event_emitter__)

        if moa_response.startswith("Error:"):
            await self.emit_status(__event_emitter__, "error", moa_response, True)
            return {"error": moa_response}

        body["messages"][-1]["content"] = moa_response
        await self.emit_status(
            __event_emitter__,
            "info",
            f"Mixture of Agents process completed with final response: {moa_response}",
            True,
        )
        return body

    async def validate_models(
        self, __event_emitter__: Callable[[dict], Awaitable[None]] = None
    ):
        await self.emit_status(__event_emitter__, "info", "Validating models", False)
        valid_models = []
        for model in self.valves.models:
            response = await self.query_openai(
                model, "Test prompt, return Hello.", __event_emitter__
            )
            if not response.startswith("Error:"):
                valid_models.append(model)
            else:
                await self.emit_status(
                    __event_emitter__, "error", f"Model '{model}' is not valid", False
                )

        if not valid_models:
            error_msg = (
                "No valid models available. Please check your model configurations."
            )
            await self.emit_status(__event_emitter__, "error", error_msg, True)
            raise ValueError(error_msg)

        self.valves.models = valid_models
        await self.emit_status(
            __event_emitter__, "info", f"Validated {len(valid_models)} models", False
        )

    async def moa_process(
        self, prompt: str, __event_emitter__: Callable[[dict], Awaitable[None]] = None
    ) -> str:
        if (
            not self.valves.models
            or not self.valves.aggregator_model
            or not self.valves.openai_api_base
        ):
            error_msg = "Configuration error: Models, aggregator model, or API base URL not set."
            await self.emit_status(__event_emitter__, "error", error_msg, True)
            return f"Error: {error_msg}"

        if len(self.valves.models) < self.valves.num_agents_per_layer:
            error_msg = f"Not enough models available. Required: {self.valves.num_agents_per_layer}, Available: {len(self.valves.models)}"
            await self.emit_status(__event_emitter__, "error", error_msg, True)
            return f"Error: {error_msg}"

        layer_outputs = []
        for layer in range(self.valves.num_layers):
            await self.emit_status(
                __event_emitter__,
                "info",
                f"Processing layer {layer + 1}/{self.valves.num_layers}",
                False,
            )

            layer_agents = random.sample(
                self.valves.models,
                min(self.valves.num_agents_per_layer, len(self.valves.models)),
            )

            tasks = [
                self.process_agent(
                    prompt, agent, layer, i, layer_outputs, __event_emitter__
                )
                for i, agent in enumerate(layer_agents)
            ]
            current_layer_outputs = await asyncio.gather(*tasks, return_exceptions=True)

            valid_outputs = [
                output
                for output in current_layer_outputs
                if isinstance(output, str) and not output.startswith("Error:")
            ]
            if not valid_outputs:
                error_msg = (
                    f"No valid responses received from any agent in layer {layer + 1}"
                )
                await self.emit_status(__event_emitter__, "error", error_msg, True)
                return f"Error: {error_msg}"

            layer_outputs.append(valid_outputs)
            await self.emit_status(
                __event_emitter__,
                "info",
                f"Completed layer {layer + 1}/{self.valves.num_layers} with outputs: {valid_outputs}",
                False,
            )

        await self.emit_status(
            __event_emitter__, "info", "Creating final aggregator prompt", False
        )
        final_prompt = self.create_final_aggregator_prompt(prompt, layer_outputs)

        await self.emit_status(
            __event_emitter__,
            "info",
            f"Generating final response with prompt: {final_prompt}",
            False,
        )
        final_response = await self.query_openai(
            self.valves.aggregator_model, final_prompt, __event_emitter__
        )

        if final_response.startswith("Error:"):
            await self.emit_status(
                __event_emitter__, "error", "Failed to generate final response", True
            )
            return f"Error: Failed to generate final response. Last error: {final_response}"

        return final_response

    async def process_agent(
        self, prompt, agent, layer, agent_index, layer_outputs, __event_emitter__
    ):
        await self.emit_status(
            __event_emitter__,
            "info",
            f"Querying agent {agent_index + 1} in layer {layer + 1} with prompt: {prompt}",
            False,
        )

        if layer == 0:
            response = await self.query_openai(agent, prompt, __event_emitter__)
        else:
            await self.emit_status(
                __event_emitter__,
                "info",
                f"Creating aggregator prompt for layer {layer + 1}",
                False,
            )
            aggregator_prompt = self.create_aggregator_prompt(prompt, layer_outputs[-1])
            response = await self.query_openai(
                self.valves.aggregator_model, aggregator_prompt, __event_emitter__
            )

        await self.emit_status(
            __event_emitter__,
            "info",
            f"Received response from agent {agent_index + 1} in layer {layer + 1} with content: {response}",
            False,
        )
        return response

    def create_aggregator_prompt(
        self, original_prompt: str, previous_responses: List[str]
    ) -> str:
        aggregator_prompt = (
            f"Original prompt: {original_prompt}\n\nPrevious responses:\n"
        )
        for i, response in enumerate(previous_responses, 1):
            aggregator_prompt += f"{i}. {response}\n\n"
        aggregator_prompt += "Based on the above responses and the original prompt, think step by step to provide an improved and comprehensive answer:"
        return aggregator_prompt

    def create_final_aggregator_prompt(
        self, original_prompt: str, all_layer_outputs: List[List[str]]
    ) -> str:
        final_prompt = (
            f"Original prompt: {original_prompt}\n\nResponses from all layers:\n"
        )
        for layer, responses in enumerate(all_layer_outputs, 1):
            final_prompt += f"Layer {layer}:\n"
            for i, response in enumerate(responses, 1):
                final_prompt += f" {i}. {response}\n\n"
        final_prompt += (
            "Considering all the responses from different layers and the original prompt, provide a final, comprehensive and reliable answer that strictly adheres to the original request:\n"
            "1. Incorporate relevant information from all previous responses seamlessly.\n"
            "2. Avoid referencing or acknowledging previous responses explicitly unless directed by the prompt.\n"
            "3. Provide a complete and detailed reply addressing the original prompt."
        )
        return final_prompt

    async def emit_status(
        self,
        __event_emitter__: Callable[[dict], Awaitable[None]],
        level: str,
        message: str,
        done: bool,
    ):
        if (
            not self.valves.enable_status_indicator
        ):  # skip this func if enable_status_indicator is disabled.
            return
        current_time = time.time()
        if (
            __event_emitter__
            and self.valves.enable_status_indicator
            and (
                current_time - self.last_emit_time >= self.valves.emit_interval or done
            )
        ):
            await __event_emitter__(
                {
                    "type": "status",
                    "data": {
                        "status": "complete" if done else "in_progress",
                        "level": level,
                        "description": message,
                        "done": done,
                    },
                }
            )
            self.last_emit_time = current_time

    async def on_start(self):
        print("Mixture of Agents Action started")

    async def on_stop(self):
        print("Mixture of Agents Action stopped")