"""
title: Content Violation Filter
author: jannikstdl
date: 2024-10-13
version: 1.0
description: A filter function that checks messages for content violations using Ollama moderation models 'llama-guard3' or 'shieldgemma' and flags them accordingly.
ShieldGemma: https://ollama.com/library/shieldgemma
LlamaGuard3: https://ollama.com/library/llama-guard3
"""
from typing import Optional, Callable, Any, Awaitable
from pydantic import BaseModel, Field
import requests
from open_webui.utils.misc import get_last_user_message
class Filter:
class Valves(BaseModel):
priority: int = Field(
default=0, description="Priority level for the filter operations."
)
ollama_api_base_url: str = Field(
default="http://localhost:11434",
description="Base URL for the Ollama API.",
)
moderation_model: str = Field(
default="llama-guard3:latest",
description=(
"Model to use for content moderation. Options are "
"'llama-guard3' and 'shieldgemma'."
),
)
enabled_for_admins: bool = Field(
default=True, description="Whether moderation is enabled for admin users."
)
enable_logs: bool = Field(
default=False, description="Whether to print logs to console (Error/Info)."
)
def __init__(self) -> None:
self.valves = self.Valves()
def log(self, message: str, log_type: str = "INFO") -> None:
if self.valves.enable_logs:
if log_type == "ERROR":
# ANSI code for light red text
print(f"\033[91m{log_type}: {message}\033[0m")
else:
print(f"{log_type}: {message}")
async def check_content(
self,
content: str,
__event_emitter__: Callable[[Any], Awaitable[None]],
) -> None:
headers = {
"Content-Type": "application/json",
}
payload = {
"model": self.valves.moderation_model,
"messages": [{"role": "user", "content": content}],
"stream": False,
}
try:
r = requests.post(
url=f"{self.valves.ollama_api_base_url}/api/chat",
json=payload,
headers=headers,
)
r.raise_for_status()
response = r.json()
except Exception as e:
error_message = str(e)
status_message = f"Content Moderation Filter: ERROR - Check your filter configuration or see logs if enabled in filter options for more infos."
self.log(f"Content Moderation Filter: Error: {error_message}", "ERROR")
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"{status_message}",
"done": True,
},
}
)
return
moderation_result = response["message"]["content"].strip().lower()
if self.valves.moderation_model.startswith("llama-guard3"):
if moderation_result.startswith("unsafe"):
categories = moderation_result.split("\n")[1:]
classes = {
"s1": "Violent Crimes",
"s2": "Non-Violent Crimes",
"s3": "Sex-Related Crimes",
"s4": "Child Sexual Exploitation",
"s5": "Defamation",
"s6": "Specialized Advice",
"s7": "Privacy",
"s8": "Intellectual Property",
"s9": "Indiscriminate Weapons",
"s10": "Hate",
"s11": "Suicide & Self-Harm",
"s12": "Sexual Content",
"s13": "Elections",
}
category_descriptions = [
classes.get(cat.strip(), cat) for cat in categories
]
self.log(
f"### Content Moderation Filter ###\n"
f"--- A message was flagged as content violation ---\n"
f"User Message: {content}\n"
f"Model: {self.valves.moderation_model}\n"
f"Flagged Categories: {category_descriptions}"
)
raise Exception(
f"Your request has been denied due to a content violation: {', '.join(category_descriptions)}"
)
elif self.valves.moderation_model.startswith("shieldgemma"):
if moderation_result == "yes":
self.log(
f"### Content Moderation Filter ###\n"
f"--- A message was flagged as content violation ---\n"
f"User Message: {content}\n"
f"Model: {self.valves.moderation_model}"
)
raise Exception(
"Your request has been denied due to a content violation"
)
async def inlet(
self,
body: dict,
__user__: Optional[dict] = None,
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
) -> dict:
if __user__ is not None and (
not __user__.get("role") == "admin" or self.valves.enabled_for_admins
):
user_message = get_last_user_message(body["messages"])
await self.check_content(user_message, __event_emitter__)
return body