Function
filter
v0.1.0
Checkpoint Summarization Filter
Manage context use by summarizing conversation history.
Function ID
checkpoint_summarization_filter
Creator
@projectmoon
Downloads
215+

Function Content
python
"""
title: Checkpoint Summary Filter
author: projectmoon
author_url: https://git.agnos.is/projectmoon/open-webui-filters
version: 0.1.0
license: AGPL-3.0+
required_open_webui_version: 0.3.9
"""

# Documentation: https://git.agnos.is/projectmoon/open-webui-filters

# System imports
import asyncio
import hashlib
import uuid
import json
import re
import logging

from typing import Optional, List, Dict, Callable, Any, NewType, Tuple, Awaitable, ClassVar
from typing_extensions import TypedDict, NotRequired
from collections import deque

# Libraries available to OpenWebUI
from pydantic import BaseModel as PydanticBaseModel, Field
import chromadb
from chromadb import Collection as ChromaCollection
from chromadb.api.types import Document as ChromaDocument

# OpenWebUI imports
from config import CHROMA_CLIENT
from apps.rag.main import app as rag_app
from apps.ollama.main import app as ollama_app
from apps.ollama.main import show_model_info, ModelNameForm
from utils.misc import get_last_user_message, get_last_assistant_message
from main import generate_chat_completions

from apps.webui.models.chats import Chats
from apps.webui.models.models import Models
from apps.webui.models.users import Users

# Embedding (not yet used)
EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
EmbeddingFunc = NewType('EmbeddingFunc', Callable[[str], List[Any]])

# Prompts
SUMMARIZER_PROMPT = """
You are a chat conversation summarizer. Your task is to summarize the given
portion of an ongoing conversation. First, determine if the conversation is
a regular chat between the user and the assistant, or if the conversation is
part of a story or role-playing session.

Summarize the important parts of the given chat between the user and the
assistant. Limit your summary to one paragraph. Make sure your summary is
detailed. Write the summary as if you are summarizing part of a larger
conversation. Do not refer to "you" or "me" in the summary. Write in the
third person perspective.

If the conversation is a regular chat, write your summary referring to the
ongoing conversation as a chat. If the conversation is a regular chat, refer
to the user and the assistant as user and assistant. If the conversation is
a regular chat, do not refer to yourself as the assistant. Do not make up a
name for the user. If the conversation is a regular chat, summarize all
important parts of the chat.

If the conversation is a story or role-playing session, write your summary
referring to the conversation as an ongoing story. If the conversation is a
story or roleplaying session, do not refer to the useror assistant in your
summary. If the conversation is a story or roleplaying sesison, only use the
names of the characters, places, and events in the story.
""".replace("\n", " ").strip()

class Message(TypedDict):
    id: NotRequired[str]
    role: str
    content: str

class MessageInsertMetadata(TypedDict):
    role: str
    chapter: str

class MessageInsert(TypedDict):
    message_id: str
    content: str
    metadata: MessageInsertMetadata
    embeddings: List[Any]


class BaseModel(PydanticBaseModel):
    class Config:
        arbitrary_types_allowed = True

class SummarizerResponse(BaseModel):
    summary: str


class Summarizer(BaseModel):
    messages: List[dict]
    model: str
    prompt: str = SUMMARIZER_PROMPT

    async def summarize(self) -> Optional[SummarizerResponse]:
        sys_message: Message = { "role": "system", "content": SUMMARIZER_PROMPT }
        user_message: Message = {
            "role": "user",
            "content": "Make a detailed summary of the conversation up to this point."
        }

        messages = [sys_message] + self.messages + [user_message]

        request = {
            "model": self.model,
            "messages": messages,
            "stream": False,
            "keep_alive": "10s"
        }

        resp = await generate_chat_completions(request)
        if "choices" in resp and len(resp["choices"]) > 0:
            content: str = resp["choices"][0]["message"]["content"]
            return SummarizerResponse(summary=content)
        else:
            return None


