import json
from typing import List, Union, Generator, Iterator
from pydantic import BaseModel
import requests
import re
from datetime import timedelta
class Pipe:
class Valves(BaseModel):
GITHUB_PAT: str = ""
GITHUB_MODELS_BASE_URL: str = "https://models.inference.ai.azure.com"
GITHUB_MARKETPLACE_BASE_URL: str = "https://github.com/marketplace"
def __init__(self):
self.id = "github_models"
self.type = "manifold"
self.name = "GitHub: "
self.valves = self.Valves()
self.pipelines = self.get_github_models()
def get_github_models(self):
if self.valves.GITHUB_PAT:
try:
headers = {
"Authorization": f"Bearer {self.valves.GITHUB_PAT}",
"Accept": "application/json",
}
models = []
page = 1
while True:
r = requests.get(
f"{self.valves.GITHUB_MARKETPLACE_BASE_URL}?page={page}&type=models",
headers=headers,
)
response = r.json()
models.extend(response.get("results", []))
if page >= response.get("totalPages", 1):
break
page += 1
return [
{
"id": (
model["original_name"]
if model["original_name"]
else model["name"]
),
"name": (
model["friendly_name"]
if "friendly_name" in model
else model["name"]
),
"description": (model["summary"] if "summary" in model else ""),
}
for model in models
if model["task"] == "chat-completion"
]
except Exception as e:
print(f"Error: {e}")
return []
def pipes(self) -> List[dict]:
return self.get_github_models()
def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
print(f"pipe:{__name__}")
headers = {
"Authorization": f"Bearer {self.valves.GITHUB_PAT}",
"Content-Type": "application/json",
}
allowed_params = {
"messages",
"temperature",
"top_p",
"stop",
"model",
"max_tokens",
"stream_options",
"stream", # Include stream in allowed params
}
# Remap the model name to the model id
body["model"] = ".".join(body["model"].split(".")[1:])
filtered_body = {k: v for k, v in body.items() if k in allowed_params}
# Handle streaming for o1 models differently
is_o1_model = "o1" in body["model"]
should_stream = filtered_body.get("stream", False)
# Don't remove stream parameters for o1 models, just set them appropriately
if is_o1_model:
filtered_body["stream"] = False
# log fields that were filtered out as a single line
if len(body) != len(filtered_body):
print(
f"Dropped params: {', '.join(set(body.keys()) - set(filtered_body.keys()))}"
)
try:
r = requests.post(
url=f"{self.valves.GITHUB_MODELS_BASE_URL}/chat/completions",
json=filtered_body,
headers=headers,
stream=should_stream and not is_o1_model, # Only stream for non-o1 models
)
if not r.status_code == 200:
error_data = r.json()
error_info = error_data.get("error", {})
message = error_info.get("message", "")
try:
message = json.loads(message).get("message", message)
except Exception:
pass
details = error_info.get("details", "")
# Extract wait time in seconds from the details
match = re.search(r'Please wait (\d+) seconds before retrying', details)
if not match:
match = re.search(r'Please wait (\d+) seconds before retrying', message)
if match:
wait_seconds = int(match.group(1))
delta = timedelta(seconds=wait_seconds)
readable_time = str(delta)
message = f"Rate limit exceeded. Please wait {readable_time} before retrying."
return f"Error: {message}"
r.raise_for_status()
if should_stream and not is_o1_model:
return r.iter_lines()
else:
response_json = r.json()
if response_json.get("choices"):
return response_json["choices"][0]["message"]["content"]
return response_json
except Exception as e:
return f"Error: {e} {r.text}"