Skip to content

Commit

Permalink
Implement muxing routing cache
Browse files Browse the repository at this point in the history
This allows us to have a parsed and in-memory representation of the
routing rule engine. The intent is to reduce calls to the database.

This is more expensive to maintain since we need to refresh the cache on
every operation pertaining to models, endpoints, workspaces, and muxes themselves.

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX committed Jan 30, 2025
1 parent d24c989 commit bc42187
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from codegate.providers.copilot.provider import CopilotProvider
from codegate.server import init_app
from codegate.storage.utils import restore_storage_backup
from codegate.workspaces import crud as wscrud


class UvicornServer:
Expand Down Expand Up @@ -341,6 +342,8 @@ def serve( # noqa: C901

registry = app.provider_registry
loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry))
wsc = wscrud.WorkspaceCrud()
loop.run_until_complete(wsc.initialize_mux_registry())

# Run the server
try:
Expand Down
16 changes: 16 additions & 0 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,22 @@ async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[Provid
)
return provider[0] if provider else None

async def get_auth_material_by_provider_id(
self, provider_id: str
) -> Optional[ProviderAuthMaterial]:
sql = text(
"""
SELECT id as provider_endpoint_id, auth_type, auth_blob
FROM provider_endpoints
WHERE id = :provider_endpoint_id
"""
)
conditions = {"provider_endpoint_id": provider_id}
auth_material = await self._exec_select_conditions_to_pydantic(
ProviderAuthMaterial, sql, conditions, should_raise=True
)
return auth_material[0] if auth_material else None

async def get_provider_endpoints(self) -> List[ProviderEndpoint]:
sql = text(
"""
Expand Down
Empty file added src/codegate/muxing/__init__.py
Empty file.
118 changes: 118 additions & 0 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import copy
from abc import ABC, abstractmethod
from collections import UserDict
from threading import Lock, RLock
from typing import List, Optional

from codegate.db import models as db_models

_muxrules_sgtn = None

_singleton_lock = Lock()


def get_muxing_rules_registry():
"""Returns a singleton instance of the muxing rules registry."""

global _muxrules_sgtn

if _muxrules_sgtn is None:
with _singleton_lock:
if _muxrules_sgtn is None:
_muxrules_sgtn = MuxingRulesinWorkspaces()

return _muxrules_sgtn


class ModelRoute:
"""A route for a model."""

def __init__(
self,
model: db_models.ProviderModel,
endpoint: db_models.ProviderEndpoint,
auth_material: db_models.ProviderAuthMaterial,
):
self.model = model
self.endpoint = endpoint
self.auth_material = auth_material


class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""

def __init__(self, route: ModelRoute):
self._route = route

@abstractmethod
def match(self, thing_to_match) -> bool:
"""Return True if the rule matches the thing_to_match."""
pass

def destination(self) -> ModelRoute:
"""Return the destination of the rule."""

return self._route


class MuxingMatcherFactory:
"""Factory for creating muxing matchers."""

@staticmethod
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""

factory = {
"catch_all": CatchAllMuxingRuleMatcher,
}

try:
return factory[mux_rule.matcher_type](route)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")


class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""

def match(self, thing_to_match) -> bool:
return True


class MuxingRulesinWorkspaces(UserDict):
"""A thread safe dictionary to store the muxing rules in workspaces."""

def __init__(self):
super().__init__()
self._lock = RLock()
self._active_workspace = ""

def __getitem__(self, key: str) -> List[MuxingRuleMatcher]:
with self._lock:
# We return a copy so concurrent modifications don't affect the original
return copy.deepcopy(super().__getitem__(key))

def __setitem__(self, key: str, value: List[MuxingRuleMatcher]):
with self._lock:
super().__setitem__(key, value)

def __delitem__(self, key: str):
with self._lock:
super().__delitem__(key)

def set_active_workspace(self, workspace_id: str):
"""Set the active workspace."""
self._active_workspace = workspace_id

def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
"""Get the first match for the given thing_to_match."""

# We iterate over all the rules and return the first match
# Since we already do a deepcopy in __getitem__, we don't need to lock here
try:
for rule in self[self._active_workspace]:
if rule.match(thing_to_match):
return rule.destination()
return None
except KeyError:
raise RuntimeError("No rules found for the active workspace")
122 changes: 107 additions & 15 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import datetime
from typing import List, Optional, Tuple
from uuid import uuid4 as uuid
Expand All @@ -11,6 +10,7 @@
WorkspaceRow,
WorkspaceWithSessionInfo,
)
from codegate.muxing import rulematcher


class WorkspaceCrudError(Exception):
Expand All @@ -37,8 +37,12 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError):

class WorkspaceCrud:

def __init__(self):
def __init__(
self,
mux_registry: rulematcher.MuxingRulesinWorkspaces = rulematcher.get_muxing_rules_registry(),
):
self._db_reader = DbReader()
self._mux_registry = mux_registry

async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
"""
Expand Down Expand Up @@ -135,6 +139,9 @@ async def activate_workspace(self, workspace_name: str):
session.last_update = datetime.datetime.now(datetime.timezone.utc)
db_recorder = DbRecorder()
await db_recorder.update_session(session)

# Ensure the mux registry is updated
self._mux_registry.set_active_workspace(workspace.id)
return

async def recover_workspace(self, workspace_name: str):
Expand Down Expand Up @@ -189,6 +196,9 @@ async def soft_delete_workspace(self, workspace_name: str):
_ = await db_recorder.soft_delete_workspace(selected_workspace)
except Exception:
raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}")

