We're Hiring!
Whitepaper
Docs
Sign In
Function
Function
filter
v6.3
Super Memory Refractor
Last Updated
3 months ago
Created
3 months ago
Function ID
super_memory_refractor
Creator
@brycewg
Downloads
105+
Get
Sponsored by Open WebUI Inc.
We are hiring!
Shape the way humanity engages with
intelligence
.
Description
1.自动分析存储记忆,提取用户对话中需要记忆的事实信息。 2.为每一条新存入的记忆,都打上一个精确到分钟的时间戳,让ai知道每段记忆记住的时间。 3.定期后台自动生成记忆摘要,将零碎记忆合并为一个完整的记忆模块 4.可以自动更新记忆,拥有查重功能,不会重复记忆已有的事实 5.全面的性能数据:完整展示每一次对话的首字时间显示、总耗时、AI的生成速度 (TPS) 和输出的Token数
README
Function Code
Show
""" title: 超级记忆助手 v6.3 description: 1.自动分析存储记忆,提取用户对话中需要记忆的事实信息。 2.为每一条新存入的记忆,都打上一个精确到分钟的时间戳,让ai知道每段记忆记住的时间。 3.定期后台自动生成记忆摘要,将零碎记忆合并为一个完整的记忆模块 4.可以自动更新记忆,拥有查重功能,不会重复记忆已有的事实 5.全面的性能数据:完整展示每一次对话的首字时间显示、总耗时、AI的生成速度 (TPS) 和输出的Token数 author: 南风 (二改Bryce) version: 6.3 required_open_webui_version: >= 0.5.0 """ # ==================== 导入必要的库 ==================== import json import asyncio import time import datetime import re import os import hashlib from typing import Optional, Callable, Awaitable, Any, List, Tuple import pytz import aiohttp from fastapi.requests import Request from pydantic import BaseModel, Field from open_webui.main import app as webui_app from open_webui.models.users import Users from open_webui.routers.memories import ( add_memory, AddMemoryForm, query_memory, QueryMemoryForm, delete_memory_by_id, ) from open_webui.utils.misc import get_last_assistant_message # ==================== 提示词库 ==================== FACT_EXTRACTION_PROMPT = """你正在帮助维护用户的“记忆”——就像一个个独立的“日记条目”。你将收到最近几条对话。你的任务是判断用户的【最新一条】消息中,有哪些细节值得作为“记忆”被长期保存。【核心指令】1. **只分析用户最新一条消息**:仅从用户的最新发言中识别新的或变更的个人信息。旧消息仅供理解上下文。2. **处理信息变更**:如果用户最新消息与旧信息冲突,只提取更新后的信息。3. **事实独立**:每条记忆都应是独立的“事实”。如果一句话包含多个信息点,请拆分。4. **提取有价值信息**:目标是捕捉任何有助于AI在未来提供更个性化服务的信息。5. **响应明确指令**:如果用户明确要求“记住”,必须提取该信息。6. **忽略短期信息**:不要记录临时或无意义的细节。7. **指定格式返回**:将结果以【JSON字符串数组】的格式返回。如无信息,【只】返回空数组(`[]`)。不要加任何解释。---### 【示例】**示例 1**_输入对话:_- user: ```我爱吃橘子```- assistant: ```太棒了!```- user: ```其实我讨厌吃橘子```_正确输出:_["用户讨厌吃橘子"]**示例 2**_输入对话:_- user: ```我是一名初级数据分析师。请记住我的重要汇报在3月15日。```_正确输出:_["用户是一名初级数据分析师", "用户在3月15日有一次重要汇报"]""" FACT_CONSOLIDATION_PROMPT = """你正在管理用户的"记忆"。你的任务是清理一个可能包含相关、重叠或冲突信息的记忆列表。 **【处理规则】** 1. 你将收到一个JSON格式的记忆列表,每条含"fact"和"created_at"时间戳。 2. 生成一个清理后的最终事实列表,确保: - **完全重复**的信息:只保留`created_at`最新的那一条 - **直接冲突**的信息:只保留`created_at`最新的那一条 - **并列可枚举**的相关属性(如家庭成员、爱好、技能等):合并为一条综合描述 - **部分相似但不冲突**的信息:两者都保留 3. 返回最终结果时,使用一个简单的【JSON字符串数组】格式。不要添加任何解释。 **【示例】** 示例1 - 并列家庭成员合并: _输入:_ `[{"fact": "用户有个妹妹", "created_at": 1635500000}, {"fact": "用户有个弟弟", "created_at": 1636000000}]` _正确输出:_ ["用户有一个妹妹和一个弟弟"] 示例2 - 冲突信息保留最新: _输入:_ `[{"fact": "用户最喜欢的颜色是青色", "created_at": 1635500000}, {"fact": "用户最喜欢的颜色是红色", "created_at": 1636000000}]` _正确输出:_ ["用户最喜欢的颜色是红色"] 示例3 - 相关爱好合并: _输入:_ `[{"fact": "用户喜欢打篮球", "created_at": 1635500000}, {"fact": "用户喜欢踢足球", "created_at": 1636000000}]` _正确输出:_ ["用户喜欢打篮球和踢足球"] 示例4 - 部分相似但不冲突保留: _输入:_ `[{"fact": "用户喜欢橘子", "created_at": 1635500000}, {"fact": "用户喜欢熟透的橘子", "created_at": 1636000000}]` _正确输出:_ ["用户喜欢橘子", "用户喜欢熟透的橘子"]""" MEMORY_SUMMARIZATION_PROMPT = """你是一个记忆摘要助手。你的任务是将关于一个用户的多条相关但零散的记忆,合并成一个简洁、全面、高质量的摘要。**【核心指令】**1. **整合信息**:捕获所有输入记忆中的关键信息点。2. **解决冲突**:如果信息有矛盾,优先采纳时间戳最新的信息。3. **消除冗余**:去除重复或无意义的细节。4. **保持精华**:保留用户的核心偏好、身份特征、重要目标和关系。5. **自然流畅**:最终输出的摘要应该像一段自然语言,而不是一个列表。6. **格式要求**:【必须】只返回一个单一段落的文本摘要。不要加任何解释或多余的文字。**【示例】**_输入的多条记忆:_- "2024年05月10日11点:用户喜欢喝咖啡"- "2024年05月12日15点:用户偏好美式咖啡"- "2024年05月20日09点:用户提到他每天早上都要喝一杯咖啡提神"_正确的输出摘要:_"用户是一个咖啡爱好者,每天早上习惯喝一杯美式咖啡来提神。"现在,请分析以下的相关记忆,并提供一个简洁、高质量的摘要。""" # ==================== 主类定义 ==================== class Filter: _background_tasks = set() # 全局共享 valves(便于后台任务读取最新配置) _global_valves = None # 内存缓存:存储已计算的嵌入向量 {cache_key: {"embedding": [...], "text": "原始文本", "timestamp": 时间戳}} _embedding_cache = {} # 缓存统计信息 _cache_stats = {"hits": 0, "misses": 0, "cleanups": 0} # 记忆整合任务状态跟踪 _last_summarization_status = {"timestamp": None, "result": None, "user_count": 0} # 按量摘要计数器 {user_id: count} _user_memory_counters = {} # 跟踪正在为哪些用户运行摘要任务,防止并发 {user_id} _summarization_running_for_user = set() # 保存原始的delete_memory_by_id函数引用(用于猴子补丁) _original_delete_memory_by_id = None # L2 缓存与后台任务状态跟踪 _reconciled_users = set() _warmed_users = set() class Valves(BaseModel): enabled: bool = Field( default=True, description="【总开关】控制整个超级记忆插件的启用或禁用。关闭后,所有记忆相关的功能将不会运行。", ) api_url: str = Field( default="https://api.openai.com/v1/chat/completions", description="【核心API】用于记忆提取、整合和摘要等自然语言处理任务的LLM API端点。请确保地址正确,且与下方指定的模型兼容。", ) api_key: str = Field( default="", description="【核心API密钥】访问上述LLM API所必需的密钥。请确保其有效并有足够额度。", ) model: str = Field( default="gpt-4o-mini", description="【核心模型】指定用于处理记忆(提取、整合、摘要)的语言模型。推荐使用能力强、速度快的模型,如gpt-4o-mini。", ) show_stats: bool = Field( default=True, description="【状态显示】在每次对话后,于WebUI底部状态栏显示详细的性能和记忆统计数据,如首字延迟、耗时、记忆总数等。", ) messages_to_consider: int = Field( default=6, description="【记忆提取范围】设定在提取新记忆时,需要分析的最新对话轮次数量。例如,设为6将分析最近的三轮用户-AI对话。适当增加可提供更丰富的上下文,但也会增加API成本。", ) timezone: str = Field( default="Asia/Shanghai", description="【时间戳时区】为所有新记忆添加时间戳时所使用的时区。请使用标准时区标识符(如 'Asia/Shanghai'),以确保记忆时间的准确性。", ) # 实时记忆整合阈值 consolidation_threshold: float = Field( default=0.75, description="【实时整合触发阈值】当插件提取到新信息时,会根据此相似度阈值(0.0-1.0)查找与之相关的旧记忆。较高的值会使查找更精确但可能错过相关性稍弱的记忆,较低的值反之。这些找到的记忆将与新信息一同被分析,以决定是更新还是合并。", ) # 语义去重阈值 semantic_dedup_threshold: float = Field( default=0.9, description="【语义去重判断阈值】在保存最终记忆前,用于判断两条信息是否在语义上重复的相似度阈值(0.0-1.0)。值越高,去重要求越严格(几乎完全相同才算重复)。建议设为较高值(如0.9)以避免错误地丢弃相似但不相同的信息。", ) # 最大记忆查询数量 max_memory_query_k: int = Field( default=1000, description="【记忆检索上限】在进行记忆相关操作(如后台摘要、相似度查询)时,从数据库中一次性检索的最大记忆条数。此设置用于控制单次操作的内存消耗和性能开销。", ) # 智能去重查询数量 intelligent_dedup_k: int = Field( default=5, description="【智能去重分析数量】在进行智能去重时,为每条新信息查找最相关的K条已有记忆,并将它们一同交由LLM判断是「新增」、「更新」还是「重复」。K值越大,上下文越丰富,判断可能越准,但API成本也越高。", ) # 按量摘要任务配置 enable_on_demand_summarization: bool = Field( default=True, description="【按量摘要-开关】启用或禁用按量记忆摘要功能。启用后,当某个用户的记忆累积到一定数量时,会自动触发一次后台摘要任务。", ) summarize_after_n_memories: int = Field( default=10, description="【按量摘要-触发阈值】设定在为某个用户累积了多少条新记忆后,应触发一次摘要任务。这取代了旧的定时器机制。", ) summarization_cluster_threshold: float = Field( default=0.65, description="【摘要-聚类阈值】在摘要任务中,用于将相似的旧记忆聚类成簇的相似度阈值(0.0-1.0)。值越高,要求记忆间的关联性越强才能被归为一类进行摘要。", ) summarization_min_cluster_size: int = Field( default=3, description="【摘要-最小簇大小】设定一个记忆簇(一组相关的记忆)至少需要包含多少条记忆,才能被触发摘要处理。这可以防止对少量相关信息进行不必要的摘要。", ) summarization_min_memory_age_days: int = Field( default=1, description="【摘要-记忆最小年龄】设定记忆需要「陈旧」到多少天以上,才会被纳入摘要的处理范围。这可以防止对近期、可能仍在变化的记忆进行过早的固化处理。", ) # 嵌入模型配置 embedding_api_url: str = Field( default="https://api.openai.com/v1/embeddings", description="【嵌入API】当嵌入模式为'api'时,指定用于生成嵌入向量的API端点地址。", ) embedding_api_key: str = Field( default="", description="【嵌入API密钥】访问嵌入API所需的密钥。如果留空,将默认使用上方的【核心API密钥】。", ) embedding_model: str = Field( default="text-embedding-3-small", description="【嵌入模型名称】指定使用的嵌入模型名称。对于'local'模式,这通常是sentence-transformers的模型标识;对于'api'模式,这是服务提供方指定的模型名称(如 'text-embedding-3-small')。", ) enable_embedding_cache: bool = Field( default=True, description="【嵌入缓存-开关】启用或禁用嵌入向量的内存缓存。启用后,插件会将计算过的文本嵌入向量存储在内存中,显著提升后续操作的性能并减少不必要的重复计算或API调用。", ) embedding_cache_size_limit: int = Field( default=10000, description="【嵌入缓存-容量上限】设定内存中最多可以缓存多少条嵌入向量。当缓存数量超过此限制时,将自动清理最旧的条目以控制内存占用。", ) # L2 磁盘缓存配置 persist_embeddings_to_disk: bool = Field( default=True, description="【L2缓存-开关】将嵌入向量持久化到磁盘,以在重启后恢复,大幅减少API调用。", ) def __init__(self): self.valves = self.Valves() self.start_time = None # 改造一:重新加入用于计算首字时间的变量 self.time_to_first_token = None self.first_chunk_received = False # 应用猴子补丁 Filter._apply_monkey_patches() @classmethod def _apply_monkey_patches(cls): """应用猴子补丁,拦截delete_memory_by_id函数""" if cls._original_delete_memory_by_id is None: # 保存原始函数引用 cls._original_delete_memory_by_id = delete_memory_by_id # 替换模块中的函数 import open_webui.routers.memories open_webui.routers.memories.delete_memory_by_id = cls._delete_memory_with_cache_cleanup # 也替换当前模块的导入 import sys current_module = sys.modules[__name__] current_module.delete_memory_by_id = cls._delete_memory_with_cache_cleanup print("[INFO] Monkey patch applied: delete_memory_by_id now includes cache cleanup") @classmethod async def _delete_memory_with_cache_cleanup(cls, memory_id: str, user): """包装函数:在删除记忆前先清理对应的L1和L2缓存条目""" try: # 1. 在删除前,尝试获取记忆内容用于生成L1缓存键 content_to_cleanup = None try: # 使用现有的query_memory方法获取所有记忆,然后找到对应ID的记忆 dummy_request = cls._get_dummy_request_static() query_result = await query_memory( dummy_request, QueryMemoryForm(content=" ", k=10000), # 使用大数值确保获取所有记忆 user, ) if (query_result and hasattr(query_result, 'documents') and hasattr(query_result, 'ids') and query_result.documents and query_result.ids and len(query_result.documents) > 0 and len(query_result.ids) > 0): docs = query_result.documents[0] ids = query_result.ids[0] # 查找匹配的记忆ID for i, doc_id in enumerate(ids): if doc_id == memory_id and i < len(docs): content_to_cleanup = docs[i] break except Exception as e: print(f"[WARNING] Failed to get memory content before deletion for L1 cache cleanup: {e}") # 2. 如果找到了记忆内容,清理相关的L1缓存条目 if content_to_cleanup: cls._cleanup_cache_for_content(content_to_cleanup, user) # 3. 清理L2磁盘缓存(基于ID,更可靠) cls._cleanup_l2_cache_for_memory_id(memory_id, user) # 4. 调用原始的删除函数 result = await cls._original_delete_memory_by_id(memory_id, user) print(f"[DEBUG] Memory {memory_id} deleted with L1/L2 cache cleanup") return result except Exception as e: print(f"[ERROR] Error in delete_memory_with_cache_cleanup: {e}") # 如果包装函数出错,仍然尝试调用原始函数 return await cls._original_delete_memory_by_id(memory_id, user) @staticmethod def _get_dummy_request_static(): """静态方法版本的_get_dummy_request,用于类方法中""" from fastapi.requests import Request from open_webui.main import app as webui_app return Request(scope={"type": "http", "app": webui_app}) @classmethod def _cleanup_cache_for_content(cls, memory_content: str, user): """(遗留)根据记忆内容清理相关的缓存条目,主要用于L1。L2清理应基于ID。""" try: # 解析记忆内容,获取去除时间戳的版本 content_without_timestamp = cls._parse_memory_content_static(memory_content)[0] # 为不同的API嵌入模型配置生成可能的缓存键 embedding_models = [ "text-embedding-3-small", "text-embedding-ada-002", "text-embedding-3-large", Filter._global_valves.embedding_model if Filter._global_valves else "" ] removed_keys = [] for model in set(filter(None, embedding_models)): # 清理 L1 cache_key = cls._generate_cache_key_static(content_without_timestamp, model) if cache_key in cls._embedding_cache: del cls._embedding_cache[cache_key] removed_keys.append(cache_key[:8] + "...") if removed_keys: cls._cache_stats["cleanups"] += len(removed_keys) print(f"[DEBUG] Cleaned {len(removed_keys)} L1 cache entries for content") except Exception as e: print(f"[WARNING] Error during L1 cache cleanup for memory content: {e}") @classmethod def _cleanup_l2_cache_for_memory_id(cls, memory_id: str, user): """根据 memory_id 清理所有模型下的L2磁盘缓存""" v = cls._global_valves if not v or not v.persist_embeddings_to_disk: return try: base_dir = cls._get_base_cache_dir() # 遍历所有可能的模型目录进行清理 if not os.path.isdir(base_dir): return cleaned_count = 0 for model_dir_name in os.listdir(base_dir): user_mem_dir = os.path.join(base_dir, model_dir_name, "users", user.id, "mem") if os.path.isdir(user_mem_dir): file_path = os.path.join(user_mem_dir, f"{memory_id}.json") if os.path.exists(file_path): try: os.remove(file_path) cleaned_count += 1 print(f"[DEBUG] Deleted L2 cache file: {file_path}") except OSError as e: print(f"[ERROR] Failed to delete L2 cache file {file_path}: {e}") if cleaned_count > 0: print(f"[INFO] Cleaned {cleaned_count} L2 cache files for memory_id {memory_id}") except Exception as e: print(f"[ERROR] Error during L2 cache cleanup for memory_id {memory_id}: {e}") @staticmethod def _parse_memory_content_static(content: str) -> tuple: """静态方法版本的_parse_memory_content""" import re import datetime import pytz match = re.match(r"^(\d{4}年\d{2}月\d{2}日\d{2}点\d{2}分):(.*)", content) if match: try: # 先解析成朴素的datetime对象 dt_obj = datetime.datetime.strptime( match.group(1), "%Y年%m月%d日%H点%M分" ) # 使用默认时区 try: target_tz = pytz.timezone("Asia/Shanghai") except pytz.UnknownTimeZoneError: target_tz = pytz.utc # 将朴素的datetime对象本地化为具有时区信息的对象 dt_obj_with_tz = target_tz.localize(dt_obj) return match.group(2), dt_obj_with_tz.timestamp() except ValueError: return match.group(2), 0 return content, 0 @staticmethod def _generate_cache_key_static(text: str, embedding_model: str) -> str: """静态方法版本的_generate_cache_key""" import hashlib # 组合文本内容和嵌入模型 combined = f"{text}|{embedding_model}" # 生成MD5哈希作为缓存键 return hashlib.md5(combined.encode("utf-8")).hexdigest() def _valves(self): """后台循环读取的全局阀值(便于后续支持热更)""" return getattr(Filter, "_global_valves", self.valves) # (This space is intentionally left blank after removing the old methods) def inlet(self, body: dict, __user__: Optional[dict] = None) -> dict: self.start_time = time.time() # 重置首字时间状态 self.time_to_first_token = None self.first_chunk_received = False # 刷新全局阀值引用(便于后台循环读取较新的配置) Filter._global_valves = self.valves return body # 改造一(续):重新加入 stream 函数 def stream(self, event: dict) -> dict: if not self.first_chunk_received: self.time_to_first_token = time.time() - self.start_time self.first_chunk_received = True return event async def outlet( self, body: dict, __event_emitter__: Callable[[Any], Awaitable[None]], __user__: Optional[dict] = None, ): # 刷新全局阀值引用 Filter._global_valves = self.valves if not self.valves.enabled or not __user__ or len(body.get("messages", [])) < 2: return body user = Users.get_user_by_id(__user__["id"]) # 为该用户调度一次性的后台任务(预热与对账) self._schedule_user_background_tasks(user) # 记录对话响应结束时间(记忆处理开始前) conversation_end_time = time.time() try: memory_result = await self._handle_fact_mode(body, user) except Exception as e: print(f"CRITICAL ERROR in outlet: {e}") import traceback traceback.print_exc() memory_result = {"status": "error", "message": "插件执行出错"} stats_result = self._calculate_stats(body, conversation_end_time) if self.valves.show_stats: await self._show_status( __event_emitter__, memory_result, stats_result, user ) return body # ==================== L2缓存与后台任务 ==================== def _schedule_user_background_tasks(self, user): """为用户调度一次性的预热和对账任务""" v = self._valves() user_id = user.id # 调度对账任务 (固定开启) if v.persist_embeddings_to_disk and user_id not in self._reconciled_users: self._reconciled_users.add(user_id) print(f"[INFO] Scheduling L2 cache reconciliation for user {user_id}") task = asyncio.create_task(self._reconcile_disk_cache_for_user(user)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) # 调度预热任务 (使用 embedding_cache_size_limit 作为预热数量) if v.persist_embeddings_to_disk and v.embedding_cache_size_limit > 0 and user_id not in self._warmed_users: self._warmed_users.add(user_id) print(f"[INFO] Scheduling L2 cache warmup for user {user_id}") task = asyncio.create_task(self._warmup_embeddings_from_disk(user)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) async def _reconcile_disk_cache_for_user(self, user): """对账:清理磁盘上已不存在于DB中的记忆缓存(孤儿文件)""" v = self._valves() # 对账策略固定为开启,孤儿策略固定为删除 if not v.persist_embeddings_to_disk: return print(f"[INFO] Starting L2 cache reconciliation for user {user.id}...") start_time = time.time() try: # 1. 获取DB中所有记忆ID query_result = await query_memory( self._get_dummy_request(), QueryMemoryForm(content=" ", k=v.max_memory_query_k), user, ) db_ids = set(query_result.ids[0]) if query_result and query_result.ids else set() print(f"[DEBUG] Found {len(db_ids)} memories in DB for user {user.id}") # 2. 遍历磁盘缓存,检查孤儿文件 base_dir = self._get_base_cache_dir() if not os.path.isdir(base_dir): return deleted_count = 0 for model_dir_name in os.listdir(base_dir): user_mem_dir = os.path.join(base_dir, model_dir_name, "users", user.id, "mem") if not os.path.isdir(user_mem_dir): continue for filename in os.listdir(user_mem_dir): if filename.endswith(".json"): memory_id = filename[:-5] if memory_id not in db_ids: file_path = os.path.join(user_mem_dir, filename) try: os.remove(file_path) deleted_count += 1 print(f"[DEBUG] Deleted orphan L2 cache file: {file_path}") except OSError as e: print(f"[ERROR] Failed to delete orphan file {file_path}: {e}") elapsed = time.time() - start_time print(f"[INFO] Reconciliation for user {user.id} finished in {elapsed:.2f}s. Deleted {deleted_count} orphan files.") except Exception as e: print(f"[ERROR] Error during L2 cache reconciliation for user {user.id}: {e}") async def _warmup_embeddings_from_disk(self, user): """预热:从磁盘加载Top K记忆的嵌入向量到内存缓存""" v = self._valves() warmup_k = v.embedding_cache_size_limit if not v.persist_embeddings_to_disk or warmup_k <= 0: return print(f"[INFO] Starting L2 cache warmup for user {user.id} (top {warmup_k})...") start_time = time.time() try: # 1. 获取Top K记忆 query_result = await query_memory( self._get_dummy_request(), QueryMemoryForm(content=" ", k=warmup_k), user, ) if not (query_result and query_result.documents and query_result.ids): print("[DEBUG] No memories found for warmup.") return docs = query_result.documents[0] ids = query_result.ids[0] # 2. 准备批量加载 texts_to_warmup = [self._parse_memory_content(doc)[0] for doc in docs] # 3. 调用批量接口,只从磁盘加载 warmed_embeddings = await self._get_embeddings_batch( texts_to_warmup, ids=ids, user=user, disk_only=True ) warmed_count = sum(1 for emb in warmed_embeddings if emb) elapsed = time.time() - start_time print(f"[INFO] Warmup for user {user.id} finished in {elapsed:.2f}s. Warmed {warmed_count}/{len(docs)} embeddings into L1 cache.") except Exception as e: print(f"[ERROR] Error during L2 cache warmup for user {user.id}: {e}") # ==================== 摘要与记忆处理 ==================== async def _trigger_summarization_if_needed(self, user) -> bool: """检查并触发按量摘要任务""" v = self._valves() if not v.enable_on_demand_summarization: return False user_id = user.id # 增加计数器 current_count = self._user_memory_counters.get(user_id, 0) + 1 self._user_memory_counters[user_id] = current_count print(f"[DEBUG] User {user_id} new memory count: {current_count}") # 检查是否达到阈值 if current_count >= v.summarize_after_n_memories: if user_id in self._summarization_running_for_user: print(f"[INFO] Summarization for user {user_id} is already running. Skipping new trigger.") return False print(f"[INFO] Summarization threshold reached for user {user_id}. Triggering background task.") # 重置计数器 self._user_memory_counters[user_id] = 0 # 创建非阻塞的后台任务 task = asyncio.create_task(self._process_user_summarization(user)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) return True return False async def _process_user_summarization(self, user): """为单个用户执行一次完整的记忆摘要流程""" user_id = user.id if user_id in self._summarization_running_for_user: print(f"[INFO] Summarization for user {user_id} is already in progress.") return 0 self._summarization_running_for_user.add(user_id) try: v = self._valves() start_time = time.time() # 更新全局状态 self._last_summarization_status["timestamp"] = start_time self._last_summarization_status["result"] = f"进行中({user.name or user_id[:8]})" print(f"Checking memories for user {user_id}...") all_user_memories = await self._query_memories_by_similarity( " ", user, k=v.max_memory_query_k, threshold=0.0 ) print(f"User {user_id} has {len(all_user_memories)} total memories") min_age_seconds = v.summarization_min_memory_age_days * 24 * 3600 now_ts = time.time() eligible_memories = [ mem for mem in all_user_memories if (now_ts - self._parse_memory_content(mem["content"])[1]) > min_age_seconds ] print(f"User {user_id} has {len(eligible_memories)} eligible old memories (>{v.summarization_min_memory_age_days} days old)") if len(eligible_memories) < v.summarization_min_cluster_size: print(f"User {user_id} has not enough eligible old memories to summarize (need at least {v.summarization_min_cluster_size}).") self._last_summarization_status["result"] = "无需整合" return 0 eligible_ids = {mem["id"] for mem in eligible_memories} clusters, processed_ids = [], set() for mem in eligible_memories: if mem["id"] in processed_ids: continue current_cluster = [mem] processed_ids.add(mem["id"]) related_mems = await self._query_memories_by_similarity( mem["content_without_timestamp"], user, k=10, threshold=v.summarization_cluster_threshold ) for related in related_mems: if related["id"] not in processed_ids and related["id"] in eligible_ids: current_cluster.append(related) processed_ids.add(related["id"]) if len(current_cluster) >= v.summarization_min_cluster_size: clusters.append(current_cluster) if not clusters: print(f"No memory clusters found for user {user_id}.") self._last_summarization_status["result"] = "无需整合" return 0 print(f"Found {len(clusters)} clusters to summarize for user {user_id}.") successful_clusters = 0 for cluster in clusters: try: cluster_content = "\n".join([f"- {m['content']}" for m in cluster]) summary_content = await self._call_llm(MEMORY_SUMMARIZATION_PROMPT, cluster_content) if summary_content: content_with_timestamp = self._add_timestamp_to_content(summary_content) await add_memory( request=self._get_dummy_request(), form_data=AddMemoryForm(content=content_with_timestamp), user=user, ) for m in cluster: await delete_memory_by_id(m["id"], user) successful_clusters += 1 print(f"Successfully summarized {len(cluster)} memories into one for user {user_id}.") except Exception as e: print(f"Error summarizing a cluster for user {user_id}: {e}") # 更新最终状态 if successful_clusters > 0: self._last_summarization_status["result"] = f"整合{successful_clusters}组" else: self._last_summarization_status["result"] = "整合失败" return successful_clusters except Exception as e: print(f"Error in user summarization process for {user_id}: {e}") self._last_summarization_status["result"] = "出错" return 0 finally: # 无论成功失败,最后都从集合中移除用户ID,允许下次触发 self._summarization_running_for_user.discard(user_id) # ==================== 记忆处理主逻辑 ==================== async def _handle_fact_mode(self, body: dict, user) -> dict: conversation_text = self._stringify_conversation(body["messages"]) if not conversation_text: return {"status": "skipped", "message": "无消息"} try: new_facts = await self._call_llm_for_json( FACT_EXTRACTION_PROMPT, conversation_text ) if not new_facts: return {"status": "success", "message": "无新事实", "net_count_delta": 0} print(f"[DEBUG] 提取到 {len(new_facts)} 条新事实: {new_facts}") facts_to_consolidate = [] for fact in new_facts: # 使用与智能去重一致的放宽阈值,提升召回效果 search_threshold = self._clamp01( max(0.6, self.valves.consolidation_threshold - 0.05) ) print(f"[DEBUG] 为事实 '{fact}' 查找相关记忆,阈值: {search_threshold}") related_memories = await self._query_memories_by_similarity( fact, user, k=self.valves.intelligent_dedup_k, threshold=search_threshold, ) print(f"[DEBUG] 找到 {len(related_memories)} 条相关记忆用于上下文整合") consolidation_input = [{"fact": fact, "created_at": time.time()}] for mem in related_memories: content, ts = self._parse_memory_content(mem["content"]) consolidation_input.append({"fact": content, "created_at": ts}) print( f"[DEBUG] 添加相关记忆到整合上下文: ID={mem['id']}, 内容='{content}'" ) facts_to_consolidate.append(consolidation_input) final_facts_to_save = [] for fact_group in facts_to_consolidate: prompt_input_json = json.dumps(fact_group, ensure_ascii=False) cleaned_facts = await self._call_llm_for_json( FACT_CONSOLIDATION_PROMPT, prompt_input_json ) final_facts_to_save.extend(cleaned_facts) # 增强去重:使用语义相似度检查而不仅仅是字符串比较 unique_facts = await self._semantic_deduplication(final_facts_to_save, user) saved_count = 0 updated_count = 0 skipped_count = 0 net_count_delta = 0 # 用于显示层的净增量补偿 for fact in unique_facts: result = await self._store_memory_intelligent(fact, user) if result["success"]: if result["status"] == "update": updated_count += 1 # 新存1条 - 删除deleted_count条 net_count_delta += 1 - result.get("deleted_count", 0) elif result["status"] == "partial_update": # 有删除错误,但总量变化仍按实际删除数计算 net_count_delta += 1 - result.get("deleted_count", 0) elif result["status"] == "new": saved_count += 1 net_count_delta += 1 elif result["status"] == "duplicate": skipped_count += 1 # 重复不改变数量 # 优化的状态消息:显示最终结果和具体条数 if saved_count > 0 and updated_count > 0: final_message = f"新增{saved_count}条, 更新{updated_count}条" elif saved_count > 0: final_message = "新增记忆" if saved_count == 1 else f"新增{saved_count}条记忆" elif updated_count > 0: final_message = "更新记忆" if updated_count == 1 else f"更新{updated_count}条记忆" elif skipped_count > 0: final_message = "信息重复" else: final_message = "无新记忆" return { "status": "success", "message": final_message, "net_count_delta": net_count_delta, } except Exception as e: import traceback traceback.print_exc() return {"status": "error", "message": f"事实整合失败: {e}", "net_count_delta": 0} # 改造三:将时间戳格式精确到分钟 def _add_timestamp_to_content(self, content: str) -> str: try: target_tz = pytz.timezone(self.valves.timezone) except pytz.UnknownTimeZoneError: target_tz = pytz.utc now = datetime.datetime.now(target_tz) return f"{now.strftime('%Y年%m月%d日%H点%M分')}:{content}" async def _store_new_memory(self, content: str, user) -> bool: # 使用智能去重系统 result = await self._store_memory_intelligent(content, user) return result.get("success", False) async def _store_memory_intelligent(self, content: str, user) -> dict: """智能存储记忆,包含去重逻辑""" action, ids_to_delete = await self._intelligent_deduplication(content, user) if action == "skip": return { "success": False, "status": "duplicate", "message": "信息重复,已跳过", "deleted_count": 0, } # 删除要替换的旧记忆 deleted_count = 0 delete_errors = [] print(f"[DEBUG] 准备删除 {len(ids_to_delete)} 条旧记忆: {ids_to_delete}") for mem_id in ids_to_delete: try: print(f"[DEBUG] 正在删除记忆ID: {mem_id}") await delete_memory_by_id(mem_id, user) deleted_count += 1 print(f"[DEBUG] 成功删除记忆ID: {mem_id}") except Exception as e: error_msg = f"删除记忆失败 {mem_id}: {e}" print(f"[ERROR] {error_msg}") delete_errors.append(error_msg) # 存储新记忆 try: content_with_timestamp = self._add_timestamp_to_content(content) print(f"[DEBUG] 正在存储新记忆: {content_with_timestamp}") await add_memory( request=self._get_dummy_request(), form_data=AddMemoryForm(content=content_with_timestamp), user=user, ) print(f"[DEBUG] 成功存储新记忆") # 触发按量摘要检查 summarization_triggered = await self._trigger_summarization_if_needed(user) if delete_errors: return { "success": True, "status": "partial_update", "message": f"已存储新记忆,但删除过程有{len(delete_errors)}个错误", "deleted_count": deleted_count, "delete_errors": delete_errors, "summarization_triggered": summarization_triggered, } elif deleted_count > 0: return { "success": True, "status": "update", "message": f"已更新记忆(替换{deleted_count}条)", "deleted_count": deleted_count, "summarization_triggered": summarization_triggered, } else: return { "success": True, "status": "new", "message": "已存储新记忆", "deleted_count": 0, "summarization_triggered": summarization_triggered, } except Exception as e: error_msg = f"记忆存储失败: {e}" print(f"[ERROR] {error_msg}") return { "success": False, "status": "error", "message": error_msg, "deleted_count": deleted_count, "delete_errors": delete_errors, } @staticmethod def _clamp01(x: float) -> float: """将值限制在[0.0, 1.0]范围内""" return max(0.0, min(1.0, x)) async def _intelligent_deduplication( self, new_content: str, user ) -> Tuple[str, list]: """智能判断:重复、更新还是新信息 返回: (动作, 要删除的记忆ID列表) 动作: "store", "skip" """ # 查找相关记忆,使用比配置更低的阈值来捕获更多潜在相关记忆 search_threshold = self._clamp01( max(0.6, self.valves.consolidation_threshold - 0.05) ) related = await self._query_memories_by_similarity( new_content, user, k=self.valves.intelligent_dedup_k, threshold=search_threshold, ) if not related: return "store", [] # 全新信息 # 构建分析提示词 - 简化为三种情况 related_contents = [] for r in related: content_without_timestamp, _ = self._parse_memory_content(r["content"]) related_contents.append(content_without_timestamp) system_prompt = "你是一个严格的记忆关系判别器。请只输出一个单词:duplicate / update / new。不要包含解释、标点或任何额外字符。" user_prompt = f"""分析新信息与已有记忆的关系类型: 新信息: {new_content} 相关记忆: {chr(10).join([f"- {content}" for content in related_contents])} 请判断关系类型,只返回以下之一: - duplicate: 信息重复,无需存储 - update: 信息更新,应替换旧信息 - new: 新信息,直接存储""" try: relationship = await self._call_llm(system_prompt, user_prompt) relationship = relationship.strip().lower() # 鲁棒解析:容忍标点、前后缀等 label = "" if "duplicate" in relationship: label = "duplicate" elif "update" in relationship: label = "update" elif "new" in relationship: label = "new" if label == "update": # 删除相关的旧记忆,存储新信息 return "store", [mem["id"] for mem in related] elif label == "duplicate": return "skip", [] else: # "new" 或其他情况都当作新信息处理 return "store", [] except Exception as e: print(f"智能去重分析失败: {e},回退到简单去重") # 回退到简单相似度判断,使用更保守的阈值 fallback_threshold = self._clamp01( max(0.85, self.valves.consolidation_threshold + 0.1) ) if related and related[0]["similarity"] > fallback_threshold: return "skip", [] else: return "store", [] # 改造三(续):修正解析和剥离时间戳的逻辑 def _parse_memory_content(self, content: str) -> (str, float): match = re.match(r"^(\d{4}年\d{2}月\d{2}日\d{2}点\d{2}分):(.*)", content) if match: try: # 先解析成朴素的datetime对象 dt_obj = datetime.datetime.strptime( match.group(1), "%Y年%m月%d日%H点%M分" ) # 获取用户配置的时区 try: target_tz = pytz.timezone(self.valves.timezone) except pytz.UnknownTimeZoneError: target_tz = pytz.utc # 将朴素的datetime对象本地化为具有时区信息的对象 dt_obj_with_tz = target_tz.localize(dt_obj) return match.group(2), dt_obj_with_tz.timestamp() except ValueError: return match.group(2), 0 return content, 0 async def _query_memories_by_similarity( self, text: str, user, k: int, threshold: float ) -> List[dict]: return await self._query_memories_semantic(text, user, k, threshold) async def _query_memories_semantic( self, text: str, user, k: int = 1, threshold: float = 0.0 ) -> List[dict]: # 获取所有原始记忆 - 直接使用内置的query_memory函数 query_result = await query_memory( self._get_dummy_request(), QueryMemoryForm(content=" ", k=self.valves.max_memory_query_k), user, ) # 安全校验,避免 IndexError docs = getattr(query_result, "documents", None) if query_result else None metas = getattr(query_result, "metadatas", None) if query_result else None dists = getattr(query_result, "distances", None) if query_result else None ids = getattr(query_result, "ids", None) if query_result else None if ( not docs or len(docs) == 0 or not isinstance(docs[0], list) or len(docs[0]) == 0 ): return [] doc0 = docs[0] meta0 = metas[0] if metas and len(metas) > 0 else [{} for _ in doc0] dist0 = dists[0] if dists and len(dists) > 0 else [1.0 for _ in doc0] ids0 = ids[0] if ids and len(ids) > 0 else [None for _ in doc0] # 长度对齐,防止后续索引越界 if len(meta0) != len(doc0): meta0 = [meta0[i] if i < len(meta0) else {} for i in range(len(doc0))] if len(dist0) != len(doc0): dist0 = [dist0[i] if i < len(dist0) else 1.0 for i in range(len(doc0))] if len(ids0) != len(doc0): ids0 = [ids0[i] if i < len(ids0) else None for i in range(len(doc0))] # 转换为标准格式 all_memories_raw = [] for i, doc_content in enumerate(doc0): content_without_timestamp, _ = self._parse_memory_content(doc_content) all_memories_raw.append( { "id": ids0[i], "content": doc_content, "content_without_timestamp": content_without_timestamp, "similarity": 1 - dist0[i], } ) if not all_memories_raw: return [] # 获取查询文本的嵌入向量 query_embedding = await self._get_embedding(text) if not query_embedding: print( "Failed to get query embedding, skip semantic search for this run (return empty)." ) return [] # 批量获取所有记忆的嵌入向量以提高效率 memory_texts = [mem["content_without_timestamp"] for mem in all_memories_raw] memory_embeddings = await self._get_embeddings_batch(memory_texts, ids=ids0, user=user) if len(memory_embeddings) != len(all_memories_raw): print("Embedding count mismatch, processing individually") # 回退到逐个处理 results = [] for mem in all_memories_raw: mem_embedding = await self._get_embedding( mem["content_without_timestamp"] ) if mem_embedding: similarity = self._compute_cosine_similarity( query_embedding, mem_embedding ) if similarity >= threshold: mem["similarity"] = float(similarity) results.append(mem) else: # 批量处理相似度计算 results = [] for i, mem in enumerate(all_memories_raw): similarity = self._compute_cosine_similarity( query_embedding, memory_embeddings[i] ) if similarity >= threshold: mem["similarity"] = float(similarity) results.append(mem) # 按相似度排序并返回前k个结果 results.sort(key=lambda x: x["similarity"], reverse=True) return results[:k] async def _call_llm(self, system_prompt: str, user_prompt: str = "") -> str: messages = [{"role": "system", "content": system_prompt}] if user_prompt: messages.append({"role": "user", "content": user_prompt}) url = self.valves.api_url headers = { "Authorization": f"Bearer {self.valves.api_key}", "Content-Type": "application/json", } payload = {"model": self.valves.model, "messages": messages, "temperature": 0.0} async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as response: response.raise_for_status() data = await response.json() return data["choices"][0]["message"]["content"].strip() async def _get_embedding_via_api(self, texts: List[str]) -> List[List[float]]: """通过外部API获取文本嵌入向量""" api_key = self.valves.embedding_api_key or self.valves.api_key if not api_key: print("ERROR: No API key provided for embedding service") return [] url = self.valves.embedding_api_url headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } # 构建请求负载 payload = {"model": self.valves.embedding_model, "input": texts} try: async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as response: response.raise_for_status() data = await response.json() # 提取嵌入向量 embeddings = [] for item in data["data"]: embeddings.append(item["embedding"]) return embeddings except Exception as e: print(f"ERROR: Failed to get embeddings via API: {e}") return [] async def _get_embedding(self, text: str, user=None, memory_id: str = None) -> Optional[List[float]]: """获取嵌入向量,支持L1内存缓存和L2磁盘缓存""" v = self._valves() cache_key = self._generate_cache_key(text) # 1. L1内存缓存查询 if v.enable_embedding_cache: cached_embedding = self._get_cached_embedding(cache_key) if cached_embedding is not None: return cached_embedding # 2. L2磁盘缓存查询 if v.persist_embeddings_to_disk: disk_path = self._get_path_for_embedding(text, user, memory_id) embedding_data = self._read_embedding_from_disk(disk_path) if embedding_data: # L2命中,回填L1并返回 self._cache_embedding(cache_key, embedding_data['embedding'], text) return embedding_data['embedding'] # 3. 缓存未命中,通过API计算 embedding = None try: embeddings = await self._get_embedding_via_api([text]) if embeddings: embedding = embeddings[0] except Exception as e: print(f"API embedding error: {e}") return None if embedding is None: print("API embedding failed to return a vector.") return None # 4. 缓存结果到L1和L2 self._cache_embedding(cache_key, embedding, text) if v.persist_embeddings_to_disk: disk_path = self._get_path_for_embedding(text, user, memory_id) self._write_embedding_to_disk(disk_path, text, embedding) return embedding async def _get_embeddings_batch( self, texts: List[str], ids: List[str] = None, user=None, disk_only: bool = False ) -> List[List[float]]: """批量获取嵌入向量,支持L1/L2缓存,并能区分mem/adhoc路径""" v = self._valves() final_results = [[] for _ in texts] # 待处理项 uncached_indices = list(range(len(texts))) # 1. L1 内存缓存查询 if v.enable_embedding_cache: next_uncached = [] for i in uncached_indices: cache_key = self._generate_cache_key(texts[i]) cached = self._get_cached_embedding(cache_key) if cached: final_results[i] = cached else: next_uncached.append(i) uncached_indices = next_uncached if not uncached_indices: return final_results # 2. L2 磁盘缓存查询 if v.persist_embeddings_to_disk: next_uncached = [] for i in uncached_indices: memory_id = ids[i] if ids and i < len(ids) else None disk_path = self._get_path_for_embedding(texts[i], user, memory_id) embedding_data = self._read_embedding_from_disk(disk_path) if embedding_data: final_results[i] = embedding_data['embedding'] # 回填 L1 if v.enable_embedding_cache: cache_key = self._generate_cache_key(texts[i]) self._cache_embedding(cache_key, embedding_data['embedding'], texts[i]) else: next_uncached.append(i) uncached_indices = next_uncached if not uncached_indices or disk_only: return final_results # 3. API 批量计算 texts_to_fetch = [texts[i] for i in uncached_indices] new_embeddings = [] try: new_embeddings = await self._get_embedding_via_api(texts_to_fetch) if len(new_embeddings) != len(texts_to_fetch): print("API batch embedding returned incomplete results.") new_embeddings = [] except Exception as e: print(f"API batch embedding error: {e}") new_embeddings = [] # 4. 缓存新结果到 L1 和 L2 if new_embeddings: for i, embedding in enumerate(new_embeddings): original_index = uncached_indices[i] text = texts[original_index] final_results[original_index] = embedding # L1 缓存 if v.enable_embedding_cache: cache_key = self._generate_cache_key(text) self._cache_embedding(cache_key, embedding, text) # L2 缓存 if v.persist_embeddings_to_disk: memory_id = ids[original_index] if ids and original_index < len(ids) else None disk_path = self._get_path_for_embedding(text, user, memory_id) self._write_embedding_to_disk(disk_path, text, embedding) return final_results def _compute_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: """计算余弦相似度""" try: import numpy as np vec1 = np.array(vec1) vec2 = np.array(vec2) denom = (np.linalg.norm(vec1) * np.linalg.norm(vec2)) if denom == 0: return 0.0 return float(np.dot(vec1, vec2) / denom) except Exception as e: print(f"Error computing similarity: {e}") return 0.0 def _generate_cache_key(self, text: str) -> str: """生成缓存键,包含文本内容和嵌入模型""" import hashlib # 组合文本内容和嵌入模型 config_str = self.valves.embedding_model combined = f"{text}|{config_str}" # 生成MD5哈希作为缓存键 return hashlib.md5(combined.encode("utf-8")).hexdigest() @classmethod def _get_cached_embedding( cls, cache_key: str, enable_cache: bool = True ) -> Optional[List[float]]: """从L1内存缓存获取嵌入向量""" if not enable_cache: return None if cache_key in cls._embedding_cache: cls._cache_stats["hits"] += 1 # 更新时间戳,用于LRU清理 cls._embedding_cache[cache_key]['timestamp'] = time.time() return cls._embedding_cache[cache_key]["embedding"].copy() else: cls._cache_stats["misses"] += 1 return None @classmethod def _cache_embedding( cls, cache_key: str, embedding: List[float], text: str, enable_cache: bool = True, size_limit: int = 10000, ): """缓存嵌入向量到L1内存""" v = cls._global_valves enable_cache = v.enable_embedding_cache size_limit = v.embedding_cache_size_limit if not enable_cache: return # LRU 缓存清理 if len(cls._embedding_cache) >= size_limit: # 移除最老的条目 try: oldest_key = min(cls._embedding_cache, key=lambda k: cls._embedding_cache[k]['timestamp']) del cls._embedding_cache[oldest_key] cls._cache_stats["cleanups"] += 1 except ValueError: pass # 缓存为空 cls._embedding_cache[cache_key] = { "embedding": embedding, "text": text, "timestamp": time.time(), } @classmethod def get_cache_stats(cls) -> dict: """获取缓存统计信息""" total = cls._cache_stats["hits"] + cls._cache_stats["misses"] hit_rate = cls._cache_stats["hits"] / total if total > 0 else 0 return { "cache_size": len(cls._embedding_cache), "cache_hits": cls._cache_stats["hits"], "cache_misses": cls._cache_stats["misses"], "cache_cleanups": cls._cache_stats["cleanups"], "hit_rate": f"{hit_rate:.2%}", } @classmethod def clear_embedding_cache(cls): """清理所有缓存""" cls._embedding_cache.clear() cls._cache_stats = {"hits": 0, "misses": 0, "cleanups": 0} print("Embedding cache cleared") async def _semantic_deduplication(self, facts: List[str], user) -> List[str]: """使用语义相似度进行智能去重""" if not facts: return [] if len(facts) == 1: return facts # 使用本地缓存避免重复计算嵌入向量 emb_cache = {} async def get_emb(text: str): if text not in emb_cache: emb_cache[text] = await self._get_embedding(text) return emb_cache[text] unique_facts = [] for fact in facts: is_duplicate = False fact_embedding = await get_emb(fact) if not fact_embedding: unique_facts.append(fact) continue # 与已选择的唯一事实进行语义相似度比较 for unique_fact in unique_facts: unique_embedding = await get_emb(unique_fact) if unique_embedding: similarity = self._compute_cosine_similarity( fact_embedding, unique_embedding ) # 使用配置的阈值来判断重复 if similarity >= self.valves.semantic_dedup_threshold: is_duplicate = True print( f"发现重复事实,相似度{similarity:.3f}: '{fact}' ≈ '{unique_fact}'" ) break if not is_duplicate: unique_facts.append(fact) print(f"语义去重完成: {len(facts)} -> {len(unique_facts)}") return unique_facts async def _semantic_deduplication_with_status(self, facts: List[str], fact_status_map: dict, user) -> List[tuple]: """使用语义相似度进行智能去重,同时保留状态信息""" if not facts: return [] if len(facts) == 1: return [(facts[0], fact_status_map.get(facts[0], False))] # 使用本地缓存避免重复计算嵌入向量 emb_cache = {} async def get_emb(text: str): if text not in emb_cache: emb_cache[text] = await self._get_embedding(text) return emb_cache[text] unique_facts_with_status = [] for fact in facts: is_duplicate = False fact_embedding = await get_emb(fact) if not fact_embedding: unique_facts_with_status.append((fact, fact_status_map.get(fact, False))) continue # 与已选择的唯一事实进行语义相似度比较 for unique_fact, _ in unique_facts_with_status: unique_embedding = await get_emb(unique_fact) if unique_embedding: similarity = self._compute_cosine_similarity( fact_embedding, unique_embedding ) # 使用配置的阈值来判断重复 if similarity >= self.valves.semantic_dedup_threshold: is_duplicate = True print( f"发现重复事实,相似度{similarity:.3f}: '{fact}' ≈ '{unique_fact}'" ) break if not is_duplicate: unique_facts_with_status.append((fact, fact_status_map.get(fact, False))) print(f"语义去重完成: {len(facts)} -> {len(unique_facts_with_status)}") return unique_facts_with_status async def _call_llm_for_json( self, system_prompt: str, user_prompt: str = "" ) -> List: content = await self._call_llm(system_prompt, user_prompt) try: if content.startswith("```json"): content = content[7:-3].strip() result = json.loads(content) return result if isinstance(result, list) else [] except json.JSONDecodeError: return [] def _stringify_conversation(self, messages: List[dict]) -> str: count = min(self.valves.messages_to_consider, len(messages)) return "\n".join( [f"- {msg['role']}: {msg['content']}" for msg in messages[-count:]] ) # 修改统计函数,加入首字时间 def _calculate_stats(self, body: dict, conversation_end_time: float) -> dict: # 使用对话结束时间而非记忆处理结束时间 elapsed = conversation_end_time - self.start_time response_msg = get_last_assistant_message(body.get("messages", [])) or "" tokens = len(response_msg) // 3 tps = tokens / elapsed if elapsed > 0 else 0 return { "elapsed": f"{elapsed:.1f}s", "tokens": tokens, "tps": f"{tps:.0f}", "ttft": ( f"{self.time_to_first_token:.2f}s" if self.time_to_first_token is not None else "N/A" ), } async def _get_total_memory_count(self, user, retries: int = 2, delay: float = 0.25) -> int: """获取用户的总记忆条数(带重试,缓解向量索引未及时刷新的“少一条”现象)""" counts = [] for i in range(retries + 1): try: query_result = await query_memory( self._get_dummy_request(), QueryMemoryForm(content=" ", k=self.valves.max_memory_query_k), user, ) if ( query_result and hasattr(query_result, "documents") and query_result.documents and isinstance(query_result.documents[0], list) ): counts.append(len(query_result.documents[0])) else: counts.append(0) except Exception: counts.append(0) if i < retries: await asyncio.sleep(delay * (i + 1)) return max(counts) if counts else 0 # 调整顺序并增加首字时间与数量校正 async def _show_status( self, event_emitter, memory_result: dict, stats_result: dict, user=None ): memory_part = [] if memory_result: status = memory_result.get("status", "skipped") message = memory_result.get("message", "") # 简化的状态图标 if status == "success": if memory_result.get("summarization_triggered"): icon = "🧠" message = "触发摘要" elif "更新" in message: icon = "🔄" elif "新增" in message: icon = "✨" elif "重复" in message: icon = "💤" elif "无" in message: icon = "💤" else: icon = "✅" elif status == "error": icon = "❌" else: icon = "💤" memory_part.append(f" {icon} {message}") stats_parts = [ f"首字{stats_result['ttft']}", f"{stats_result['elapsed']}", f"{stats_result['tps']}TPS", f"≈{stats_result['tokens']}T", ] # 添加记忆统计信息(带净增量校正) if user: try: base_total = await self._get_total_memory_count(user) delta = 0 if memory_result and memory_result.get("status") == "success": delta = int(memory_result.get("net_count_delta", 0)) display_total = max(0, base_total + delta) stats_parts.append(f"共{display_total}条") # 显示最后记忆整合任务状态 last_summarization = self._last_summarization_status if last_summarization["timestamp"]: # 使用用户配置的时区来显示时间 try: target_tz = pytz.timezone(self.valves.timezone) except pytz.UnknownTimeZoneError: target_tz = pytz.utc last_time = datetime.datetime.fromtimestamp( last_summarization["timestamp"], tz=target_tz ) time_str = last_time.strftime("%m-%d %H:%M") result_str = last_summarization["result"] or "完成" stats_parts.append(f"🔄整合:{time_str}({result_str})") except Exception as e: print(f"获取记忆统计信息时出错: {e}") # 按照优化后的顺序组合 final_parts = memory_part + stats_parts if final_parts: await event_emitter( { "type": "status", "data": {"description": " | ".join(final_parts), "done": True}, } ) def _get_dummy_request(self) -> Request: return Request(scope={"type": "http", "app": webui_app}) # ==================== L2缓存辅助函数 ==================== def _get_base_cache_dir(self) -> str: """获取L2缓存的统一基目录""" return os.path.join(".", "data", "super_memory_cache") def _get_safe_model_name(self) -> str: """获取一个可用作目录名的安全模型名称""" model_name = self._valves().embedding_model return re.sub(r'[^a-zA-Z0-9_.-]', '_', model_name) def _get_path_for_embedding(self, text: str, user=None, memory_id: str = None) -> str: """根据有无memory_id决定使用mem路径还是adhoc路径""" import hashlib v = self._valves() safe_model = self._get_safe_model_name() base_dir = self._get_base_cache_dir() if user and memory_id: # mem 路径 path = os.path.join(base_dir, safe_model, "users", user.id, "mem", f"{memory_id}.json") else: # adhoc 路径 text_hash = hashlib.md5(text.encode("utf-8")).hexdigest() path = os.path.join(base_dir, safe_model, "adhoc", f"{text_hash}.json") return path def _ensure_dir(self, file_path: str): """确保文件路径所在的目录存在""" try: directory = os.path.dirname(file_path) if not os.path.exists(directory): os.makedirs(directory, exist_ok=True) except Exception as e: print(f"[ERROR] Failed to create directory for {file_path}: {e}") def _read_embedding_from_disk(self, path: str) -> Optional[dict]: """从磁盘读取单个嵌入向量文件""" if not self._valves().persist_embeddings_to_disk or not os.path.exists(path): return None try: with open(path, "r", encoding="utf-8") as f: return json.load(f) except (IOError, json.JSONDecodeError) as e: print(f"[WARNING] Failed to read or decode L2 cache file {path}: {e}") return None def _write_embedding_to_disk(self, path: str, text: str, embedding: List[float]): """向磁盘原子化地写入单个嵌入向量文件""" if not self._valves().persist_embeddings_to_disk: return self._ensure_dir(path) data = { "v": "1.0", "text": text, "embedding_model": self._valves().embedding_model, "embedding": embedding, "dim": len(embedding), "ts": time.time(), } temp_path = f"{path}.{os.getpid()}.tmp" try: with open(temp_path, "w", encoding="utf-8") as f: json.dump(data, f) os.replace(temp_path, path) print(f"[INFO] Successfully wrote L2 cache to {path}") except Exception as e: print(f"[ERROR] Failed to write L2 cache file to {path}: {e}") if os.path.exists(temp_path): try: os.remove(temp_path) except OSError: pass