We're Hiring!
Whitepaper
Docs
Sign In
@pessini
ยท
3 months ago
ยท
3 months ago
function
LangGraph with human-in-the-loop
Get
Last Updated
3 months ago
Created
3 months ago
Function
pipe
v0.1.0
Name
LangGraph with human-in-the-loop
Downloads
159+
Description
A custom Pipe for LangGraph with real-time human-in-the-loop control.
Function Code
Show
""" title: LangGraph Pipe with human-in-the-loop author: Leandro Pessini - https://github.com/pessini version: 0.1.0 requires: langgraph-sdk v0.2.6 license: MIT """ import logging from typing import AsyncGenerator, Dict, Optional, Union from langgraph_sdk.client import get_client from langgraph_sdk.schema import Command from pydantic import BaseModel, Field logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # "LANGGRAPH_SERVER_URL", "http://langgraph-api:8000" # When running via docker # When running via langgraph dev and open webui is installed via pip # "LANGGRAPH_SERVER_URL", "http://localhost:2024" # "LANGGRAPH_SERVER_URL", "http://host.docker.internal:2024", # When running via langgraph dev class Pipe: class Valves(BaseModel): LANGGRAPH_SERVER_URL: str = Field(default="http://localhost:2024") MODEL_ID: str = Field(default="agent") VERSION: str = Field(default="v1.0") DEBUG_LOGGING: bool = Field(default=False) SHOW_TOOL_CALLS: bool = Field(default=False) class Metadata(BaseModel): user: str user_id: str chat_id: str message_id: str group_id: str user_message: str model_name: str def __init__(self): self.name = "LangGraph Pipe" self.valves = self.Valves() self.client = None def _initialize_client(self): """Initialize the LangGraph client.""" try: if self.client is None: if self.valves.DEBUG_LOGGING: logger.info( f"Initializing LangGraph client with URL: {self.valves.LANGGRAPH_SERVER_URL}" ) self.client = get_client(url=self.valves.LANGGRAPH_SERVER_URL) if self.valves.DEBUG_LOGGING: logger.info("LangGraph client initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize LangGraph client: {e}") self.client = None return False def on_startup(self): """Called when the server is started.""" logger.info(f"on_startup: {__name__}") self._initialize_client() def on_shutdown(self): """Called when the server is shutdown.""" logger.info(f"on_shutdown: {__name__}") if self.client: self.client = None def _extract_metadata( self, body: dict, __metadata__: Optional[Dict] = None, __user__: Optional[Dict] = None, ) -> Metadata: """ Extract and validate metadata from OpenWebUI request. Returns: Metadata object containing validated metadata Raises: ValueError: If required metadata is missing """ if __metadata__ is None: raise ValueError("Missing metadata") # Extract required fields user = __user__.get("email") if __user__ else "anonymous" user_id = __metadata__.get("user_id") chat_id = __metadata__.get("chat_id") message_id = __metadata__.get("message_id") if not user_id or not chat_id: raise ValueError("Missing user_id or chat_id in metadata") messages = body.get("messages", []) if not messages: raise ValueError("No messages provided") # Extract group_id safely try: group_id = ( __metadata__.get("model") .get("info") .get("access_control") .get("read") .get("group_ids")[0] ) except (AttributeError, KeyError, IndexError): group_id = "default" user_message = messages[-1]["content"] # Get model name for status messages model_name = "AI" if __metadata__ and "model" in __metadata__: model_name = __metadata__["model"].get("name", "AI") return self.Metadata( user=user, user_id=user_id, chat_id=chat_id, message_id=message_id, group_id=group_id, user_message=user_message, model_name=model_name, ) def _status(self, message: str = "", done: bool = False) -> Dict: """Helper to create status events.""" return { "event": {"type": "status", "data": {"description": message, "done": done}} } # threads are created independently and then associated with assistants when you start a run. async def create_thread(self, user_id: str, thread_id: str) -> str: """Create or get existing thread.""" if self.client is None: if self.valves.DEBUG_LOGGING: logger.info("LangGraph client is not initialized. Initializing...") if not self._initialize_client(): raise RuntimeError("Failed to initialize LangGraph client.") try: thread = await self.client.threads.create( thread_id=thread_id, metadata={"user_id": user_id}, if_exists="do_nothing", ) return thread["thread_id"] except Exception as e: logger.error(f"Failed to create thread {thread_id}: {e}") raise async def _stream_run( self, thread_id: str, assistant_id: str, input_data: Dict | Command, metadata: Dict = None, config: Dict = None, ) -> AsyncGenerator[Union[str, Dict], None]: """ Common streaming method for both new runs and interrupt handling. """ try: # Get the current state of the thread thread_state = await self.client.threads.get_state(thread_id) # Check for interrupts in tasks interrupts = ( thread_state.get("interrupts", []) if isinstance(thread_state, dict) else getattr(thread_state, "interrupts", []) ) has_interrupts = bool(interrupts) # Check if input_data is a user message is_user_message = isinstance(input_data, dict) and "messages" in input_data is_command_input = False if has_interrupts and is_user_message: if self.valves.DEBUG_LOGGING: logger.info( f"Converting user message to resume command for thread {thread_id}" ) input_data = Command(resume=input_data["messages"]) is_command_input = True except Exception as e: if self.valves.DEBUG_LOGGING: logger.info( f"Could not check thread state for interrupts: {e}, proceeding with normal run" ) # Build base arguments stream_args = { "thread_id": thread_id, "assistant_id": assistant_id, "metadata": metadata or {}, "config": config, "stream_mode": ["messages-tuple", "updates"], } input_key = "command" if is_command_input else "input" stream_args[input_key] = input_data if self.valves.DEBUG_LOGGING: logger.info(f"Starting stream_run with args: {stream_args}") try: async for chunk in self.client.runs.stream(**stream_args): if not chunk or not chunk.data: continue if self.valves.DEBUG_LOGGING: self._pretty_print_chunk(chunk) event_type = chunk.event # Handle different event types if event_type == "updates": # Handle human-in-the-loop interrupts if "__interrupt__" in chunk.data: interrupt_data = chunk.data["__interrupt__"][0] interrupt_id = interrupt_data["id"] interrupt_value = interrupt_data["value"] if self.valves.DEBUG_LOGGING: logger.info( f"Interrupt encountered in thread {thread_id}: id={interrupt_id}, value={interrupt_value!r}" ) yield interrupt_value continue # Stop streaming here, wait for next user message # Show tool calls if not hidden if self.valves.SHOW_TOOL_CALLS: try: # Get the first node data (whatever the node name is) node_data = next(iter(chunk.data.values()), None) if node_data and node_data.get("messages"): tool_calls = node_data["messages"][0].get( "tool_calls", [] ) if tool_calls: tool_name = tool_calls[0].get("name") if tool_name: yield self._status( f"๐ ๏ธ Tool {tool_name} requested...", done=False, ) except (IndexError, KeyError, AttributeError, TypeError) as e: logger.error(f"Error extracting tool call info: {e}") elif event_type == "messages": message, metadata_chunk, *_ = chunk.data if isinstance(message, dict): content = message.get("content") if not content: continue # Skip if this is from a tools node if ( isinstance(metadata_chunk, dict) and metadata_chunk.get("langgraph_node") == "tools" ): continue yield self._status() # Clear status yield content # Yield the actual message content except Exception as e: logger.error(f"Streaming error: {e}") yield self._status(f"โ Streaming error: {str(e)}", done=True) async def pipe( self, body: dict, __metadata__: Optional[Dict] = None, __user__: Optional[Dict] = None, ) -> AsyncGenerator[Union[str, Dict], None]: """ Main pipe function for OpenWebUI integration. """ try: # Extract and validate metadata metadata = self._extract_metadata(body, __metadata__, __user__) # Show initial status yield self._status( f"๐ง {metadata.model_name} is thinking... Please wait a moment." ) # Create thread thread_id = await self.create_thread(metadata.user_id, metadata.chat_id) # Prepare metadata and input data run_metadata = { "langfuse_user_id": metadata.user, "langfuse_tags": [metadata.group_id], "langfuse_session_id": thread_id, "message_id": metadata.message_id, } # Always use the same flow - _stream_run will handle interrupts internally input_data = {"messages": metadata.user_message} first_chunk = True async for chunk in self._stream_run( thread_id=thread_id, assistant_id=self.valves.MODEL_ID, input_data=input_data, metadata=run_metadata, ): # Clear status on first chunk if first_chunk: yield self._status() # Empty message clears status first_chunk = False yield chunk except ValueError as e: logger.error(f"Validation error: {e}") yield self._status(f"โ Error: {str(e)}", done=True) return except Exception as e: logger.error(f"LangGraph pipe error: {e}", exc_info=True) yield self._status(f"โ Error: {str(e)}", done=True) def _pretty_print_chunk(self, chunk): """Basic pretty printing with pprint""" if chunk.event.upper() == "MESSAGES": return import pprint print(f"\n{'=' * 60}") print(f"EVENT TYPE: {chunk.event.upper()}") print(f"{'=' * 60}") # Convert chunk data to dict for better formatting if hasattr(chunk, "data") and chunk.data: pprint.pprint(chunk.data, indent=2, width=120, depth=10) else: print("No data available")
Sponsored by Open WebUI Inc.
We are hiring!
Shape the way humanity engages with
intelligence
.