Function
filter
v0.2
Message bad words Filter
Use Message bad words Filter to protect our APIs
Function ID
message_bad_words_filter
Downloads
96+

Function Content
python
"""
title: Message bad words Filter
author: Yanyutin753
author_url: https://github.com/Yanyutin753
funding_url: https://github.com/open-webui
version: 0.2
"""

import time
from typing import Optional

from pydantic import BaseModel


class Filter:
    class hashNode:
        def __init__(self):
            self.End = False
            self.Results = []
            self.m_values = {}
            self.min_flag = 0xFFFF
            self.max_flag = 0

        def Add(self, c, node3):
            if self.min_flag > c:
                self.min_flag = c
            if self.max_flag < c:
                self.max_flag = c
            self.m_values[c] = node3

        def SetResults(self, index):
            if not self.End:
                self.End = True
            if not (index in self.Results):
                self.Results.append(index)

        def HasKey(self, c):
            return c in self.m_values

        def TryGetValue(self, c):
            if self.min_flag <= c <= self.max_flag:
                if c in self.m_values:
                    return self.m_values[c]
            return None

    class trieNode:
        def __init__(self):
            self.Index = 0
            self.Index = 0
            self.Layer = 0
            self.End = False
            self.Char = ""
            self.Results = []
            self.m_values = {}
            self.Failure = None
            self.Parent = None

        def Add(self, c):
            if c in self.m_values:
                return self.m_values[c]
            node = Filter.trieNode()
            node.Parent = self
            node.Char = c
            self.m_values[c] = node
            return node

        def SetResults(self, index):
            if not self.End:
                self.End = True
            self.Results.append(index)

    class wordsSearch:
        def __init__(self):
            self._first = {}
            self._keywords = []
            self._indexs = []

        def SetKeywords(self, keywords):
            self._keywords = keywords
            self._indexs = []
            for i in range(len(keywords)):
                self._indexs.append(i)

            root = Filter.trieNode()
            allNodeLayer = {}

            for i in range(len(self._keywords)):
                p = self._keywords[i]
                nd = root
                for j in range(len(p)):
                    nd = nd.Add(ord(p[j]))
                    if nd.Layer == 0:
                        nd.Layer = j + 1
                        if nd.Layer in allNodeLayer:
                            allNodeLayer[nd.Layer].append(nd)
                        else:
                            allNodeLayer[nd.Layer] = []
                            allNodeLayer[nd.Layer].append(nd)
                nd.SetResults(i)

            allNode = [root]
            for key in allNodeLayer.keys():
                for nd in allNodeLayer[key]:
                    allNode.append(nd)

            for i in range(len(allNode)):
                if i == 0:
                    continue
                nd = allNode[i]
                nd.Index = i
                r = nd.Parent.Failure
                c = nd.Char
                while r is not None and c not in r.m_values:
                    r = r.Failure
                if r is None:
                    nd.Failure = root
                else:
                    nd.Failure = r.m_values[c]
                    for key2 in nd.Failure.Results:
                        nd.SetResults(key2)
            root.Failure = root

            allNode2 = []
            for i in range(len(allNode)):
                allNode2.append(Filter.hashNode())

            for i in range(len(allNode2)):
                oldNode = allNode[i]
                newNode = allNode2[i]

                for key in oldNode.m_values:
                    index = oldNode.m_values[key].Index
                    newNode.Add(key, allNode2[index])

                for index in range(len(oldNode.Results)):
                    item = oldNode.Results[index]
                    newNode.SetResults(item)

                oldNode = oldNode.Failure
                while oldNode != root:
                    for key in oldNode.m_values:
                        if not newNode.HasKey(key):
                            index = oldNode.m_values[key].Index
                            newNode.Add(key, allNode2[index])
                    for index in range(len(oldNode.Results)):
                        item = oldNode.Results[index]
                        newNode.SetResults(item)
                    oldNode = oldNode.Failure

            self._first = allNode2[0]

        def FindFirst(self, text):
            ptr = None
            for index in range(len(text)):
                t = ord(text[index])
                if ptr is None:
                    tn = self._first.TryGetValue(t)
                else:
                    tn = ptr.TryGetValue(t)
                    if tn is None:
                        tn = self._first.TryGetValue(t)

                if tn is not None:
                    if tn.End:
                        item = tn.Results[0]
                        keyword = self._keywords[item]
                        return {
                            "Keyword": keyword,
                            "Success": True,
                            "End": index,
                            "Start": index + 1 - len(keyword),
                            "Index": self._indexs[item],
                        }
                ptr = tn
            return None

        def FindAll(self, text):
            ptr = None
            key_list = []

            for index in range(len(text)):
                t = ord(text[index])
                if ptr is None:
                    tn = self._first.TryGetValue(t)
                else:
                    tn = ptr.TryGetValue(t)
                    if tn is None:
                        tn = self._first.TryGetValue(t)

                if tn is not None:
                    if tn.End:
                        for j in range(len(tn.Results)):
                            item = tn.Results[j]
                            keyword = self._keywords[item]
                            key_list.append(
                                {
                                    "Keyword": keyword,
                                    "Success": True,
                                    "End": index,
                                    "Start": index + 1 - len(keyword),
                                    "Index": self._indexs[item],
                                }
                            )
                ptr = tn
            return key_list

        def ContainsAny(self, text):
            ptr = None
            for index in range(len(text)):
                t = ord(text[index])
                if ptr is None:
                    tn = self._first.TryGetValue(t)
                else:
                    tn = ptr.TryGetValue(t)
                    if tn is None:
                        tn = self._first.TryGetValue(t)

                if tn is not None:
                    if tn.End:
                        return True
                ptr = tn
            return False

        def Replace(self, text, replaceChar="*"):
            result = list(text)

            ptr = None
            for i in range(len(text)):
                t = ord(text[i])
                if ptr is None:
                    tn = self._first.TryGetValue(t)
                else:
                    tn = ptr.TryGetValue(t)
                    if tn is None:
                        tn = self._first.TryGetValue(t)

                if tn is not None:
                    if tn.End:
                        maxLength = len(self._keywords[tn.Results[0]])
                        start = i + 1 - maxLength
                        for j in range(start, i + 1):
                            result[j] = replaceChar
                ptr = tn
            return "".join(result)

    class Valves(BaseModel):
        ENABLE_MESSAGE_FILTER: Optional[bool] = True
        CHAT_FILTER_WORDS: Optional[str] = "fuck,SB"
        ENABLE_REPLACE_FILTER_WORDS: Optional[bool] = True
        REPLACE_FILTER_WORDS: Optional[str] = "*"

    def __init__(self):
        self.valves = self.Valves()
        self.TEM_CHAT_FILTER_WORDS = None
        self.search = None

    def inlet(self, body: dict, __user__: Optional[dict] = None) -> dict:
        # Modify the request body or validate it before processing by the chat completion API.
        # This function is the pre-processor for the API where various checks on the input can be performed.
        # It can also modify the request before sending it to the API.
        print(f"inlet:{__name__}")
        print(f"inlet:body:{body}")
        print(f"inlet:user:{__user__}")
        print(f"{self.valves}")
        if (
            self.valves.ENABLE_MESSAGE_FILTER
            and self.valves.CHAT_FILTER_WORDS != self.TEM_CHAT_FILTER_WORDS
        ):
            self.search = self.wordsSearch()
            self.search.SetKeywords(str(self.valves.CHAT_FILTER_WORDS).split(","))
            self.TEM_CHAT_FILTER_WORDS = self.valves.CHAT_FILTER_WORDS

        if __user__.get("role", "admin") in ["user", "admin"]:
            messages = body.get("messages", [])
            if (
                self.valves.ENABLE_MESSAGE_FILTER
                and self.search
                and self.valves.CHAT_FILTER_WORDS
            ):
                start_time = time.time()
                for message in reversed(messages):
                    if message.get("role") == "user":
                        content = message.get("content")
                        if not isinstance(content, list):
                            if not self.valves.ENABLE_REPLACE_FILTER_WORDS:
                                print(
                                    str(self.valves.ENABLE_REPLACE_FILTER_WORDS)
                                    + self.valves.CHAT_FILTER_WORDS
                                )
                                filter_condition = self.search.FindFirst(content)
                                if filter_condition:
                                    filter_word = filter_condition["Keyword"]
                                    detail_message = (
                                        f"Open WebUI: Your message contains bad words (`{filter_word}`) "
                                        "and cannot be sent. Please create a new topic and try again."
                                    )
                                    print(
                                        f"The time taken to check the filter words: {time.time() - start_time:.6f}s"
                                    )
                                    raise Exception(detail_message)
                            else:
                                message["content"] = self.search.Replace(
                                    content, self.valves.REPLACE_FILTER_WORDS
                                )
                                print(
                                    f"Replace bad words in content: {message['content']}"
                                )
                print(
                    f"The time taken to check the filter words: {time.time() - start_time:.6f}s"
                )
        return body

    def outlet(self, body: dict, __user__: Optional[dict] = None) -> dict:
        # Modify or analyze the response body after processing by the API.
        # This function is the post-processor for the API, which can be used to modify the response
        # or perform additional checks and analytics.
        print(f"outlet:{__name__}")
        print(f"outlet:body:{body}")
        print(f"outlet:user:{__user__}")

        return body