class Checkpoint(BaseModel):
    # chat id
    chat_id: str

    # the message ID this checkpoint was created from.
    message_id: str

    # index of the message in the message input array. in the inlet
    # function, we do not have access to incoming message ids for some
    # reason. used as a fallback to drop old context when
    message_index: int = 0

    # the "slug", or chain of messages, that led to this point.
    slug: str

    # actual summary of messages.
    summary: str

    # if we try to put a type hint on this, it gets mad.
    @staticmethod
    def from_json(obj: dict):
        try:
            return Checkpoint(
                chat_id=obj["chat_id"],
                message_id=obj["message_id"],
                message_index=obj["message_index"],
                slug=obj["slug"],
                summary=obj["summary"]
            )
        except:
            return None

    def to_json(self) -> str:
        return self.model_dump_json()


class Checkpointer(BaseModel):
    """Manages summary checkpoints in a single chat."""
    chat_id: str
    summarizer_model: str = ""
    chroma_client: chromadb.ClientAPI
    messages: List[dict]=[] # stripped set of messages
    full_messages: List[dict]=[] # all the messages
    embedding_func: EmbeddingFunc=(lambda a: 0)

    collection_name: ClassVar[str] = "chat_checkpoints"

    def _get_collection(self) -> ChromaCollection:
        return self.chroma_client.get_or_create_collection(
            name=Checkpointer.collection_name
        )


    def _insert_checkpoint(self, checkpoint: Checkpoint):
        coll = self._get_collection()
        checkpoint_doc = checkpoint.to_json()
        # Insert the checkpoint itself with slug as ID.
        coll.upsert(
            ids=[checkpoint.slug],
            documents=[checkpoint_doc],
            metadatas=[{ "chat_id": self.chat_id, "type": "checkpoint" }],
            embeddings=[self.embedding_func(checkpoint_doc)]
        )

        # Update the chat info doc for this chat.
        coll.upsert(
            ids=[self.chat_id],
            documents=[json.dumps({ "current_checkpoint": checkpoint.slug })],
            embeddings=[self.embedding_func(self.chat_id)]
        )

    def _calculate_slug(self) -> Optional[str]:
        if len(self.messages) == 0:
            return None

        message_ids = [msg["id"] for msg in reversed(self.messages)]
        slug = "|".join(message_ids)
        return hashlib.sha256(slug.encode()).hexdigest()

    def _get_state(self):
        resp = self._get_collection().get(ids=[self.chat_id], include=["documents"])
        state: dict = (json.loads(resp["documents"][0])
                 if resp["documents"] and len(resp["documents"]) > 0
                 else { "current_checkpoint": None })
        return state


    def _find_message_index(self, message_id: str) -> Optional[int]:
        for idx, message in enumerate(self.full_messages):
            if message["id"] == message_id:
                return idx
        return None

    def nuke_checkpoints(self):
        """Delete all checkpoints for this chat."""
        coll = self._get_collection()

        checkpoints = coll.get(
            include=["documents"],
            where={"chat_id": self.chat_id}
        )

        self._get_collection().delete(
            ids=[self.chat_id] + checkpoints["ids"]
        )

    async def create_checkpoint(self) -> str:
        summarizer = Summarizer(model=self.summarizer_model, messages=self.messages)
        resp = await summarizer.summarize()
        if resp:
            slug = self._calculate_slug()
            checkpoint_message = self.messages[-1]
            checkpoint_index = self._find_message_index(checkpoint_message["id"])

            checkpoint = Checkpoint(
                chat_id = self.chat_id,
                slug = self._calculate_slug(),
                message_id = checkpoint_message["id"],
                message_index = checkpoint_index,
                summary = resp.summary
            )

            self._insert_checkpoint(checkpoint)
            return slug

    def get_checkpoint(self, slug: Optional[str]) -> Optional[Checkpoint]:
        if not slug:
            return None

        resp = self._get_collection().get(ids=[slug], include=["documents"])
        checkpoint = (resp["documents"][0]
                      if resp["documents"] and len(resp["documents"]) > 0
                      else None)

        if checkpoint:
            return Checkpoint.from_json(json.loads(checkpoint))
        else:
            return None

    def get_current_checkpoint(self) -> Optional[Checkpoint]:
        state = self._get_state()
        return self.get_checkpoint(state["current_checkpoint"])


