Whitepaper
Docs
Sign In
Function
Function
filter
v1.0
Content Moderation Filter
Function ID
content_moderation_filter
Creator
@jnkstr29
Downloads
102+
Ollama Content Moderation Filter based on LLamaGuard3 or ShieldGemma
Get
README
No README available
Function Code
Show
""" title: Content Violation Filter author: jannikstdl date: 2024-10-13 version: 1.0 description: A filter function that checks messages for content violations using Ollama moderation models 'llama-guard3' or 'shieldgemma' and flags them accordingly. ShieldGemma: https://ollama.com/library/shieldgemma LlamaGuard3: https://ollama.com/library/llama-guard3 """ from typing import Optional, Callable, Any, Awaitable from pydantic import BaseModel, Field import requests from open_webui.utils.misc import get_last_user_message class Filter: class Valves(BaseModel): priority: int = Field( default=0, description="Priority level for the filter operations." ) ollama_api_base_url: str = Field( default="http://localhost:11434", description="Base URL for the Ollama API.", ) moderation_model: str = Field( default="llama-guard3:latest", description=( "Model to use for content moderation. Options are " "'llama-guard3' and 'shieldgemma'." ), ) enabled_for_admins: bool = Field( default=True, description="Whether moderation is enabled for admin users." ) enable_logs: bool = Field( default=False, description="Whether to print logs to console (Error/Info)." ) def __init__(self) -> None: self.valves = self.Valves() def log(self, message: str, log_type: str = "INFO") -> None: if self.valves.enable_logs: if log_type == "ERROR": # ANSI code for light red text print(f"\033[91m{log_type}: {message}\033[0m") else: print(f"{log_type}: {message}") async def check_content( self, content: str, __event_emitter__: Callable[[Any], Awaitable[None]], ) -> None: headers = { "Content-Type": "application/json", } payload = { "model": self.valves.moderation_model, "messages": [{"role": "user", "content": content}], "stream": False, } try: r = requests.post( url=f"{self.valves.ollama_api_base_url}/api/chat", json=payload, headers=headers, ) r.raise_for_status() response = r.json() except Exception as e: error_message = str(e) status_message = f"Content Moderation Filter: ERROR - Check your filter configuration or see logs if enabled in filter options for more infos." self.log(f"Content Moderation Filter: Error: {error_message}", "ERROR") await __event_emitter__( { "type": "status", "data": { "description": f"{status_message}", "done": True, }, } ) return moderation_result = response["message"]["content"].strip().lower() if self.valves.moderation_model.startswith("llama-guard3"): if moderation_result.startswith("unsafe"): categories = moderation_result.split("\n")[1:] classes = { "s1": "Violent Crimes", "s2": "Non-Violent Crimes", "s3": "Sex-Related Crimes", "s4": "Child Sexual Exploitation", "s5": "Defamation", "s6": "Specialized Advice", "s7": "Privacy", "s8": "Intellectual Property", "s9": "Indiscriminate Weapons", "s10": "Hate", "s11": "Suicide & Self-Harm", "s12": "Sexual Content", "s13": "Elections", } category_descriptions = [ classes.get(cat.strip(), cat) for cat in categories ] self.log( f"### Content Moderation Filter ###\n" f"--- A message was flagged as content violation ---\n" f"User Message: {content}\n" f"Model: {self.valves.moderation_model}\n" f"Flagged Categories: {category_descriptions}" ) raise Exception( f"Your request has been denied due to a content violation: {', '.join(category_descriptions)}" ) elif self.valves.moderation_model.startswith("shieldgemma"): if moderation_result == "yes": self.log( f"### Content Moderation Filter ###\n" f"--- A message was flagged as content violation ---\n" f"User Message: {content}\n" f"Model: {self.valves.moderation_model}" ) raise Exception( "Your request has been denied due to a content violation" ) async def inlet( self, body: dict, __user__: Optional[dict] = None, __event_emitter__: Callable[[Any], Awaitable[None]] = None, ) -> dict: if __user__ is not None and ( not __user__.get("role") == "admin" or self.valves.enabled_for_admins ): user_message = get_last_user_message(body["messages"]) await self.check_content(user_message, __event_emitter__) return body