Function
pipe
v0.1.5
Cohere
Cohere Manifold Pipe
Function ID
cohere
Creator
@justinrahb
Downloads
352+

Function Content
python
"""
title: Cohere Manifold Pipe
author: justinh-rahb
author_url: https://github.com/justinh-rahb
funding_url: https://github.com/open-webui
version: 0.1.5
license: MIT
"""

import os
import json
import requests
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel, Field

# Set DEBUG to True to enable detailed logging
DEBUG = False


class Pipe:
    class Valves(BaseModel):
        COHERE_API_BASE_URL: str = Field(default="https://api.cohere.com/v1")
        COHERE_API_KEY: str = Field(default="")

    def __init__(self):
        self.type = "manifold"
        self.id = "cohere"
        self.name = "cohere/"
        self.valves = self.Valves(**{"COHERE_API_KEY": os.getenv("COHERE_API_KEY", "")})

    def get_cohere_models(self):
        if self.valves.COHERE_API_KEY:
            try:
                headers = {
                    "Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
                    "Content-Type": "application/json",
                    "Accept": "application/json",
                }
                if DEBUG:
                    print(
                        f"Fetching models from: {self.valves.COHERE_API_BASE_URL}/models"
                    )
                    print(f"Headers: {headers}")
                r = requests.get(
                    f"{self.valves.COHERE_API_BASE_URL}/models", headers=headers
                )
                r.raise_for_status()
                models = r.json()
                if DEBUG:
                    print(f"Models response: {models}")
                return [
                    {
                        "id": model["name"],
                        "name": model.get("name", model["name"]),
                    }
                    for model in models["models"]
                ]
            except Exception as e:
                if DEBUG:
                    print(f"Error fetching Cohere models: {e}")
                return [
                    {
                        "id": "cohere",
                        "name": f"Could not fetch models from Cohere: {str(e)}",
                    },
                ]
        else:
            if DEBUG:
                print("COHERE_API_KEY is not set")
            return []

    def pipes(self) -> List[dict]:
        return self.get_cohere_models()

    def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
        try:
            model = body["model"]
            messages = body["messages"]
            stream = body.get("stream", False)

            if DEBUG:
                print("Incoming body:", json.dumps(body, indent=2))

            # Strip "cohere." prefix from model name
            if model.startswith("cohere."):
                model = model[7:]

            if stream:
                return self.stream_response(model, messages)
            else:
                return self.get_completion(model, messages)
        except Exception as e:
            if DEBUG:
                print(f"Error in pipe method: {e}")
            return f"Error: {e}"

    def stream_response(self, model: str, messages: List[dict]) -> Generator:
        headers = {
            "Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
            "Content-Type": "application/json",
            "Accept": "application/json",
        }

        chat_history = self._format_chat_history(messages[:-1])
        user_message = messages[-1]["content"]
        system_message = self._get_system_message(messages)

        payload = {
            "model": model,
            "chat_history": chat_history,
            "message": user_message,
            "stream": True,
            **({"system": str(system_message)} if system_message else {}),
        }

        if DEBUG:
            print("Stream request:")
            print(f"URL: {self.valves.COHERE_API_BASE_URL}/chat")
            print(f"Headers: {headers}")
            print(f"Payload: {json.dumps(payload, indent=2)}")

        try:
            r = requests.post(
                url=f"{self.valves.COHERE_API_BASE_URL}/chat",
                json=payload,
                headers=headers,
                stream=True,
            )
            r.raise_for_status()

            for line in r.iter_lines():
                if line:
                    try:
                        event = json.loads(line.decode("utf-8"))
                        if DEBUG:
                            print(f"Received event: {event}")
                        if event["event_type"] == "text-generation":
                            yield event["text"]
                        elif event["event_type"] == "stream-end":
                            break
                    except json.JSONDecodeError:
                        if DEBUG:
                            print(f"Failed to decode JSON: {line}")
                        pass
        except requests.RequestException as e:
            if DEBUG:
                print(f"Request exception in stream_response: {e}")
                print(
                    f"Response content: {r.content if 'r' in locals() else 'No response'}"
                )
            yield f"Error: {str(e)}"

    def get_completion(self, model: str, messages: List[dict]) -> str:
        headers = {
            "Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
            "Content-Type": "application/json",
            "Accept": "application/json",
        }

        chat_history = self._format_chat_history(messages[:-1])
        user_message = messages[-1]["content"]
        system_message = self._get_system_message(messages)

        payload = {
            "model": model,
            "chat_history": chat_history,
            "message": user_message,
            **({"system": str(system_message)} if system_message else {}),
        }

        if DEBUG:
            print("Completion request:")
            print(f"URL: {self.valves.COHERE_API_BASE_URL}/chat")
            print(f"Headers: {headers}")
            print(f"Payload: {json.dumps(payload, indent=2)}")

        try:
            r = requests.post(
                url=f"{self.valves.COHERE_API_BASE_URL}/chat",
                json=payload,
                headers=headers,
            )
            r.raise_for_status()
            data = r.json()
            if DEBUG:
                print(f"Completion response: {json.dumps(data, indent=2)}")
            return data.get("text", "No response from Cohere.")
        except requests.RequestException as e:
            if DEBUG:
                print(f"Request exception in get_completion: {e}")
                print(
                    f"Response content: {r.content if 'r' in locals() else 'No response'}"
                )
            return f"Error: {str(e)}"

    def _format_chat_history(self, messages: List[dict]) -> List[dict]:
        return [
            {
                "role": "USER" if message["role"] == "user" else "CHATBOT",
                "message": message["content"],
            }
            for message in messages
            if message["role"] != "system"
        ]

    def _get_system_message(self, messages: List[dict]) -> Union[str, None]:
        system_messages = [
            msg["content"] for msg in messages if msg["role"] == "system"
        ]
        return str(system_messages[0]) if system_messages else None