#########################
# Utilities
#########################

class SessionInfo(BaseModel):
    chat_id: str
    message_id: str
    session_id: str


def extract_session_info(event_emitter) -> Optional[SessionInfo]:
    """The latest innovation in hacky workarounds."""
    try:
        info = event_emitter.__closure__[0].cell_contents
        return SessionInfo(
            chat_id=info["chat_id"],
            message_id=info["message_id"],
            session_id=info["session_id"]
        )
    except:
        return None


def predicted_token_use(messages) -> int:
    """Parse most recent message to calculate estimated token use."""
    if len(self.messages == 0):
        return 0

    # Naive assumptions:
    #  - 1 word = 1 token.
    #  - 1 period, comma, or colon = 1 token
    message = messages[-1]
    return len(list(filter(None, re.split(r"\s|(;)|(,)|(\.)|(:)|\n", message))))

def is_big_convo(messages, num_ctx: int=8192) -> bool:
    """
    Attempt to detect large pre-existing conversation by looking at
    recent eval counts from messages and comparing against given
    num_ctx. We check all messages for an eval count that goes above
    the context limit. It doesn't matter where in the message list; if
    it's somewhere in the middle, it means that there was a context
    shift.
    """
    for message in messages:
        if "info" in message:
            eval_count = (message["info"]["eval_count"]
                                 if "eval_count" in message["info"]
                                 else 0)
            prompt_eval_count = (message["info"]["prompt_eval_count"]
                                 if "prompt_eval_count" in message["info"]
                                 else 0)
            tokens_used = eval_count + prompt_eval_count
        else:
            tokens_used = 0

        if tokens_used >= num_ctx:
            return True

    return False


def hit_context_limit(
        messages,
        num_ctx: int=8192,
        wiggle_room: int=1000
) -> Tuple[bool, int]:
    """
    Determine if we've hit the context limit, within some reasonable
    estimation. We have a defined 'wiggle room' that is subtracted
    from the num_ctx parameter, in order to capture near-filled
    contexts. We do it this way because we're summarizing on output,
    rather than before input (inlet function doesn't have enough
    info).
    """
    if len(messages) == 0:
        return False, 0

    last_message = messages[-1]
    tokens_used = 0
    if "info" in last_message:
        eval_count = (last_message["info"]["eval_count"]
                      if "eval_count" in last_message["info"] else 0)
        prompt_eval_count = (last_message["info"]["prompt_eval_count"]
                      if "prompt_eval_count" in last_message["info"] else 0)
        tokens_used = eval_count + prompt_eval_count

    if tokens_used >= (num_ctx - wiggle_room):
        amount_over = tokens_used - num_ctx
        amount_over = 0 if amount_over < 0 else amount_over
        return True, amount_over
    else:
        return False, 0

def extract_base_model_id(model: dict) -> Optional[str]:
    if "base_model_id" not in model["info"]:
        return None

    base_model_id = model["info"]["base_model_id"]
    if not base_model_id:
        base_model_id = model["id"]

    return base_model_id

def extract_owu_model_param(model_obj: dict, param_name: str):
    """
    Extract a parameter value from the DB definition of a model
    that is based on another model.
    """
    if not "params" in model_obj["info"]:
        return None

    params = model_obj["info"]["params"]
    return params.get(param_name, None)

def extract_owu_base_model_param(base_model_id: str, param_name: str):
    """Extract a parameter value from the DB definition of an ollama base model."""
    base_model = Models.get_model_by_id(base_model_id)

    if not base_model:
        return None

    base_model.params = base_model.params.model_dump()
    return base_model.params.get(param_name, None)

