Whitepaper
Docs
Sign In
Function
Function
filter
v0.1
Context Manager
Function ID
context_manager
Creator
@keke
Downloads
152+
Truncates chat context length with token limit and max turns, logging the chat turn data.
Get
README
No README available
Function Code
Show
""" title: Context Manager description: 1. Truncate chat context length with token limit and max turns, system message excluded. 2. Log the chat turn data, marked with a 'log_type' field. author: Kejun Luo version: 0.1 """ import tiktoken from pydantic import BaseModel, Field from typing import Optional, Callable, Any, Awaitable import time import logging import json logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class Filter: class Valves(BaseModel): priority: int = Field(default=0, description="Priority level") max_turns: int = Field( default=25, description="Number of conversation turns to retain. Set '0' for unlimited", ) token_limit: int = Field( default=10000, description="Number of token limit to retain. Set '0' for unlimited", ) def __init__(self): self.valves = self.Valves() self.limit_exceeded = False self.encoding = tiktoken.get_encoding("o200k_base") self.input_tokens = 0 self.output_tokens = 0 self.user = None self.model_base = None self.model_name = None self.start_time = None self.elapsed_time = None self.input_message_count = None def log_chat_turn(self): """ Log data for a single chat turn """ if all( [ self.user, self.model_base, self.model_name, self.input_tokens is not None, self.output_tokens is not None, self.elapsed_time is not None, self.input_message_count is not None, ] ): log_data = { "log_type": "chat_turn", "user": self.user, "model_base": self.model_base, "model_name": self.model_name, "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, "elapsed_seconds": round(self.elapsed_time, 0), "input_message_count": self.input_message_count, } # print(json.dumps(log_data)) logger.info(json.dumps(log_data)) async def inlet( self, body: dict, __event_emitter__: Callable[[Any], Awaitable[None]], __model__: Optional[dict] = None, __user__: Optional[dict] = None, ) -> dict: """Truncate chat context length with token limit and max turns, system message excluded""" messages = body["messages"] chat_messages = messages[:] self.limit_exceeded = False chat_messages = self.truncate_turns(chat_messages) chat_messages = self.truncate_tokens(chat_messages) await self.show_exceeded_status(__event_emitter__, len(chat_messages)) self.init_log_data(__user__["email"], chat_messages) body["messages"] = chat_messages return body def truncate_turns(self, messages: list) -> list: result = messages if self.valves.max_turns > 0: current_turns = (len(messages) - 1) // 2 if current_turns > self.valves.max_turns: sent_msg_count = self.valves.max_turns * 2 + 1 result = messages[-sent_msg_count:] self.limit_exceeded = True return result def truncate_tokens(self, messages: list) -> list: filter_messages = messages if self.valves.token_limit > 0: filter_messages = [] current_toks = 0 for msg in reversed(messages): toks = self.count_text_tokens(msg) user = msg.get("role", "") # the first message must be a user message, so a user message should not be truncated. if (current_toks + toks > self.valves.token_limit) and (user != "user"): self.limit_exceeded = True break filter_messages.insert(0, msg) current_toks += toks return filter_messages def init_log_data(self, user_email: str, messages: list): self.user = user_email self.input_tokens = 0 self.start_time = time.time() for msg in messages: self.input_tokens += self.count_text_tokens(msg) self.input_message_count = len(messages) def outlet( self, body: dict, __model__: Optional[dict] = None, ) -> dict: self.output_tokens = self.count_text_tokens(body["messages"][-1]) self.model_base = __model__["id"] self.model_name = __model__["name"] end_time = time.time() if self.start_time: self.elapsed_time = end_time - self.start_time self.log_chat_turn() return body async def show_exceeded_status( self, __event_emitter__: Callable[[Any], Awaitable[None]], message_count: int ) -> None: if self.limit_exceeded: await __event_emitter__( { "type": "status", "data": { "description": f"Context limit reached - keeping last {message_count} messages.", "done": True, }, } ) def count_text_tokens(self, msg: dict) -> int: content = msg.get("content", "") total_tokens = 0 if isinstance(content, list): # Handle multi-modal content for item in content: if item.get("type") == "text": text = item.get("text", "") total_tokens += len(self.encoding.encode(text)) elif isinstance(content, str): # Handle text-only content total_tokens = len(self.encoding.encode(content)) else: # Handle unexpected content types total_tokens = 0 return total_tokens