"""
title: Chat with your Data
author: Marius Raileanu
description: This tool allows users to interact with data stored in Databricks through Genie APIs and leverage your language model to quickly analyze and interpret results.
required_open_webui_version: 0.4.0
requirements:requests>=2.31.0,pydantic>=2.0.0, asyncio
version: 0.1.0
licence: MIT
"""
import json
import requests
import asyncio
from typing import Dict, Optional, Callable, Awaitable, Any, Tuple
from pydantic import BaseModel, Field
from requests.exceptions import RequestException, ConnectionError, Timeout, HTTPError
from dataclasses import dataclass
from enum import Enum
class DatabricksError(Exception):
"""Base exception for Databricks-related errors."""
pass
class APIError(DatabricksError):
"""Exception raised for API-related errors."""
pass
class ConfigurationError(DatabricksError):
"""Exception raised for configuration-related errors."""
pass
class ConnectionError(DatabricksError):
"""Exception raised for connection-related errors."""
pass
class ConversationStatus(Enum):
"""Enum for conversation status values."""
COMPLETED = "COMPLETED"
IN_PROGRESS = "IN_PROGRESS"
FAILED = "FAILED"
@dataclass
class APIResponse:
"""Data class for API responses."""
success: bool
data: Any = None
error: Optional[str] = None
class Helper:
"""Helper class for common operations."""
@staticmethod
def format_table(attachment_json: dict) -> str:
"""
Format attachment JSON data into a Markdown table.
Args:
attachment_json: JSON data of the attachment.
Returns:
String representing the formatted table.
Raises:
ValueError: If the attachment JSON is invalid or missing required fields.
"""
try:
schema = attachment_json["statement_response"]["manifest"]["schema"]
data_rows = attachment_json["statement_response"]["result"]["data_array"]
headers = [col["name"] for col in schema["columns"]]
header_row = "| " + " | ".join(headers) + " |"
separator_row = "| " + " | ".join(["---"] * len(headers)) + " |"
rows = [
"| " + " | ".join(str(cell) for cell in row) + " |" for row in data_rows
]
return "\n".join([header_row, separator_row] + rows)
except KeyError as e:
raise ValueError(f"Invalid attachment JSON structure: missing {e}")
except Exception as e:
raise ValueError(f"Error formatting table: {e}")
@staticmethod
def extract_message_info(
response_data: dict,
) -> Tuple[Optional[str], Optional[str]]:
"""
Extract conversation and message IDs from response data.
Args:
response_data: Response data from the API.
Returns:
Tuple of (conversation_id, message_id).
"""
conversation_id = response_data.get("conversation_id", "")
message_info = response_data.get("message", {}) or response_data
message_id = message_info.get("id") or message_info.get("message_id")
return conversation_id, message_id
class Tools:
def __init__(self):
"""
Initialize the Tools class with optional configuration.
Args:
config: Optional Valves configuration object. If not provided, default values are used.
"""
try:
self.valves = self.Valves()
self.valves.validate()
self.conversations: Dict[str, str] = {}
self.helper = Helper()
except Exception as e:
raise ConfigurationError(f"Configuration error: {str(e)}")
class Valves(BaseModel):
"""Configuration for the Databricks Genie tool."""
base_url: str = Field(
default="https://<your-databricks-instance>.azuredatabricks.net/api/2.0/genie/spaces/<your-space-id>",
description="Base URL for the Databricks Genie API space.",
)
auth_token: str = Field(
default="<your-valid-auth-token>",
description="Authorization token.",
)
poll_max_attempts: int = Field(
default=10, description="Maximum polling attempts for conversation flow."
)
poll_interval: int = Field(
default=3, description="Polling interval in seconds."
)
def validate(self) -> None:
"""Validate configuration values."""
if not self.base_url.startswith(("http://", "https://")):
raise ConfigurationError("base_url must start with http:// or https://")
if not self.auth_token:
raise ConfigurationError("auth_token cannot be empty")
if self.poll_max_attempts < 1:
raise ConfigurationError("poll_max_attempts must be greater than 0")
if self.poll_interval < 1:
raise ConfigurationError("poll_interval must be greater than 0")
def _get_headers(self, include_content_type: bool = True) -> dict:
"""
Create HTTP headers for API calls.
Args:
include_content_type: Whether to include the Content-Type header.
Returns:
Dictionary of HTTP headers.
"""
headers = {"Authorization": f"Bearer {self.valves.auth_token}"}
if include_content_type:
headers["Content-Type"] = "application/json"
return headers
async def _perform_http_request(
self,
method: str,
url: str,
content: str = None,
include_content_type: bool = True,
max_retries: int = 3,
retry_delay: int = 1,
) -> APIResponse:
"""
Perform an HTTP request with retries.
Args:
method: HTTP method ('GET' or 'POST').
url: Endpoint URL.
content: Request body content for POST requests.
include_content_type: Whether to include the Content-Type header.
max_retries: Maximum number of retry attempts.
retry_delay: Delay between retries in seconds.
Returns:
APIResponse object containing the response data or error information.
"""
headers = self._get_headers(include_content_type)
payload = {"content": content} if content is not None else None
for attempt in range(max_retries):
try:
if method.upper() == "GET":
response = requests.get(url, headers=headers, timeout=30)
elif method.upper() == "POST":
response = requests.post(
url,
headers=headers,
data=json.dumps(payload) if payload else None,
timeout=30,
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
# Only attempt to decode JSON if there's content.
if response.text.strip():
json_data = response.json()
else:
json_data = {}
return APIResponse(success=True, data=json_data)
except ConnectionError:
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
return APIResponse(
success=False,
error="Could not connect to the server. Please check your connection and try again.",
)
except Timeout:
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
return APIResponse(
success=False,
error="Request timed out. Please try again later.",
)
except HTTPError as e:
if e.response.status_code == 401:
return APIResponse(
success=False,
error="Authentication failed. Please check your credentials.",
)
elif e.response.status_code == 403:
return APIResponse(
success=False,
error="Access denied. Please check your permissions.",
)
elif e.response.status_code == 404:
return APIResponse(
success=False,
error="Resource not found. Please check the URL and try again.",
)
else:
return APIResponse(success=False, error=f"Server error: {str(e)}")
except RequestException as e:
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
return APIResponse(success=False, error=f"Request failed: {str(e)}")
def get_conversation_id(self, user_id: str) -> str:
"""Retrieve a stored conversation ID for a user."""
return self.conversations.get(user_id, "")
def set_conversation_id(self, user_id: str, conversation_id: str):
"""Store a conversation ID for a user."""
self.conversations[user_id] = conversation_id
async def _emit_status(
self,
__event_emitter__: Callable[[dict], Awaitable[None]],
description: str,
status: str,
done: bool,
):
"""Emit a status update using the provided event emitter."""
await __event_emitter__(
{
"data": {"description": description, "status": status, "done": done},
"type": "status",
}
)
async def start_conversation(
self,
content: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
) -> str:
"""
Start a new conversation by sending initial content to the API.
Returns:
JSON string response from the API.
"""
await self._emit_status(
__event_emitter__, "Starting conversation...", "in_progress", False
)
url = f"{self.valves.base_url}/start-conversation"
try:
response = await self._perform_http_request("POST", url, content)
if not response.success:
return f"Error: {response.error}"
data = response.data
user_id = __user__.get("id", "default_user")
new_conv_id = data.get("conversation_id")
if new_conv_id:
self.set_conversation_id(user_id, new_conv_id)
await self._emit_status(
__event_emitter__,
"Conversation started successfully.",
"complete",
True,
)
return json.dumps(data)
except Exception as e:
await self._emit_status(
__event_emitter__, f"Error starting conversation: {e}", "complete", True
)
return f"Error: {e}"
async def send_message(
self,
conversation_id: str,
content: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
) -> str:
"""
Send a follow-up message to an existing conversation.
Returns:
JSON string response from the API.
"""
await self._emit_status(
__event_emitter__, "Sending follow-up message...", "in_progress", False
)
url = f"{self.valves.base_url}/conversations/{conversation_id}/messages"
try:
response = await self._perform_http_request("POST", url, content)
if not response.success:
return f"Error: {response.error}"
await self._emit_status(
__event_emitter__,
"Follow-up message sent successfully.",
"complete",
True,
)
return json.dumps(response.data)
except Exception as e:
await self._emit_status(
__event_emitter__,
f"Error sending follow-up message: {e}",
"complete",
True,
)
return f"Error: {e}"
async def get_conversation_message(
self,
conversation_id: str,
message_id: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
) -> str:
"""
Retrieve a specific conversation message.
Returns:
JSON string containing the message.
"""
await self._emit_status(
__event_emitter__,
"Retrieving conversation message...",
"in_progress",
False,
)
url = f"{self.valves.base_url}/conversations/{conversation_id}/messages/{message_id}"
try:
response = await self._perform_http_request(
"GET", url, include_content_type=False
)
if not response.success:
return f"Error: {response.error}"
await self._emit_status(
__event_emitter__,
"Conversation message retrieved successfully.",
"complete",
True,
)
return json.dumps(response.data)
except Exception as e:
await self._emit_status(
__event_emitter__,
f"Error retrieving conversation message: {e}",
"complete",
True,
)
return f"Error: {e}"
async def get_query_result_attachment(
self,
conversation_id: str,
message_id: str,
attachment_id: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
) -> str:
"""
Retrieve the query result attachment for a conversation message.
Returns:
JSON string with the attachment data.
"""
await self._emit_status(
__event_emitter__,
"Retrieving query result attachment...",
"in_progress",
False,
)
url = (
f"{self.valves.base_url}/conversations/{conversation_id}/messages/{message_id}/"
f"attachments/{attachment_id}/query-result"
)
try:
response = await self._perform_http_request(
"GET", url, include_content_type=False
)
if not response.success:
return f"Error: {response.error}"
await self._emit_status(
__event_emitter__,
"Query result attachment retrieved successfully.",
"complete",
True,
)
return json.dumps(response.data)
except Exception as e:
await self._emit_status(
__event_emitter__,
f"Error retrieving query result attachment: {e}",
"complete",
True,
)
return f"Error: {e}"
async def _poll_for_message(
self,
conversation_id: str,
message_id: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
) -> dict:
"""
Poll until the conversation message status becomes 'COMPLETED'.
Returns:
The complete message data as a dictionary.
"""
status = ""
for _ in range(self.valves.poll_max_attempts):
message_str = await self.get_conversation_message(
conversation_id, message_id, __event_emitter__, __user__
)
try:
message_data = json.loads(message_str)
except Exception as e:
raise Exception(f"Error parsing message response: {e}")
status = message_data.get("status", "")
if status.upper() == ConversationStatus.COMPLETED.value:
return message_data
await asyncio.sleep(self.valves.poll_interval)
raise Exception(
f"Conversation not completed after waiting. Current status: {status}"
)
async def run_conversation_flow(
self,
content: str,
__event_emitter__: Callable[[dict], Awaitable[None]],
__user__: dict = {},
conversation_id: str = "",
) -> str:
"""
Run the entire conversation flow:
- Start a new conversation or send a follow-up message.
- Poll until the conversation message is 'COMPLETED'.
- Retrieve and format any attachment if available.
- Return a consolidated JSON string result.
"""
user_id = __user__.get("id", "default_user")
if __user__.get("new_chat", False):
self.conversations.pop(user_id, None)
conversation_id = ""
else:
conversation_id = conversation_id or self.get_conversation_id(user_id)
response_str = (
await self.send_message(
conversation_id, content, __event_emitter__, __user__
)
if conversation_id
else await self.start_conversation(content, __event_emitter__, __user__)
)
try:
conv_response = json.loads(response_str)
except Exception as e:
return f"Error parsing conversation response: {e}"
if not conversation_id:
conversation_id = conv_response.get("conversation_id", "")
if not conversation_id:
return "Error: Missing conversation ID in start response."
message_info = conv_response.get("message", {}) or conv_response
message_id = message_info.get("id") or message_info.get("message_id")
if not message_id:
return "Error: Missing message ID in response."
try:
conv_message = await self._poll_for_message(
conversation_id, message_id, __event_emitter__, __user__
)
except Exception as e:
return str(e)
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"message_data": conv_message,
"attachment_data": None,
"formatted_table": None,
}
attachments = conv_message.get("attachments", [])
if attachments:
attachment_id = attachments[0].get("attachment_id", "")
if attachment_id:
attachment_str = await self.get_query_result_attachment(
conversation_id,
message_id,
attachment_id,
__event_emitter__,
__user__,
)
try:
attachment_data = json.loads(attachment_str)
result["attachment_data"] = attachment_data
result["formatted_table"] = self.helper.format_table(
attachment_data
)
except Exception as e:
return f"Error parsing query result attachment: {e}"
return json.dumps(result)