def extract_ollama_response_param(model: dict, param_name: str):
    """Extract a parameter value from ollama show API response."""
    if "parameters" not in model:
        return None

    for line in model["parameters"].splitlines():
        if line.startswith(param_name):
            return line.lstrip(param_name).strip()

    return None

async def get_model_from_ollama(model_id: str, user_id) -> Optional[dict]:
    """Call ollama show API and return model information."""
    curr_user = Users.get_user_by_id(user_id)
    try:
        return await show_model_info(ModelNameForm(name=model_id), user=curr_user)
    except Exception as e:
        print(f"Could not get model info: {e}")
        return None

async def calculate_num_ctx(chat_id: str, user_id, model: dict) -> int:
    """
    Attempt to discover the current num_ctx parameter in many
    different ways.
    """
    # first check the open-webui chat parameters.
    chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id)
    if chat:
        # this might look odd, but the chat field is a json blob of
        # useful info.
        chat = json.loads(chat.chat)
        if "params" in chat and "num_ctx" in chat["params"]:
            return chat["params"]["num_ctx"]

    # then check open web ui model def
    num_ctx = extract_owu_model_param(model, "num_ctx")
    if num_ctx:
        return num_ctx

    # then check open web ui base model def.
    base_model_id = extract_base_model_id(model)
    if not base_model_id:
        # fall back to default in case of weirdness.
        return 2048

    num_ctx = extract_owu_base_model_param(base_model_id, "num_ctx")
    if num_ctx:
        return num_ctx

    # THEN check ollama directly.
    base_model = await get_model_from_ollama(base_model_id, user_id)
    num_ctx = extract_ollama_response_param(base_model, "num_ctx")
    if num_ctx:
        return num_ctx

    # finally, return default.
    return 2048



