"""
title: Grok Manifold Pipe
author: Siraj Florida
author_url: https://github.com/sirajperson
funding_url: https://www.kinothen.com
version: 0.1.0
license: GNUv3
"""
import os
import json
import requests
from typing import List, Union, Generator, Iterator, Optional
from pydantic import BaseModel, Field
from utils.misc import pop_system_message
class Pipe:
class Valves(BaseModel):
"""Configuration for the Grok API interactions."""
GROK_API_KEY: str = Field(default="", description="API key for Grok services.")
GROK_API_BASE_URL: str = Field(
default="https://api.x.ai/v1",
description="Base URL for Grok API endpoints.",
)
MAX_TOKENS: int = Field(
default=4096, description="Maximum number of tokens to generate."
)
TEMPERATURE: float = Field(default=0.8, description="Sampling temperature.")
TOP_P: float = Field(default=0.9, description="Nucleus sampling top_p value.")
STREAM: bool = Field(default=False, description="Whether to stream responses.")
def __init__(self):
self.valves = self.Valves(
GROK_API_KEY=os.getenv("GROK_API_KEY", ""),
GROK_API_BASE_URL=os.getenv("GROK_API_BASE_URL", "https://api.x.ai/v1"),
)
self.type = "manifold"
self.id = "grok"
self.name = "grok/"
def get_model_id(self, model_name: str) -> str:
"""Extract just the base model name from any format"""
# Split on both / and . to handle any format
parts = model_name.replace(".", "/").split("/")
# Return only the actual model name (e.g. "grok-beta")
return parts[-1]
def get_grok_models(self):
headers = {
"Authorization": f"Bearer {self.valves.GROK_API_KEY}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{self.valves.GROK_API_BASE_URL}/models", headers=headers
)
response.raise_for_status()
models_data = response.json()
return [
{"id": model["id"], "name": model["id"]}
for model in models_data.get("data", [])
]
except Exception as e:
print(f"Error fetching models: {e}")
return []
def pipes(self) -> List[dict]:
return self.get_grok_models()
def process_image(self, image_data):
if image_data["image_url"]["url"].startswith("data:image"):
mime_type, base64_data = image_data["image_url"]["url"].split(",", 1)
media_type = mime_type.split(":")[1].split(";")[0]
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data,
},
}
else:
return {
"type": "image",
"source": {"type": "url", "url": image_data["image_url"]["url"]},
}
def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
system_message, messages = pop_system_message(body.get("messages", []))
processed_messages = []
for message in messages:
if isinstance(message.get("content"), list):
for item in message["content"]:
if item["type"] == "text":
processed_messages.append(
{"role": message["role"], "content": item["text"]}
)
elif item["type"] == "image_url":
processed_image = self.process_image(item)
processed_messages.append(processed_image)
else:
processed_messages.append(
{"role": message["role"], "content": message.get("content", "")}
)
# Include system message if present
if system_message:
processed_messages.insert(
0, {"role": "system", "content": str(system_message)}
)
# Extract just the base model name
model_id = self.get_model_id(body["model"])
# Structure payload according to API spec
payload = {
"model": model_id,
"messages": processed_messages,
"stream": body.get("stream", self.valves.STREAM),
"temperature": body.get("temperature", self.valves.TEMPERATURE),
"max_tokens": body.get("max_tokens", self.valves.MAX_TOKENS),
"top_p": body.get("top_p", self.valves.TOP_P),
"frequency_penalty": body.get("frequency_penalty", 0),
"presence_penalty": body.get("presence_penalty", 0),
"stop": body.get("stop", []),
"user": body.get("user", ""),
"n": body.get("n", 1),
"logprobs": body.get("logprobs", False),
"top_logprobs": body.get("top_logprobs", 0),
}
headers = {
"Authorization": f"Bearer {self.valves.GROK_API_KEY}",
"Content-Type": "application/json",
}
url = f"{self.valves.GROK_API_BASE_URL}/chat/completions"
try:
if payload["stream"]:
return self.stream_response(url, headers, payload)
else:
return self.non_stream_response(url, headers, payload)
except Exception as e:
print(f"Error in pipe method: {e}")
return f"Error: {e}"
def stream_response(self, url, headers, payload):
with requests.post(url, headers=headers, json=payload, stream=True) as response:
if response.status_code != 200:
raise Exception(f"HTTP Error {response.status_code}: {response.text}")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
try:
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
print(f"Failed to parse JSON: {line}")
except KeyError as e:
print(f"Unexpected data structure: {e}")
print(f"Full data: {data}")
def non_stream_response(self, url, headers, payload):
response = requests.post(url, headers=headers, json=payload)
if response.status_code != 200:
raise Exception(f"HTTP Error {response.status_code}: {response.text}")
res = response.json()
return res["choices"][0]["message"]["content"] if res.get("choices") else ""