"""
title: Cohere Manifold Pipe
author: justinh-rahb
author_url: https://github.com/justinh-rahb
funding_url: https://github.com/open-webui
version: 0.1.5
license: MIT
"""
import os
import json
import requests
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel, Field
# Set DEBUG to True to enable detailed logging
DEBUG = False
class Pipe:
class Valves(BaseModel):
COHERE_API_BASE_URL: str = Field(default="https://api.cohere.com/v1")
COHERE_API_KEY: str = Field(default="")
def __init__(self):
self.type = "manifold"
self.id = "cohere"
self.name = "cohere/"
self.valves = self.Valves(**{"COHERE_API_KEY": os.getenv("COHERE_API_KEY", "")})
def get_cohere_models(self):
if self.valves.COHERE_API_KEY:
try:
headers = {
"Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
"Content-Type": "application/json",
"Accept": "application/json",
}
if DEBUG:
print(
f"Fetching models from: {self.valves.COHERE_API_BASE_URL}/models"
)
print(f"Headers: {headers}")
r = requests.get(
f"{self.valves.COHERE_API_BASE_URL}/models", headers=headers
)
r.raise_for_status()
models = r.json()
if DEBUG:
print(f"Models response: {models}")
return [
{
"id": model["name"],
"name": model.get("name", model["name"]),
}
for model in models["models"]
]
except Exception as e:
if DEBUG:
print(f"Error fetching Cohere models: {e}")
return [
{
"id": "cohere",
"name": f"Could not fetch models from Cohere: {str(e)}",
},
]
else:
if DEBUG:
print("COHERE_API_KEY is not set")
return []
def pipes(self) -> List[dict]:
return self.get_cohere_models()
def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
try:
model = body["model"]
messages = body["messages"]
stream = body.get("stream", False)
if DEBUG:
print("Incoming body:", json.dumps(body, indent=2))
# Strip "cohere." prefix from model name
if model.startswith("cohere."):
model = model[7:]
if stream:
return self.stream_response(model, messages)
else:
return self.get_completion(model, messages)
except Exception as e:
if DEBUG:
print(f"Error in pipe method: {e}")
return f"Error: {e}"
def stream_response(self, model: str, messages: List[dict]) -> Generator:
headers = {
"Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
"Content-Type": "application/json",
"Accept": "application/json",
}
chat_history = self._format_chat_history(messages[:-1])
user_message = messages[-1]["content"]
system_message = self._get_system_message(messages)
payload = {
"model": model,
"chat_history": chat_history,
"message": user_message,
"stream": True,
**({"system": str(system_message)} if system_message else {}),
}
if DEBUG:
print("Stream request:")
print(f"URL: {self.valves.COHERE_API_BASE_URL}/chat")
print(f"Headers: {headers}")
print(f"Payload: {json.dumps(payload, indent=2)}")
try:
r = requests.post(
url=f"{self.valves.COHERE_API_BASE_URL}/chat",
json=payload,
headers=headers,
stream=True,
)
r.raise_for_status()
for line in r.iter_lines():
if line:
try:
event = json.loads(line.decode("utf-8"))
if DEBUG:
print(f"Received event: {event}")
if event["event_type"] == "text-generation":
yield event["text"]
elif event["event_type"] == "stream-end":
break
except json.JSONDecodeError:
if DEBUG:
print(f"Failed to decode JSON: {line}")
pass
except requests.RequestException as e:
if DEBUG:
print(f"Request exception in stream_response: {e}")
print(
f"Response content: {r.content if 'r' in locals() else 'No response'}"
)
yield f"Error: {str(e)}"
def get_completion(self, model: str, messages: List[dict]) -> str:
headers = {
"Authorization": f"Bearer {self.valves.COHERE_API_KEY}",
"Content-Type": "application/json",
"Accept": "application/json",
}
chat_history = self._format_chat_history(messages[:-1])
user_message = messages[-1]["content"]
system_message = self._get_system_message(messages)
payload = {
"model": model,
"chat_history": chat_history,
"message": user_message,
**({"system": str(system_message)} if system_message else {}),
}
if DEBUG:
print("Completion request:")
print(f"URL: {self.valves.COHERE_API_BASE_URL}/chat")
print(f"Headers: {headers}")
print(f"Payload: {json.dumps(payload, indent=2)}")
try:
r = requests.post(
url=f"{self.valves.COHERE_API_BASE_URL}/chat",
json=payload,
headers=headers,
)
r.raise_for_status()
data = r.json()
if DEBUG:
print(f"Completion response: {json.dumps(data, indent=2)}")
return data.get("text", "No response from Cohere.")
except requests.RequestException as e:
if DEBUG:
print(f"Request exception in get_completion: {e}")
print(
f"Response content: {r.content if 'r' in locals() else 'No response'}"
)
return f"Error: {str(e)}"
def _format_chat_history(self, messages: List[dict]) -> List[dict]:
return [
{
"role": "USER" if message["role"] == "user" else "CHATBOT",
"message": message["content"],
}
for message in messages
if message["role"] != "system"
]
def _get_system_message(self, messages: List[dict]) -> Union[str, None]:
system_messages = [
msg["content"] for msg in messages if msg["role"] == "system"
]
return str(system_messages[0]) if system_messages else None