Whitepaper
Docs
Sign In
Function
Function
filter
v0.2
Message bad words Filter
Function ID
message_bad_words_filter
Creator
@yangyutin753
Downloads
162+
Use Message bad words Filter to protect our APIs
Get
README
No README available
Function Code
Show
""" title: Message bad words Filter author: Yanyutin753 author_url: https://github.com/Yanyutin753 funding_url: https://github.com/open-webui version: 0.2 """ import time from typing import Optional from pydantic import BaseModel class Filter: class hashNode: def __init__(self): self.End = False self.Results = [] self.m_values = {} self.min_flag = 0xFFFF self.max_flag = 0 def Add(self, c, node3): if self.min_flag > c: self.min_flag = c if self.max_flag < c: self.max_flag = c self.m_values[c] = node3 def SetResults(self, index): if not self.End: self.End = True if not (index in self.Results): self.Results.append(index) def HasKey(self, c): return c in self.m_values def TryGetValue(self, c): if self.min_flag <= c <= self.max_flag: if c in self.m_values: return self.m_values[c] return None class trieNode: def __init__(self): self.Index = 0 self.Index = 0 self.Layer = 0 self.End = False self.Char = "" self.Results = [] self.m_values = {} self.Failure = None self.Parent = None def Add(self, c): if c in self.m_values: return self.m_values[c] node = Filter.trieNode() node.Parent = self node.Char = c self.m_values[c] = node return node def SetResults(self, index): if not self.End: self.End = True self.Results.append(index) class wordsSearch: def __init__(self): self._first = {} self._keywords = [] self._indexs = [] def SetKeywords(self, keywords): self._keywords = keywords self._indexs = [] for i in range(len(keywords)): self._indexs.append(i) root = Filter.trieNode() allNodeLayer = {} for i in range(len(self._keywords)): p = self._keywords[i] nd = root for j in range(len(p)): nd = nd.Add(ord(p[j])) if nd.Layer == 0: nd.Layer = j + 1 if nd.Layer in allNodeLayer: allNodeLayer[nd.Layer].append(nd) else: allNodeLayer[nd.Layer] = [] allNodeLayer[nd.Layer].append(nd) nd.SetResults(i) allNode = [root] for key in allNodeLayer.keys(): for nd in allNodeLayer[key]: allNode.append(nd) for i in range(len(allNode)): if i == 0: continue nd = allNode[i] nd.Index = i r = nd.Parent.Failure c = nd.Char while r is not None and c not in r.m_values: r = r.Failure if r is None: nd.Failure = root else: nd.Failure = r.m_values[c] for key2 in nd.Failure.Results: nd.SetResults(key2) root.Failure = root allNode2 = [] for i in range(len(allNode)): allNode2.append(Filter.hashNode()) for i in range(len(allNode2)): oldNode = allNode[i] newNode = allNode2[i] for key in oldNode.m_values: index = oldNode.m_values[key].Index newNode.Add(key, allNode2[index]) for index in range(len(oldNode.Results)): item = oldNode.Results[index] newNode.SetResults(item) oldNode = oldNode.Failure while oldNode != root: for key in oldNode.m_values: if not newNode.HasKey(key): index = oldNode.m_values[key].Index newNode.Add(key, allNode2[index]) for index in range(len(oldNode.Results)): item = oldNode.Results[index] newNode.SetResults(item) oldNode = oldNode.Failure self._first = allNode2[0] def FindFirst(self, text): ptr = None for index in range(len(text)): t = ord(text[index]) if ptr is None: tn = self._first.TryGetValue(t) else: tn = ptr.TryGetValue(t) if tn is None: tn = self._first.TryGetValue(t) if tn is not None: if tn.End: item = tn.Results[0] keyword = self._keywords[item] return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item], } ptr = tn return None def FindAll(self, text): ptr = None key_list = [] for index in range(len(text)): t = ord(text[index]) if ptr is None: tn = self._first.TryGetValue(t) else: tn = ptr.TryGetValue(t) if tn is None: tn = self._first.TryGetValue(t) if tn is not None: if tn.End: for j in range(len(tn.Results)): item = tn.Results[j] keyword = self._keywords[item] key_list.append( { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item], } ) ptr = tn return key_list def ContainsAny(self, text): ptr = None for index in range(len(text)): t = ord(text[index]) if ptr is None: tn = self._first.TryGetValue(t) else: tn = ptr.TryGetValue(t) if tn is None: tn = self._first.TryGetValue(t) if tn is not None: if tn.End: return True ptr = tn return False def Replace(self, text, replaceChar="*"): result = list(text) ptr = None for i in range(len(text)): t = ord(text[i]) if ptr is None: tn = self._first.TryGetValue(t) else: tn = ptr.TryGetValue(t) if tn is None: tn = self._first.TryGetValue(t) if tn is not None: if tn.End: maxLength = len(self._keywords[tn.Results[0]]) start = i + 1 - maxLength for j in range(start, i + 1): result[j] = replaceChar ptr = tn return "".join(result) class Valves(BaseModel): ENABLE_MESSAGE_FILTER: Optional[bool] = True CHAT_FILTER_WORDS: Optional[str] = "fuck,SB" ENABLE_REPLACE_FILTER_WORDS: Optional[bool] = True REPLACE_FILTER_WORDS: Optional[str] = "*" def __init__(self): self.valves = self.Valves() self.TEM_CHAT_FILTER_WORDS = None self.search = None def inlet(self, body: dict, __user__: Optional[dict] = None) -> dict: # Modify the request body or validate it before processing by the chat completion API. # This function is the pre-processor for the API where various checks on the input can be performed. # It can also modify the request before sending it to the API. print(f"inlet:{__name__}") print(f"inlet:body:{body}") print(f"inlet:user:{__user__}") print(f"{self.valves}") if ( self.valves.ENABLE_MESSAGE_FILTER and self.valves.CHAT_FILTER_WORDS != self.TEM_CHAT_FILTER_WORDS ): self.search = self.wordsSearch() self.search.SetKeywords(str(self.valves.CHAT_FILTER_WORDS).split(",")) self.TEM_CHAT_FILTER_WORDS = self.valves.CHAT_FILTER_WORDS if __user__.get("role", "admin") in ["user", "admin"]: messages = body.get("messages", []) if ( self.valves.ENABLE_MESSAGE_FILTER and self.search and self.valves.CHAT_FILTER_WORDS ): start_time = time.time() for message in reversed(messages): if message.get("role") == "user": content = message.get("content") if not isinstance(content, list): if not self.valves.ENABLE_REPLACE_FILTER_WORDS: print( str(self.valves.ENABLE_REPLACE_FILTER_WORDS) + self.valves.CHAT_FILTER_WORDS ) filter_condition = self.search.FindFirst(content) if filter_condition: filter_word = filter_condition["Keyword"] detail_message = ( f"Open WebUI: Your message contains bad words (`{filter_word}`) " "and cannot be sent. Please create a new topic and try again." ) print( f"The time taken to check the filter words: {time.time() - start_time:.6f}s" ) raise Exception(detail_message) else: message["content"] = self.search.Replace( content, self.valves.REPLACE_FILTER_WORDS ) print( f"Replace bad words in content: {message['content']}" ) print( f"The time taken to check the filter words: {time.time() - start_time:.6f}s" ) return body def outlet(self, body: dict, __user__: Optional[dict] = None) -> dict: # Modify or analyze the response body after processing by the API. # This function is the post-processor for the API, which can be used to modify the response # or perform additional checks and analytics. print(f"outlet:{__name__}") print(f"outlet:body:{body}") print(f"outlet:user:{__user__}") return body