"""
title: Image OCR Filter
author: jojo112211
author_url: nope
funding_url: stillnope
version: 0.2 加入三次错误尝试 差点把自己的API泄露了QWQ
"""
from pydantic import BaseModel, Field
from typing import Callable, Awaitable, Any, Optional
import asyncio
import httpx
import base64
import time
class Filter:
class Valves(BaseModel):
priority: int = Field(
default=0, description="Priority level for filtering operations."
)
API_KEY: str = Field(
default="", description="Baidu OCR API Key"
)
SECRET_KEY: str = Field(
default="",
description="Baidu OCR Secret Key",
)
MAX_RETRIES: int = Field(default=3, description="最大重试次数")
def __init__(self):
self.valves = self.Valves()
self.access_token = None
self.token_lock = asyncio.Lock()
async def _get_access_token(self):
"""获取百度OCR access token(带重试功能)"""
async with self.token_lock:
if self.access_token:
return self.access_token
retries = 0
last_error = None
while retries < self.valves.MAX_RETRIES:
try:
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.valves.API_KEY,
"client_secret": self.valves.SECRET_KEY,
}
async with httpx.AsyncClient() as client:
response = await client.post(url, params=params)
if response.status_code != 200:
raise Exception(
f"获取token失败,状态码: {response.status_code},响应: {response.text}"
)
result = response.json()
self.access_token = result.get("access_token")
if not self.access_token:
raise Exception(f"响应中未找到access_token: {result}")
return self.access_token
except Exception as e:
last_error = e
retries += 1
if retries < self.valves.MAX_RETRIES:
# 指数退避策略,等待时间随重试次数增加
wait_time = 2**retries
print(
f"获取token失败,{retries}/{self.valves.MAX_RETRIES}次重试,等待{wait_time}秒: {str(e)}"
)
await asyncio.sleep(wait_time)
# 所有重试都失败
raise Exception(
f"获取access_token失败,已重试{self.valves.MAX_RETRIES}次: {str(last_error)}"
)
async def _perform_ocr(self, image_base64: str, event_emitter):
"""执行OCR识别(带重试功能)"""
retries = 0
last_error = None
while retries < self.valves.MAX_RETRIES:
try:
# 每次尝试前获取新token
token = await self._get_access_token()
url = f"https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={token}"
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
payload = {
"image": image_base64,
"detect_direction": "false",
"detect_language": "false",
"paragraph": "false",
"probability": "false",
}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, data=payload)
result = response.json()
if "error_code" in result:
raise Exception(
f"OCR API错误: {result.get('error_msg', '未知错误')}"
)
words = [item["words"] for item in result.get("words_result", [])]
return "\n".join(words)
except Exception as e:
last_error = e
retries += 1
if retries < self.valves.MAX_RETRIES:
# 更新状态通知
await event_emitter(
{
"type": "status",
"data": {
"description": f"⚠️ OCR识别失败,正在进行第{retries}次重试...",
"done": False,
},
}
)
# 指数退避策略
wait_time = 2**retries
await asyncio.sleep(wait_time)
# 所有重试都失败
raise Exception(
f"OCR识别失败,已重试{self.valves.MAX_RETRIES}次: {str(last_error)}"
)
def _find_image_in_messages(self, messages):
"""在消息中查找图片"""
for m_index, message in enumerate(messages):
if message["role"] == "user" and isinstance(message.get("content"), list):
for c_index, content in enumerate(message["content"]):
if content["type"] == "image_url":
return m_index, c_index, content["image_url"]["url"]
return None
async def inlet(
self,
body: dict,
__event_emitter__: Callable[[Any], Awaitable[None]],
__user__: Optional[dict] = None,
__model__: Optional[dict] = None,
) -> dict:
messages = body.get("messages", [])
# 查找图片
image_info = self._find_image_in_messages(messages)
if not image_info:
return body
message_index, content_index, image_base64 = image_info
try:
# 显示开始处理状态
await __event_emitter__(
{
"type": "status",
"data": {
"description": "✨正在进行文字识别,请稍候...",
"done": False,
},
}
)
# 执行OCR(带重试机制)
ocr_result = await self._perform_ocr(image_base64, __event_emitter__)
# 显示完成状态
await __event_emitter__(
{
"type": "status",
"data": {
"description": "✅文字识别完成!",
"done": True,
},
}
)
# 更新消息内容
messages[message_index]["content"][content_index] = {
"type": "text",
"text": f"用户上传了一张图片,以下是图片转文字的结果:\n\n{ocr_result}",
}
body["messages"] = messages
except Exception as e:
print(f"OCR错误: {str(e)}")
await __event_emitter__(
{
"type": "status",
"data": {
"description": f"❌识别失败: {str(e)}",
"done": True,
},
}
)
return body
async def outlet(
self,
body: dict,
__event_emitter__: Callable[[Any], Awaitable[None]],
__user__: Optional[dict] = None,
__model__: Optional[dict] = None,
) -> dict:
return body