# Remove the muxes from the registry
del self._mux_registry[workspace_name]
return

async def hard_delete_workspace(self, workspace_name: str):
Expand Down Expand Up @@ -243,6 +253,8 @@ async def get_muxes(self, workspace_name: str):

# Can't use type hints since the models are not yet defined
async def set_muxes(self, workspace_name: str, muxes):
from codegate.api import v1_models

# Verify if workspace exists
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
if not workspace:
Expand All @@ -252,23 +264,19 @@ async def set_muxes(self, workspace_name: str, muxes):
db_recorder = DbRecorder()
await db_recorder.delete_muxes_by_workspace(workspace.id)

tasks = set()

# Add the new muxes
priority = 0

muxes_with_routes: List[Tuple[v1_models.MuxRule, rulematcher.ModelRoute]] = []

# Verify all models are valid
for mux in muxes:
dbm = await self._db_reader.get_provider_model_by_provider_id_and_name(
mux.provider_id,
mux.model,
)
if not dbm:
raise WorkspaceCrudError(
f"Model {mux.model} does not exist for provider {mux.provider_id}"
)
route = await self.get_routing_for_mux(mux)
muxes_with_routes.append((mux, route))

for mux in muxes:
matchers: List[rulematcher.MuxingRuleMatcher] = []

for mux, route in muxes_with_routes:
new_mux = MuxRule(
id=str(uuid()),
provider_endpoint_id=mux.provider_id,
Expand All @@ -278,8 +286,92 @@ async def set_muxes(self, workspace_name: str, muxes):
matcher_blob=mux.matcher if mux.matcher else "",
priority=priority,
)
tasks.add(db_recorder.add_mux(new_mux))
dbmux = await db_recorder.add_mux(new_mux)

matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route))

priority += 1

await asyncio.gather(*tasks)
# Set routing list for the workspace
self._mux_registry[workspace_name] = matchers

async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
"""Get the routing for a mux
Note that this particular mux object is the API model, not the database model.
It's only not annotated because of a circular import issue.
"""
dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id)
if not dbprov:
raise WorkspaceCrudError(f"Provider {mux.provider_id} does not exist")

dbm = await self._db_reader.get_provider_model_by_provider_id_and_name(
mux.provider_id,
mux.model,
)
if not dbm:
raise WorkspaceCrudError(
f"Model {mux.model} does not exist for provider {mux.provider_id}"
)
dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_id)
if not dbauth:
raise WorkspaceCrudError(f"Auth material for provider {mux.provider_id} does not exist")

return rulematcher.ModelRoute(
provider=dbprov,
model=dbm,
auth=dbauth,
)

async def get_routing_for_db_mux(self, mux: MuxRule) -> rulematcher.ModelRoute:
"""Get the routing for a mux
Note that this particular mux object is the database model, not the API model.
It's only not annotated because of a circular import issue.
"""
dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_endpoint_id)
if not dbprov:
raise WorkspaceCrudError(f"Provider {mux.provider_endpoint_id} does not exist")

dbm = await self._db_reader.get_provider_model_by_provider_id_and_name(
mux.provider_endpoint_id,
mux.provider_model_name,
)
if not dbm:
raise WorkspaceCrudError(
f"Model {mux.provider_model_name} does not "
"exist for provider {mux.provider_endpoint_id}"
)
dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_endpoint_id)
if not dbauth:
raise WorkspaceCrudError(
f"Auth material for provider {mux.provider_endpoint_id} does not exist"
)

return rulematcher.ModelRoute(
model=dbm,
endpoint=dbprov,
auth_material=dbauth,
)

async def initialize_mux_registry(self):
"""Initialize the mux registry with all workspaces in the database"""

active_ws = await self.get_active_workspace()
if active_ws:
self._mux_registry.set_active_workspace(active_ws.name)

# Get all workspaces
workspaces = await self.get_workspaces()

# For each workspace, get the muxes and set them in the registry
for ws in workspaces:
muxes = await self._db_reader.get_muxes_by_workspace(ws.id)

matchers: List[rulematcher.MuxingRuleMatcher] = []

for mux in muxes:
route = await self.get_routing_for_db_mux(mux)
matchers.append(rulematcher.MuxingMatcherFactory.create(mux, route))

self._mux_registry[ws.name] = matchers

0 comments on commit bc42187

Please sign in to comment.