Function
filter
v0.0.2
Narrative Memory Filter
Super hacky, very basic automatic narrative memory filter for OpenWebUI, that may or may not actually enhance narrative generation!
Function ID
narrative_memory_filter
Creator
@projectmoon
Downloads
254+

Function Content
python
"""
title: Memory Filter
author: projectmoon
author_url: https://git.agnos.is/projectmoon/open-webui-filters
version: 0.0.2
license: AGPL-3.0+
required_open_webui_version: 0.3.9
"""

# Documentation: https://git.agnos.is/projectmoon/open-webui-filters
#
# Changelog:
#  0.0.1 - Initial release, proof of concept
#  0.0.2 - Slightly less hacky (but still hacky) way of getting chat IDs

# System imports
import asyncio
import hashlib
import uuid
import json

from typing import Optional, List, Dict, Callable, Any, NewType, Tuple, Awaitable
from typing_extensions import TypedDict, NotRequired

# Libraries available to OpenWebUI
import markdown
from bs4 import BeautifulSoup
from pydantic import BaseModel as PydanticBaseModel, Field
import chromadb
from chromadb import Collection as ChromaCollection
from chromadb.api.types import Document as ChromaDocument

# OpenWebUI imports
from config import CHROMA_CLIENT
from apps.rag.main import app
from utils.misc import get_last_user_message, get_last_assistant_message
from main import generate_chat_completions

# OpenWebUI aliases
EMBEDDING_FUNCTION = app.state.EMBEDDING_FUNCTION

# Custom type declarations
EmbeddingFunc = NewType('EmbeddingFunc', Callable[[str], List[Any]])

# Prompts
ENRICHMENT_SUMMARY_PROMPT = """
You are tasked with analyzing the following Characters and Plot Details
sections and reducing this set of information into lists of the most
important points needed for the continuation of the narrative you are
writing. Remove duplicate or conflicting information. If there is conflicting
information, decide on something consistent and interesting for the story.

Your reply must consist of two sections: Characters and Plot Details. These
sections must be markdown ### Headers. Under each header, respond with a
list of bullet points. Each bullet point must be one piece of relevant information.

Limit each bullet point to one sentence. Respond ONLY with the Characters and
Plot Details sections, with the bullet points under them, and nothing else.
Do not respond with any commentary. ONLY respond with the bullet points.
""".replace("\n", " ").strip()

QUERY_PROMPT = """
You are tasked with generating questions for a vector database
about the narrative presented below. The queries must be questions about
parts of the story that you need more details on. The questions must be
about past events in the story, or questions about the characters involved
or mentioned in the scene (their appearance, mental state, past actions, etc).

Your reply must consist of two sections: Characters and Plot Details. These
sections must be markdown ### Headers. Under each header, respond with a
list of bullet points. Each bullet point must be a single question or sentence
that will be given to the vector database. Generate a maximum of 5 Character
queries and 5 Plot Detail queries.

Limit each bullet point to one sentence. Respond ONLY with the Characters and
Plot Details sections, with the bullet points under them, and nothing else.
Do not respond with any commentary. ONLY respond with the bullet points.
""".replace("\n", " ").strip()

SUMMARIZER_PROMPT = """
You are a narrative summarizer. Summarize the given message as if it's
part of a story. Your response must have two separate sections: Characters
and Plot Details. These sections should be markdown ### Headers. Under each
section, respond with a list of bullet points. This knowledge will be stored
in a vector database for your future use.

The Characters section should note any characters in the scene, and important
things that happen to them. Describe the characters' appearances, actions,
mental states, and emotional states. The Plot Details section should have a
list of important plot details in this scene.

The bullet points you generate must be in the context of storing future
knowledge about the story. Do not focus on useless details: only focus on
information that you could lose in the future as your context window shifts.

Limit each bullet point to one sentence. The sentence MUST be in the PAST TENSE.
Respond ONLY with the Characters and Plot Details sections, with the bullet points
under them, and nothing else. Do not respond with any commentary. ONLY respond with
the bullet points.
""".replace("\n", " ").strip()

class Message(TypedDict):
    id: NotRequired[str]
    role: str
    content: str