class Filter:
    class Valves(BaseModel):
        def summarizer_model(self, body):
            if self.summarizer_model_id == "":
                return extract_base_model_id(body["model"])
            else:
                return self.summarizer_model_id

        summarize_large_contexts: bool = Field(
            default=False,
            description=(
                f"Whether or not to use a large context model to summarize large "
                f"pre-existing conversations."
            )
        )
        wiggle_room: int = Field(
            default=1000,
            description=(
                "Amount of token 'wiggle room' for estimating when a context shift occurs. "
                "Subtracted from num_ctx when checking if summarization is needed."
            )
        )
        summarizer_model_id: str = Field(
            default="",
            description="Model used to summarize the conversation. Must be a base model.",
        )
        large_summarizer_model_id: str = Field(
            default="",
            description=(
                "Model used to summarize large pre-existing contexts. "
                "Must be a base model with a context size large enough "
                "to fit the conversation."
            )
        )
        pass

    class UserValves(BaseModel):
        pass

    def __init__(self):
        self.valves = self.Valves()
        pass


    def load_current_chat(self) -> dict:
        # the chat property of the model is the json blob that holds
        # all the interesting stuff
        chat = (Chats
                .get_chat_by_id_and_user_id(self.session_info.chat_id, self.user["id"])
                .chat)

        return json.loads(chat)

    def get_messages_for_checkpointing(self, messages, num_ctx, last_checkpointed_id):
        """
        Assemble list of messages to checkpoint, based on current
        state and valve settings.
        """
        message_chain = deque()
        for message in reversed(messages):
            if message["id"] == last_checkpointed_id:
                break
            message_chain.appendleft(message)

        message_chain = list(message_chain) # the lazy way

        # now we check if we are a big conversation, and if valve
        # settings allow that kind of summarization.
        summarizer_model = self.valves.summarizer_model
        if is_big_convo(messages, num_ctx) and not self.valves.summarize_large_contexts:
            # must summarize using small model. for now, drop to last
            # N messages.
            print((
                "Dropping all but last 4 messages to summarize "
                "large convo without large model."
            ))
            message_chain = message_chain[-4:]

        return message_chain


    async def create_checkpoint(
            self,
            messages: List[dict],
            last_checkpointed_id: Optional[str]=None,
            num_ctx: int=8192
    ):
        if len(messages) == 0:
            return

        print(f"[{self.session_info.chat_id}] Detected context shift. Summarizing.")
        await self.set_summarizing_status(done=False)
        last_message = messages[-1] # should check for role = assistant
        curr_message_id: Optional[str] = (
            last_message["id"] if last_message else None
        )

        if not curr_message_id:
            return

        # strip messages down to what is in the current checkpoint.
        message_chain = self.get_messages_for_checkpointing(
            messages, num_ctx, last_checkpointed_id
        )

        # we should now have a list of messages that is just within
        # the current context limit.
        summarizer_model = self.valves.summarizer_model_id
        if is_big_convo(message_chain, num_ctx) and self.valves.summarize_large_contexts:
            print(f"[{self.session_info.chat_id}] Summarizing LARGE context!")
            summarizer_model = self.valves.large_summarizer_model_id


        checkpointer = Checkpointer(
            chat_id=self.session_info.chat_id,
            summarizer_model=summarizer_model,
            chroma_client=CHROMA_CLIENT,
            full_messages=messages,
            messages=message_chain
        )

        try:
            slug = await checkpointer.create_checkpoint()
            await self.set_summarizing_status(done=True)
            print((f"[{self.session_info.chat_id}] Summarization checkpoint created: "
                   f"{slug}"))
        except Exception as e:
            print(f"[{self.session_info.chat_id}] Error creating summary: {str(e)}")
            await self.set_summarizing_status(
                done=True, message=f"Error summarizing: {str(e)}"
            )


    def update_chat_with_checkpoint(self, messages: List[dict], checkpoint: Checkpoint):
        if len(messages) < checkpoint.message_index:
            # do not mess with anything if the index doesn't even
            # exist anymore. need a new checkpoint.
            return messages

        # proceed with altering the system prompt. keep system prompt,
        # if it's there, and add summary to it. summary will become
        # system prompt if there is no system prompt.
        convo_messages = [
            message for message in messages if message.get("role") != "system"
        ]

        system_prompt = next(
            (message for message in messages if message.get("role") == "system"), None
        )

        summary_message = f"Summary of conversation so far:\n\n{checkpoint.summary}"

        if system_prompt:
            system_prompt["content"] += f"\n\n{summary_message}"
        else:
            system_prompt = { "role": "system", "content": summary_message }


        # drop old messages, reapply system prompt.
        messages = self.apply_checkpoint(checkpoint, messages)
        print(f"[{self.session_info.chat_id}] Applying summary:\n\n{checkpoint.summary}")
        return [system_prompt] + messages


    async def send_message(self, message: str):
        await self.event_emitter({
            "type": "status",
            "data": {
                "description": message,
                "done": True,
            },
        })

    async def set_summarizing_status(self, done: bool, message: Optional[str]=None):
        if not self.event_emitter:
            return

        if not done:
            description = (
                "Summarizing conversation due to reaching context limit (do not reply yet)."
            )
        else:
            description = (
                "Summarization complete (you may now reply)."
            )

        if message:
            description = message

        await self.event_emitter({
            "type": "status",
            "data": {
                "description": description,
                "done": done,
            },
        })

    def apply_checkpoint(
            self, checkpoint: Checkpoint, messages: List[dict]
    ) -> List[dict]:
        """
        Possibly shorten the message context based on a checkpoint.
        This works two ways: if the messages have IDs (outlet
        filter), split by message ID (very reliable). Otherwise,
        attempt to split by on the recorded message index (inlet
        filter; not very reliable).
        """

        # first attempt to drop everything before the checkpointed
        # message id.
        split_point = 0
        for idx, message in enumerate(messages):
            if "id" in message and message["id"] == checkpoint.message_id:
                split_point = idx
                break

        # if we can't find the ID to split on, fall back to message
        # index if possible. this can happen during message
        # regeneration, for example. or if we're called from the inlet
        # filter, which doesn't have access to message ids.
        if split_point == 0 and checkpoint.message_index <= len(messages):
            split_point = checkpoint.message_index

        orig = len(messages)
        messages = messages[split_point:]
        print((f"[{self.session_info.chat_id}] Dropped context to {len(messages)} "
               f"messages (from {orig})"))
        return messages


    async def handle_nuke(self, body):
        checkpointer = Checkpointer(
            chat_id=self.session_info.chat_id,
            chroma_client=CHROMA_CLIENT
        )
        checkpointer.nuke_checkpoints()
        await self.send_message("Deleted all checkpoint for chat.")

        body["messages"][-1]["content"] = (
            "Respond ony with: 'Deleted all checkpoints for chat.'"
        )

        body["messages"] = body["messages"][-1:]
        return body

    async def outlet(
        self,
        body: dict,
        __user__: Optional[dict],
        __model__: Optional[dict],
        __event_emitter__: Callable[[Any], Awaitable[None]],
    ) -> dict:
        # Useful things to have around.
        self.user = __user__
        self.model = __model__
        self.session_info = extract_session_info(__event_emitter__)
        self.event_emitter = __event_emitter__
        self.summarizer_model_id = self.valves.summarizer_model(body)

        # global filters apply to requests coming in through proxied
        # API. If we're not an OpenWebUI chat, abort mission.
        if not self.session_info:
            return body

        if not self.model or self.model["owned_by"] != "ollama":
            return body

        messages = body["messages"]

        num_ctx = await calculate_num_ctx(
            chat_id=self.session_info.chat_id,
            user_id=self.user["id"],
            model=self.model
        )

        # apply current checkpoint ONLY for purposes of calculating if
        # we have hit num_ctx within current checkpoint.
        checkpointer = Checkpointer(
            chat_id=self.session_info.chat_id,
            chroma_client=CHROMA_CLIENT
        )

        checkpoint = checkpointer.get_current_checkpoint()
        messages_for_ctx_check = (self.apply_checkpoint(checkpoint, messages)
                                  if checkpoint else messages)

        hit_limit, amount_over = hit_context_limit(
            messages=messages_for_ctx_check,
            num_ctx=num_ctx,
            wiggle_room=self.valves.wiggle_room
        )

        if hit_limit:
            # we need the FULL message list to do proper summarizing,
            # because we might be summarizing a hug context.
            await self.create_checkpoint(
                messages=messages,
                num_ctx=num_ctx,
                last_checkpointed_id=checkpoint.message_id if checkpoint else None
            )

        print(f"[{self.session_info.chat_id}] Done checking for summarization")
        return body


    async def inlet(
        self,
        body: dict,
        __user__: Optional[dict],
        __model__: Optional[dict],
        __event_emitter__: Callable[[Any], Awaitable[None]]
    ) -> dict:
        # Useful properties to have around.
        self.user = __user__
        self.model = __model__
        self.session_info = extract_session_info(__event_emitter__)
        self.event_emitter = __event_emitter__
        self.summarizer_model_id = self.valves.summarizer_model(body)

        # global filters apply to requests coming in through proxied
        # API. If we're not an OpenWebUI chat, abort mission.
        if not self.session_info:
            return body

        if not self.model or self.model["owned_by"] != "ollama":
            return body

        # super basic external command handling (delete checkpoints).
        user_msg = get_last_user_message(body["messages"])
        if user_msg and user_msg == "!nuke":
            return await self.handle_nuke(body)

        # apply current checkpoint to the chat: adds most recent
        # summary to system prompt, and drops all messages before the
        # checkpoint.
        checkpointer = Checkpointer(
            chat_id=self.session_info.chat_id,
            chroma_client=CHROMA_CLIENT
        )

        checkpoint = checkpointer.get_current_checkpoint()
        if checkpoint:
            print((
                f"Using checkpoint {checkpoint.slug} for "
                f"conversation {self.session_info.chat_id}"
            ))

            body["messages"] = self.update_chat_with_checkpoint(body["messages"], checkpoint)

        return body