Function
filter
v1.0
Content Moderation Filter
Ollama Content Moderation Filter based on LLamaGuard3 or ShieldGemma
Function ID
content_moderation_filter
Creator
@jnkstr29
Downloads
16+

Function Content
python
"""
title: Content Violation Filter
author: jannikstdl
date: 2024-10-13
version: 1.0
description: A filter function that checks messages for content violations using Ollama moderation models 'llama-guard3' or 'shieldgemma' and flags them accordingly.
ShieldGemma: https://ollama.com/library/shieldgemma
LlamaGuard3: https://ollama.com/library/llama-guard3
"""

from typing import Optional, Callable, Any, Awaitable
from pydantic import BaseModel, Field
import requests
from open_webui.utils.misc import get_last_user_message


class Filter:
    class Valves(BaseModel):
        priority: int = Field(
            default=0, description="Priority level for the filter operations."
        )
        ollama_api_base_url: str = Field(
            default="http://localhost:11434",
            description="Base URL for the Ollama API.",
        )
        moderation_model: str = Field(
            default="llama-guard3:latest",
            description=(
                "Model to use for content moderation. Options are "
                "'llama-guard3' and 'shieldgemma'."
            ),
        )
        enabled_for_admins: bool = Field(
            default=True, description="Whether moderation is enabled for admin users."
        )
        enable_logs: bool = Field(
            default=False, description="Whether to print logs to console (Error/Info)."
        )

    def __init__(self) -> None:
        self.valves = self.Valves()

    def log(self, message: str, log_type: str = "INFO") -> None:
        if self.valves.enable_logs:
            if log_type == "ERROR":
                # ANSI code for light red text
                print(f"\033[91m{log_type}: {message}\033[0m")
            else:
                print(f"{log_type}: {message}")

    async def check_content(
        self,
        content: str,
        __event_emitter__: Callable[[Any], Awaitable[None]],
    ) -> None:
        headers = {
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.valves.moderation_model,
            "messages": [{"role": "user", "content": content}],
            "stream": False,
        }
        try:
            r = requests.post(
                url=f"{self.valves.ollama_api_base_url}/api/chat",
                json=payload,
                headers=headers,
            )
            r.raise_for_status()
            response = r.json()
        except Exception as e:
            error_message = str(e)
            status_message = f"Content Moderation Filter: ERROR - Check your filter configuration or see logs if enabled in filter options for more infos."
            self.log(f"Content Moderation Filter: Error: {error_message}", "ERROR")
            await __event_emitter__(
                {
                    "type": "status",
                    "data": {
                        "description": f"{status_message}",
                        "done": True,
                    },
                }
            )
            return

        moderation_result = response["message"]["content"].strip().lower()

        if self.valves.moderation_model.startswith("llama-guard3"):
            if moderation_result.startswith("unsafe"):
                categories = moderation_result.split("\n")[1:]
                classes = {
                    "s1": "Violent Crimes",
                    "s2": "Non-Violent Crimes",
                    "s3": "Sex-Related Crimes",
                    "s4": "Child Sexual Exploitation",
                    "s5": "Defamation",
                    "s6": "Specialized Advice",
                    "s7": "Privacy",
                    "s8": "Intellectual Property",
                    "s9": "Indiscriminate Weapons",
                    "s10": "Hate",
                    "s11": "Suicide & Self-Harm",
                    "s12": "Sexual Content",
                    "s13": "Elections",
                }
                category_descriptions = [
                    classes.get(cat.strip(), cat) for cat in categories
                ]
                self.log(
                    f"### Content Moderation Filter ###\n"
                    f"--- A message was flagged as content violation ---\n"
                    f"User Message: {content}\n"
                    f"Model: {self.valves.moderation_model}\n"
                    f"Flagged Categories: {category_descriptions}"
                )
                raise Exception(
                    f"Your request has been denied due to a content violation: {', '.join(category_descriptions)}"
                )
        elif self.valves.moderation_model.startswith("shieldgemma"):
            if moderation_result == "yes":
                self.log(
                    f"### Content Moderation Filter ###\n"
                    f"--- A message was flagged as content violation ---\n"
                    f"User Message: {content}\n"
                    f"Model: {self.valves.moderation_model}"
                )
                raise Exception(
                    "Your request has been denied due to a content violation"
                )

    async def inlet(
        self,
        body: dict,
        __user__: Optional[dict] = None,
        __event_emitter__: Callable[[Any], Awaitable[None]] = None,
    ) -> dict:
        if __user__ is not None and (
            not __user__.get("role") == "admin" or self.valves.enabled_for_admins
        ):
            user_message = get_last_user_message(body["messages"])
            await self.check_content(user_message, __event_emitter__)
        return body