class MessageInsertMetadata(TypedDict):
    role: str
    chapter: str

class MessageInsert(TypedDict):
    message_id: str
    content: str
    metadata: MessageInsertMetadata
    embeddings: List[Any]


class BaseModel(PydanticBaseModel):
    class Config:
        arbitrary_types_allowed = True

class SummarizerResponse(BaseModel):
    characters: List[str]
    plot: List[str]


class Summarizer(BaseModel):
    message: str
    model: str
    prompt: str = SUMMARIZER_PROMPT

    def extract_section(self, soup: BeautifulSoup, section_name: str) -> List[str]:
        for h3 in soup.find_all('h3'):
            heading = h3.get_text().strip()
            if heading != section_name:
                continue

            # Find the next sibling which should be a 
    or
      ul = h3.find_next_sibling('ul') ol = h3.find_next_sibling('ol') list_items = [] if ul: list_items = [li.get_text().strip() for li in ul.find_all('li')] elif ol: list_items = [li.get_text().strip() for li in ol.find_all('li')] return list_items return [] def sanitize_section(self, bullet_points: List[str]) -> List[str]: return [ bullet.strip().lstrip("-*•123456789").strip() for bullet in bullet_points ] async def summarize(self) -> SummarizerResponse: messages: List[Message] = [ { "role": "system", "content": SUMMARIZER_PROMPT }, { "role": "user", "content": self.message } ] request = { "model": self.model, "messages": messages, "stream": False, "keep_alive": "10s" } resp = await generate_chat_completions(request) if "choices" in resp and len(resp["choices"]) > 0: content: str = resp["choices"][0]["message"]["content"] html = markdown.markdown(content) soup = BeautifulSoup(html, "html.parser") character_results = self.extract_section(soup, "Characters") character_results = self.sanitize_section(character_results) plot_points = self.extract_section(soup, "Plot Details") plot_points = self.sanitize_section(plot_points) return SummarizerResponse(characters=character_results, plot=plot_points) else: return SummarizerResponse(characters=[], plot=[]) class Chapter(BaseModel): """ Focuses on a single 'chapter,' or chunk of a conversation. Provides methods to search for data in this section of conversational story history. """ convo_id: Optional[str] client: chromadb.ClientAPI chapter_id: str messages: List[Message] embedding_func: EmbeddingFunc def create_metadata(self) -> Dict: return { "convo_id": self.convo_id, "chapter": self.chapter_id } def get_collection(self) -> Optional[ChromaCollection]: try: coll = self.client.get_collection("stories") if not self.convo_id: self.convo_id = ( coll.metadata["current_convo_id"] if "current_convo_id" in coll.metadata else None ) return coll except ValueError as e: return None def _create_inserts(self, summary: SummarizerResponse) -> List[MessageInsert]: inserts = [] plot_points = summary.plot character_points = summary.characters for plot_point in plot_points: inserts.append({ 'id': str(uuid.uuid4()), 'content': plot_point, 'metadata': { "convo_id": self.convo_id, "chapter": self.chapter_id, "type": "plot" }, 'embedding': self.embedding_func(plot_point) }) for character_point in character_points: inserts.append({ 'id': str(uuid.uuid4()), 'content': character_point, 'metadata': { "convo_id": self.convo_id, "chapter": self.chapter_id, "type": "character" }, 'embedding': self.embedding_func(character_point) }) return inserts def chapter_state(self) -> dict: """Useful for storing current place in chapter, and convo switching.""" coll = self.get_collection() result = coll.get(ids=f"chapter-{self.chapter_id}", include=["metadatas"]) if len(result.metadatas) > 0: return result.metadatas[0] else: return {} def embed(self, summary: SummarizerResponse): """ Store plot points for this chapter in ChromaDB. """ coll = self.get_collection() if not self.convo_id: return inserts = self._create_inserts(summary) if len(inserts) > 0: documents = [entry['content'] for entry in inserts] metadatas = [entry['metadata'] for entry in inserts] ids = [entry['id'] for entry in inserts] embeddings = [entry['embedding'] for entry in inserts] coll.upsert(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) def query_plot(self, search_term): return self.query(search_term, "plot") def query_characters(self, search_term): return self.query(search_term, "character") def query(self, search_term: str, type: str) -> List[ChromaDocument]: coll = self.get_collection() if coll and self.convo_id: term_embedding = self.embedding_func(search_term) results = coll.query( query_embeddings=[term_embedding], include=["documents", "metadatas"], where={ "$and": [ { "convo_id": self.convo_id }, { "chapter": self.chapter_id }, { "type": type } ] }, n_results = 5 ) # flatten out list of list of documents # because chroma returns a List[List[Document]] for some reason. if 'documents' in results: docs = [ doc for doc_list in results['documents'] for doc in doc_list ] metadatas = [ md for md_list in results['metadatas'] for md in md_list ] results = [] for (doc, metadata) in zip(docs, metadatas): results.append({ "doc": doc, "metadata": metadata }) return results else: return [] else: return [] class Story(BaseModel): """Container for chapters. Manages an entire conversation.""" convo_id: Optional[str] = None client: chromadb.ClientAPI messages: List[Message] embedding_func: EmbeddingFunc def _collection_name(self): return f"stories" def create_metadata(self): try: coll = self.client.get_collection(self._collection_name()) if coll: # If we have pre-specified a convo id, update metadata # of collection accordingly. if self.convo_id: metadata = coll.metadata metadata['current_convo_id'] = self.convo_id metadata["hnsw:space"] = "cosine" coll = self.client.get_or_create_collection( name=self._collection_name(), metadata=metadata ) else: # Otherwise pull it out of the database. self.convo_id = ( coll.metadata['current_convo_id'] if 'current_convo_id' in coll.metadata else None ) return coll.metadata except ValueError: return { "current_convo_id": "", "current_chapter": 1 } def convo_state(self) -> dict: """Retrieve information about the current conversation.""" if not self.convo_id or self.convo_id == "": return {} convo_state_id = f"convo-{self.convo_id}" coll = self.get_collection() result = coll.get(ids=[convo_state_id], include=["metadatas"]) if len(result.metadatas) > 0: return result.metadatas[0] else: # insert convo state # TODO do something useful with convo summary convo_summary = f"State for convo {self.convo_id}" convo_metadata = { "current_chapter": 1 } coll.add( ids=[convo_state_id], documents=[convo_summary], # maybe store convo summary here? embeddings=self.embedding_func(convo_summary), metadatas=[convo_metadata] ) return convo_metadata def switch_convo(self): """Force a switch of current conversation.""" if not self.convo_id: # If we have only a user message (i.e. start of # conversation), forcibly set to if len(self.messages) < 2: self.convo_id = "" else: # Otherwise attempt to get the cllection, which forces # metatada creation and updates. self.get_collection() def get_collection(self): """Retrieve the collection, with its context set to the current convo ID.""" try: coll = self.client.get_collection(self._collection_name()) if coll: # If we have pre-specified a convo id, update metadata # of collection accordingly. if self.convo_id: metadata = coll.metadata metadata['current_convo_id'] = self.convo_id metadata["hnsw:space"] = "cosine" return self.client.get_or_create_collection( name=self._collection_name(), metadata=metadata ) else: # Otherwise pull existing convo id out of the database. self.convo_id = ( coll.metadata['current_convo_id'] if 'current_convo_id' in coll.metadata else None ) return coll except ValueError: # if the stories collection does not exist, create it # completely from scratch. metadata = { "current_convo_id": "", "hnsw:space": "cosine" } return self.client.get_or_create_collection(self._collection_name(), metadata=metadata) def _current_chapter(self) -> int: try: return self.convo_state()["current_chapter"] except: return 1 def _current_chapter_object(self) -> Chapter: return Chapter( convo_id = self.convo_id, chapter_id=str(self._current_chapter()), messages=self.messages, client=self.client, embedding_func=self.embedding_func ) def embed_summary(self, summary: SummarizerResponse): self._current_chapter_object().embed(summary) def query_plot(self, term: str) -> List[ChromaDocument]: return self._current_chapter_object().query_plot(term) def query_characters(self, term: str) -> List[ChromaDocument]: return self._current_chapter_object().query_characters(term) # Utils class SessionInfo(BaseModel): chat_id: str message_id: str session_id: str def extract_session_info(event_emitter) -> Optional[SessionInfo]: """The latest innovation in hacky workarounds.""" try: info = event_emitter.__closure__[0].cell_contents return SessionInfo( chat_id=info["chat_id"], message_id=info["message_id"], session_id=info["session_id"] ) except: return None def create_enrichment_summary_prompt( narrative: str, character_details: List[str], plot_details: List[str] ) -> str: prompt = ENRICHMENT_SUMMARY_PROMPT prompt += "Here are the original Character and Plot Details sections." prompt += " Summarize them according to the instructions.\n\n" snippets = "## Character Details:\n" for character_detail in character_details: snippets += f"- {character_detail}\n" snippets = snippets.strip() snippets += "\n" snippets += "\n\n## Plot Details:\n" for plot_point in plot_details: snippets += f"- {plot_point}\n" snippets = snippets.strip() snippets += "\n" snippets = snippets.strip() prompt += snippets + "\n\n" prompt += "Additionally, the narrative you must continue is provided below." prompt += "\n\n-----\n\n" prompt += narrative return prompt.strip() def create_context(results: SummarizerResponse) -> Optional[str]: if not results: return None character_details = results.characters plot_details = results.plot snippets = "## Relevant Character Details:\n" snippets += "These are relevant bits of information about characters in the story.\n" for character_detail in character_details: snippets += f"- {character_detail}\n" snippets = snippets.strip() snippets += "\n" snippets += "\n\n## Relevant Plot Details:\n" snippets += "These are relevant plot details that happened earlier in the story.\n" for plot_point in plot_details: snippets += f"- {plot_point}\n" snippets = snippets.strip() snippets += "\n" message = ( "\n\nUse the following context as information about the story, inside XML tags.\n\n" f"\n{snippets}\n" "When answering to user:\n" "- Use the context to enhance your knowledge of the story.\n" "- If you don't know, do not ask for clarification.\n" "Do not mention that you obtained the information from the context.\n" "Do not mention the context.\n" f"Continue the story according to the user's directions." ) return message def split_messages(messages, keep_amount): if len(messages) <= keep_amount: return messages[:], [] recent_messages = messages[-keep_amount:] old_messages = messages[:-keep_amount] return recent_messages, old_messages def chunk_messages(messages, chunk_size): return [messages[i:i + chunk_size] for i in range(0, len(messages), chunk_size)] def llm_messages_to_user_messages(messages): return [ {'role': 'user', 'content': msg['content']} for msg in messages if msg['role'] == 'assistant' ] # Das Filter class Filter: class Valves(BaseModel): def summarizer_model(self, body): if self.summarizer_model_id == "": # This will be the model ID in the convo. If not base # model, it will cause problems. return body["model"] else: return self.summarizer_model_id summarizer_model_id: str = Field( default="", description="Model used to summarize the conversation. Must be a base model.", ) n_last_messages: int = Field( default=4, description="Number of last messages to retain." ) pass class UserValves(BaseModel): pass def __init__(self): self.valves = self.Valves() pass def extract_convo_id(self, messages): """Extract ID of first message to use as conversation ID.""" if len(messages) > 0: first_user_message = next( (message for message in messages if message.get("role") == "user"), None ) if first_user_message and 'id' in first_user_message: return first_user_message['id'] else: raise ValueError("No messages found to extract conversation ID") else: raise ValueError("No messages found to extract conversation ID") async def summarize(self, messages) -> Optional[SummarizerResponse]: message_to_summarize = get_last_assistant_message(messages) if message_to_summarize: summarizer = Summarizer(model=self.summarizer_model_id, message=message_to_summarize) return await summarizer.summarize() else: return None async def send_outlet_status(self, event_emitter, done: bool): description = ( "Analyzing Narrative (do not reply until this is done)" if not done else "Narrative analysis complete (you may now reply)." ) await event_emitter({ "type": "status", "data": { "description": description, "done": done, }, }) async def set_enriching_status(self, state: str): if not self.event_emitter: return done = state == "done" description = "Enriching Narrative" if state == "init": description = f"{description}: Initializing..." if state == "searching": description = f"{description}: Searching..." if state == "analyzing": description = f"{description}: Analyzing..." description = ( description if not done else "Enrichment Complete" ) await self.event_emitter({ "type": "status", "data": { "description": description, "done": done, }, }) async def outlet( self, body: dict, __user__: Optional[dict], __event_emitter__: Callable[[Any], Awaitable[None]], ) -> dict: # Useful things to have around. self.session_info = extract_session_info(__event_emitter__) self.event_emitter = __event_emitter__ self.summarizer_model_id = self.valves.summarizer_model(body) await self.send_outlet_status(__event_emitter__, False) messages = body['messages'] # summarize into plot points. summary = await self.summarize(messages) story = Story( convo_id=self.session_info.chat_id, client=CHROMA_CLIENT, embedding_func=EMBEDDING_FUNCTION, messages=messages ) story.switch_convo() if summary: story.embed_summary(summary) await self.send_outlet_status(__event_emitter__, True) return body async def generate_enrichment_queries(self, messages) -> SummarizerResponse: last_response = get_last_assistant_message(messages) user_input = get_last_user_message(messages) query_message = "" if last_response: query_message += f"## Assistant\n\n{last_response}\n\n" if user_input: query_message += f"## User\n\n{user_input}\n\n" query_message = query_message.strip() summarizer = Summarizer( model=self.summarizer_model_id, message=query_message, prompt=QUERY_PROMPT ) return await summarizer.summarize() async def summarize_enrichment( self, messages, character_results: List[ChromaDocument], plot_results: List[ChromaDocument] ) -> SummarizerResponse: last_response = get_last_assistant_message(messages) user_input = get_last_user_message(messages) character_details = [r['doc'] for r in character_results] plot_details = [r['doc'] for r in plot_results] narrative_message = "" if last_response: narrative_message += f"## Assistant\n\n{last_response}\n\n" if user_input: narrative_message += f"## User\n\n{user_input}\n\n" narrative_message = narrative_message.strip() summarization_prompt = create_enrichment_summary_prompt( narrative=narrative_message, plot_details=plot_details, character_details=character_details ) summarizer = Summarizer( model=self.summarizer_model_id, message=narrative_message, prompt=summarization_prompt ) return await summarizer.summarize() async def enrich(self, story: Story, messages) -> Optional[SummarizerResponse]: if len(messages) < 2: return None await self.set_enriching_status("searching") query_generation_result = await self.generate_enrichment_queries(messages) character_results = [result for query in query_generation_result.characters for result in story.query_characters(query)] plot_results = [result for query in query_generation_result.plot for result in story.query_plot(query)] await self.set_enriching_status("analyzing") return await self.summarize_enrichment(messages, character_results, plot_results) async def update_system_message(self, messages, system_message): story = Story( convo_id=self.session_info.chat_id, client=CHROMA_CLIENT, embedding_func=EMBEDDING_FUNCTION, messages=messages ) story.switch_convo() if story.convo_id == "": return enrichment_summary: Optional[SummarizerResponse] = await self.enrich(story, messages) if enrichment_summary: context = create_context(enrichment_summary) else: context = None if context: system_message["content"] += context async def inlet( self, body: dict, __user__: Optional[dict], __event_emitter__: Callable[[Any], Awaitable[None]] ) -> dict: # Useful properties to have around. self.session_info = extract_session_info(__event_emitter__) self.event_emitter = __event_emitter__ self.summarizer_model_id = self.valves.summarizer_model(body) await self.set_enriching_status("init") messages = body["messages"] # Ensure we always keep the system prompt system_prompt = next( (message for message in messages if message.get("role") == "system"), None ) if system_prompt: all_messages = [ message for message in messages if message.get("role") != "system" ] recent_messages, old_messages = split_messages(all_messages, self.valves.n_last_messages) most_recent_messages = messages[-self.valves.n_last_messages :] else: system_prompt = { "id": str(uuid.uuid4()), "role": "system", "content": "" } recent_messages, old_messages = split_messages(messages, self.valves.n_last_messages) await self.update_system_message(messages, system_prompt) recent_messages.insert(0, system_prompt) body["messages"] = recent_messages await self.set_enriching_status("done") return body