"""
title: Message bad words Filter
author: Yanyutin753
author_url: https://github.com/Yanyutin753
funding_url: https://github.com/open-webui
version: 0.2
"""
import time
from typing import Optional
from pydantic import BaseModel
class Filter:
class hashNode:
def __init__(self):
self.End = False
self.Results = []
self.m_values = {}
self.min_flag = 0xFFFF
self.max_flag = 0
def Add(self, c, node3):
if self.min_flag > c:
self.min_flag = c
if self.max_flag < c:
self.max_flag = c
self.m_values[c] = node3
def SetResults(self, index):
if not self.End:
self.End = True
if not (index in self.Results):
self.Results.append(index)
def HasKey(self, c):
return c in self.m_values
def TryGetValue(self, c):
if self.min_flag <= c <= self.max_flag:
if c in self.m_values:
return self.m_values[c]
return None
class trieNode:
def __init__(self):
self.Index = 0
self.Index = 0
self.Layer = 0
self.End = False
self.Char = ""
self.Results = []
self.m_values = {}
self.Failure = None
self.Parent = None
def Add(self, c):
if c in self.m_values:
return self.m_values[c]
node = Filter.trieNode()
node.Parent = self
node.Char = c
self.m_values[c] = node
return node
def SetResults(self, index):
if not self.End:
self.End = True
self.Results.append(index)
class wordsSearch:
def __init__(self):
self._first = {}
self._keywords = []
self._indexs = []
def SetKeywords(self, keywords):
self._keywords = keywords
self._indexs = []
for i in range(len(keywords)):
self._indexs.append(i)
root = Filter.trieNode()
allNodeLayer = {}
for i in range(len(self._keywords)):
p = self._keywords[i]
nd = root
for j in range(len(p)):
nd = nd.Add(ord(p[j]))
if nd.Layer == 0:
nd.Layer = j + 1
if nd.Layer in allNodeLayer:
allNodeLayer[nd.Layer].append(nd)
else:
allNodeLayer[nd.Layer] = []
allNodeLayer[nd.Layer].append(nd)
nd.SetResults(i)
allNode = [root]
for key in allNodeLayer.keys():
for nd in allNodeLayer[key]:
allNode.append(nd)
for i in range(len(allNode)):
if i == 0:
continue
nd = allNode[i]
nd.Index = i
r = nd.Parent.Failure
c = nd.Char
while r is not None and c not in r.m_values:
r = r.Failure
if r is None:
nd.Failure = root
else:
nd.Failure = r.m_values[c]
for key2 in nd.Failure.Results:
nd.SetResults(key2)
root.Failure = root
allNode2 = []
for i in range(len(allNode)):
allNode2.append(Filter.hashNode())
for i in range(len(allNode2)):
oldNode = allNode[i]
newNode = allNode2[i]
for key in oldNode.m_values:
index = oldNode.m_values[key].Index
newNode.Add(key, allNode2[index])
for index in range(len(oldNode.Results)):
item = oldNode.Results[index]
newNode.SetResults(item)
oldNode = oldNode.Failure
while oldNode != root:
for key in oldNode.m_values:
if not newNode.HasKey(key):
index = oldNode.m_values[key].Index
newNode.Add(key, allNode2[index])
for index in range(len(oldNode.Results)):
item = oldNode.Results[index]
newNode.SetResults(item)
oldNode = oldNode.Failure
self._first = allNode2[0]
def FindFirst(self, text):
ptr = None
for index in range(len(text)):
t = ord(text[index])
if ptr is None:
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if tn is None:
tn = self._first.TryGetValue(t)
if tn is not None:
if tn.End:
item = tn.Results[0]
keyword = self._keywords[item]
return {
"Keyword": keyword,
"Success": True,
"End": index,
"Start": index + 1 - len(keyword),
"Index": self._indexs[item],
}
ptr = tn
return None
def FindAll(self, text):
ptr = None
key_list = []
for index in range(len(text)):
t = ord(text[index])
if ptr is None:
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if tn is None:
tn = self._first.TryGetValue(t)
if tn is not None:
if tn.End:
for j in range(len(tn.Results)):
item = tn.Results[j]
keyword = self._keywords[item]
key_list.append(
{
"Keyword": keyword,
"Success": True,
"End": index,
"Start": index + 1 - len(keyword),
"Index": self._indexs[item],
}
)
ptr = tn
return key_list
def ContainsAny(self, text):
ptr = None
for index in range(len(text)):
t = ord(text[index])
if ptr is None:
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if tn is None:
tn = self._first.TryGetValue(t)
if tn is not None:
if tn.End:
return True
ptr = tn
return False
def Replace(self, text, replaceChar="*"):
result = list(text)
ptr = None
for i in range(len(text)):
t = ord(text[i])
if ptr is None:
tn = self._first.TryGetValue(t)
else:
tn = ptr.TryGetValue(t)
if tn is None:
tn = self._first.TryGetValue(t)
if tn is not None:
if tn.End:
maxLength = len(self._keywords[tn.Results[0]])
start = i + 1 - maxLength
for j in range(start, i + 1):
result[j] = replaceChar
ptr = tn
return "".join(result)
class Valves(BaseModel):
ENABLE_MESSAGE_FILTER: Optional[bool] = True
CHAT_FILTER_WORDS: Optional[str] = "fuck,SB"
ENABLE_REPLACE_FILTER_WORDS: Optional[bool] = True
REPLACE_FILTER_WORDS: Optional[str] = "*"
def __init__(self):
self.valves = self.Valves()
self.TEM_CHAT_FILTER_WORDS = None
self.search = None
def inlet(self, body: dict, __user__: Optional[dict] = None) -> dict:
# Modify the request body or validate it before processing by the chat completion API.
# This function is the pre-processor for the API where various checks on the input can be performed.
# It can also modify the request before sending it to the API.
print(f"inlet:{__name__}")
print(f"inlet:body:{body}")
print(f"inlet:user:{__user__}")
print(f"{self.valves}")
if (
self.valves.ENABLE_MESSAGE_FILTER
and self.valves.CHAT_FILTER_WORDS != self.TEM_CHAT_FILTER_WORDS
):
self.search = self.wordsSearch()
self.search.SetKeywords(str(self.valves.CHAT_FILTER_WORDS).split(","))
self.TEM_CHAT_FILTER_WORDS = self.valves.CHAT_FILTER_WORDS
if __user__.get("role", "admin") in ["user", "admin"]:
messages = body.get("messages", [])
if (
self.valves.ENABLE_MESSAGE_FILTER
and self.search
and self.valves.CHAT_FILTER_WORDS
):
start_time = time.time()
for message in reversed(messages):
if message.get("role") == "user":
content = message.get("content")
if not isinstance(content, list):
if not self.valves.ENABLE_REPLACE_FILTER_WORDS:
print(
str(self.valves.ENABLE_REPLACE_FILTER_WORDS)
+ self.valves.CHAT_FILTER_WORDS
)
filter_condition = self.search.FindFirst(content)
if filter_condition:
filter_word = filter_condition["Keyword"]
detail_message = (
f"Open WebUI: Your message contains bad words (`{filter_word}`) "
"and cannot be sent. Please create a new topic and try again."
)
print(
f"The time taken to check the filter words: {time.time() - start_time:.6f}s"
)
raise Exception(detail_message)
else:
message["content"] = self.search.Replace(
content, self.valves.REPLACE_FILTER_WORDS
)
print(
f"Replace bad words in content: {message['content']}"
)
print(
f"The time taken to check the filter words: {time.time() - start_time:.6f}s"
)
return body
def outlet(self, body: dict, __user__: Optional[dict] = None) -> dict:
# Modify or analyze the response body after processing by the API.
# This function is the post-processor for the API, which can be used to modify the response
# or perform additional checks and analytics.
print(f"outlet:{__name__}")
print(f"outlet:body:{body}")
print(f"outlet:user:{__user__}")
return body