Function
pipe
v3.1.1
CostTrackingPipe
A pipe function to change API keys on the fly as well as optionally remove 'thinking' blocks
Function ID
costtrackingpipe
Downloads
36+

Function Content
python
"""
title: CostTrackingPipe
author: thiswillbeyourgithub
author_url: https://github.com/thiswillbeyourgithub/openwebui_custom_pipes_filters/
funding_url: https://github.com/thiswillbeyourgithub/openwebui_custom_pipes_filters/
version: 3.1.1
date: 2024-08-21
license: GPLv3
description: A pipe function to track user costs and remove 'thinking' blocks
"""

from typing import List, Union, Generator, Iterator, Callable, Any, Optional
from pydantic import BaseModel, Field
import requests
import os
import re
import time
import json

DEFAULT_BASE_URL = "http://127.0.0.1:4000"
DEFAULT_CHAT_MODEL = "litellm_sonnet-3.5"
DEFAULT_TITLE_CHAT_MODEL = "litellm_gpt-4o-mini"


class Pipe:

    class Valves(BaseModel):
        LITELLM_BASE_URL: str = DEFAULT_BASE_URL
        api_keys: Optional[str] = Field(
            default=None,
            description="Dict where keys are litellm users and values are their virtual api keys (a string that will be json loaded as a dict). Leave to None if you want to load from env 'COSTTRACKINGPIPE_API_KEYS'",
        )

    class UserValves(BaseModel):
        enabled: bool = Field(default=True, description="True to enable price counting")
        chat_model: str = Field(
            default=DEFAULT_CHAT_MODEL, description="Chat model to use"
        )
        title_chat_model: str = Field(
            default=DEFAULT_TITLE_CHAT_MODEL,
            description="Model to use to generate titles",
        )

        remove_thoughts: bool = Field(
            default=True, description="True to remove the thoughts block"
        )
        start_thoughts: str = Field(
            default="^``` ?thinking", description="Start of thought block"
        )
        stop_thoughts: str = Field(default="```", description="End of thought block")
        debug: bool = Field(
            default=False,
            description="Set to True to print more info to the docker logs, also to not remove the last emitter message.",
        )

    def __init__(self):
        # You can also set the pipelines that are available in this pipeline.
        # Set manifold to True if you want to use this pipeline as a manifold.
        # Manifold pipelines can have multiple pipelines.
        self.type = "manifold"

        # Optionally, you can set the id and name of the pipeline.
        # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
        # The identifier must be unique across all pipelines.
        # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
        self.id = "cost_tracking_pipe"

        # Optionally, you can set the name of the manifold pipeline.
        self.name = "CostTrackingPipe"

        # Initialize rate limits
        self.valves = self.Valves()
        self.uvalves = self.UserValves()

        self.start_thought = re.compile(self.uvalves.start_thoughts)
        self.stop_thought = re.compile(self.uvalves.stop_thoughts)
        self.pattern = re.compile(
            self.uvalves.start_thoughts + "(.*)?" + self.uvalves.stop_thoughts,
            flags=re.DOTALL | re.MULTILINE,
        )

    async def on_valves_updated(self):
        """This function is called when the valves are updated."""

        # just checking the validity of the api_keys
        if self.valves.api_keys is None:
            assert (
                "COSTTRACKINGPIPE_API_KEYS" in os.environ
            ), "You left the valve api_keys to None but didn't set an env variable COSTTRACKINGPIPE_API_KEYS"
            api_keys = os.environ["COSTTRACKINGPIPE_API_KEYS"]
        else:
            api_keys = self.valves.api_keys
        assert isinstance(
            api_keys, str
        ), f"Expected api_keys to be a str at this point, not {type(api_keys)}"
        try:
            api_keys = json.loads(api_keys)
            assert isinstance(
                api_keys, dict
            ), f"Expected api_keys to be a dict at this point, not {type(api_keys)}"
        except Exception as err:
            raise Exception(f"Error when casting api_keys from str to dict: '{err}'")

        assert "default" in api_keys, f"No 'default' key found in dict: {api_keys}"

    async def pipe(
        self,
        body: dict,
        __user__: dict,
        __event_emitter__: Callable[[dict], Any] = None,
        # this is just to debug if there are breaking changes etc
        *args,
        **kwargs,
    ) -> Union[str, Generator, Iterator]:

        # load the api_keys as a dict
        if self.valves.api_keys is None:
            assert (
                "COSTTRACKINGPIPE_API_KEYS" in os.environ
            ), "You left the valve api_keys to None but didn't set an env variable COSTTRACKINGPIPE_API_KEYS"
            api_keys = os.environ["COSTTRACKINGPIPE_API_KEYS"]
        else:
            api_keys = self.valves.api_keys
        assert isinstance(
            api_keys, str
        ), f"Expected api_keys to be a str at this point, not {type(api_keys)}"
        try:
            api_keys = json.loads(api_keys)
            assert isinstance(
                api_keys, dict
            ), f"Expected api_keys to be a dict at this point, not {type(api_keys)}"
        except Exception as err:
            raise Exception(f"Error when casting api_keys from str to dict: '{err}'")

        assert "default" in api_keys, f"No 'default' key found in dict: {api_keys}"

        # prints and emitter to show progress
        def pprint(message: str) -> str:
            print(f"CostTrackingPipe of '{__user__['name']}': " + str(message))
            return message

        emitter = EventEmitter(__event_emitter__)

        async def prog(message: str) -> None:
            await emitter.progress_update(pprint(message))

        async def succ(message: str) -> None:
            await emitter.success_update(pprint(message))

        async def err(message: str) -> None:
            await emitter.error_update(pprint(message))

        # to know in the future if there are new arguments I could use
        if args or kwargs:
            if args:
                pprint("Received args:" + str(args))
            if kwargs:
                pprint("Received kwargs:" + str(kwargs))

        if self.uvalves.debug:
            pprint(body.keys())
            pprint(body)

        # match the api key
        headers = {}
        await prog("Start")
        username = __user__["name"]
        if not self.uvalves.enabled:
            await prog("Disabled api key matching, will use the default key")
            apikey = api_keys["default"]
            headers["Authorization"] = f"Bearer {apikey}"

        else:
            if username in api_keys:
                apikey = api_keys[username]
                await prog(f"Will use key for {username}")
                headers["Authorization"] = f"Bearer {apikey}"
            else:
                apikey = api_keys["default"]
                headers["Authorization"] = f"Bearer {apikey}"
                # await err(f"Username {username} not found in litellm env keys")
                # raise Exception(f"User not found: {username}")
        body["user"] = username

        try:
            if body["stream"]:
                model = self.uvalves.chat_model
            else:
                # stream disabled is only used for the summary title creator AFAIK
                model = self.uvalves.title_chat_model
            payload = {**body, "model": model, "user": username}

            await prog("Waiting for response")
            r = requests.post(
                url=f"{self.valves.LITELLM_BASE_URL}/v1/chat/completions",
                json=payload,
                headers=headers,
                stream=True,
            )

            r.raise_for_status()
            assert r.status_code == 200, f"Invalid status code: {r.status_code}"

            if body["stream"]:
                await prog("Receiving chunks")
                if (not self.uvalves.remove_thoughts) or (not self.uvalves.enabled):
                    for line in r.iter_lines():
                        yield line
                    return
                buffer = ""
                thought_pattern = re.compile(r"``` ?thinking.*?```", re.DOTALL)
                thought_removed = False

                for line in r.iter_lines():
                    if (
                        not self.uvalves.debug
                        and "start_time" in locals()
                        and time.time() - start_time > 1
                    ):
                        # remove this print after 1s
                        await succ("")
                    if line:
                        line = line.decode("utf-8")
                        if line.startswith("data: "):
                            line = line[6:]  # Remove "data: " prefix
                        if line.strip() == "[DONE]":
                            break
                        try:
                            parsed_line = json.loads(line)
                        except (json.JSONDecodeError, KeyError):
                            continue

                        content = parsed_line["choices"][0]["delta"].get("content", "")
                        if not content:
                            continue

                        if thought_removed:
                            yield content
                            continue
                        buffer += content
                        match = thought_pattern.search(buffer)
                        if match:
                            # Remove the thought block
                            buffer = buffer[: match.start()] + buffer[match.end() :]
                            yield buffer
                            buffer = ""
                            thought_removed = True
                            await succ("Removed thought block")
                            start_time = time.time()

                if not thought_removed:
                    # model didn't produce a thought (for example can happen for the chat title)
                    await succ("Thought block never found")
                    yield buffer
                    buffer = ""

                if buffer:  # Yield any remaining content with finish_reason "stop"
                    yield buffer

            else:  # return the whole text directly
                await prog("Returning directly")
                j = r.json()
                to_yield = j["choices"][0]["message"].get("content", "")
                yield to_yield

            if not self.uvalves.debug:
                await succ("")  # hides it
            return

        except Exception as e:
            await err(f"Error: {e}")
            raise


class EventEmitter:
    def __init__(self, event_emitter: Callable[[dict], Any] = None):
        self.event_emitter = event_emitter

    async def progress_update(self, description):
        await self.emit(description)

    async def error_update(self, description):
        await self.emit(description, "error", True)

    async def success_update(self, description):
        await self.emit(description, "success", True)

    async def emit(self, description="Unknown State", status="in_progress", done=False):
        if self.event_emitter:
            await self.event_emitter(
                {
                    "type": "status",
                    "data": {
                        "status": status,
                        "description": description,
                        "done": done,
                    },
                }
            )