diff --git a/src/codegate/cli.py b/src/codegate/cli.py index ba3016eb..be5096f6 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -21,6 +21,7 @@ from codegate.providers.copilot.provider import CopilotProvider from codegate.server import init_app from codegate.storage.utils import restore_storage_backup +from codegate.workspaces import crud as wscrud class UvicornServer: @@ -341,6 +342,8 @@ def serve( # noqa: C901 registry = app.provider_registry loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry)) + wsc = wscrud.WorkspaceCrud() + loop.run_until_complete(wsc.initialize_mux_registry()) # Run the server try: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index cdeef763..6fb8bc69 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -710,6 +710,22 @@ async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[Provid ) return provider[0] if provider else None + async def get_auth_material_by_provider_id( + self, provider_id: str + ) -> Optional[ProviderAuthMaterial]: + sql = text( + """ + SELECT id as provider_endpoint_id, auth_type, auth_blob + FROM provider_endpoints + WHERE id = :provider_endpoint_id + """ + ) + conditions = {"provider_endpoint_id": provider_id} + auth_material = await self._exec_select_conditions_to_pydantic( + ProviderAuthMaterial, sql, conditions, should_raise=True + ) + return auth_material[0] if auth_material else None + async def get_provider_endpoints(self) -> List[ProviderEndpoint]: sql = text( """ diff --git a/src/codegate/muxing/__init__.py b/src/codegate/muxing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py new file mode 100644 index 00000000..3d220a74 --- /dev/null +++ b/src/codegate/muxing/rulematcher.py @@ -0,0 +1,118 @@ +import copy +from abc import ABC, abstractmethod +from collections import UserDict +from threading import Lock, RLock +from typing import List, Optional + +from codegate.db import models as db_models + +_muxrules_sgtn = None + +_singleton_lock = Lock() + + +def get_muxing_rules_registry(): + """Returns a singleton instance of the muxing rules registry.""" + + global _muxrules_sgtn + + if _muxrules_sgtn is None: + with _singleton_lock: + if _muxrules_sgtn is None: + _muxrules_sgtn = MuxingRulesinWorkspaces() + + return _muxrules_sgtn + + +class ModelRoute: + """A route for a model.""" + + def __init__( + self, + model: db_models.ProviderModel, + endpoint: db_models.ProviderEndpoint, + auth_material: db_models.ProviderAuthMaterial, + ): + self.model = model + self.endpoint = endpoint + self.auth_material = auth_material + + +class MuxingRuleMatcher(ABC): + """Base class for matching muxing rules.""" + + def __init__(self, route: ModelRoute): + self._route = route + + @abstractmethod + def match(self, thing_to_match) -> bool: + """Return True if the rule matches the thing_to_match.""" + pass + + def destination(self) -> ModelRoute: + """Return the destination of the rule.""" + + return self._route + + +class MuxingMatcherFactory: + """Factory for creating muxing matchers.""" + + @staticmethod + def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: + """Create a muxing matcher for the given endpoint and model.""" + + factory = { + "catch_all": CatchAllMuxingRuleMatcher, + } + + try: + return factory[mux_rule.matcher_type](route) + except KeyError: + raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") + + +class CatchAllMuxingRuleMatcher(MuxingRuleMatcher): + """A catch all muxing rule matcher.""" + + def match(self, thing_to_match) -> bool: + return True + + +class MuxingRulesinWorkspaces(UserDict): + """A thread safe dictionary to store the muxing rules in workspaces.""" + + def __init__(self): + super().__init__() + self._lock = RLock() + self._active_workspace = "" + + def __getitem__(self, key: str) -> List[MuxingRuleMatcher]: + with self._lock: + # We return a copy so concurrent modifications don't affect the original + return copy.deepcopy(super().__getitem__(key)) + + def __setitem__(self, key: str, value: List[MuxingRuleMatcher]): + with self._lock: + super().__setitem__(key, value) + + def __delitem__(self, key: str): + with self._lock: + super().__delitem__(key) + + def set_active_workspace(self, workspace_id: str): + """Set the active workspace.""" + self._active_workspace = workspace_id + + def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]: + """Get the first match for the given thing_to_match.""" + + # We iterate over all the rules and return the first match + # Since we already do a deepcopy in __getitem__, we don't need to lock here + try: + for rule in self[self._active_workspace]: + if rule.match(thing_to_match): + return rule.destination() + return None + except KeyError: + raise RuntimeError("No rules found for the active workspace") diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 39c23e86..3eb202d2 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -1,4 +1,3 @@ -import asyncio import datetime from typing import List, Optional, Tuple from uuid import uuid4 as uuid @@ -11,6 +10,7 @@ WorkspaceRow, WorkspaceWithSessionInfo, ) +from codegate.muxing import rulematcher class WorkspaceCrudError(Exception): @@ -37,8 +37,12 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): class WorkspaceCrud: - def __init__(self): + def __init__( + self, + mux_registry: rulematcher.MuxingRulesinWorkspaces = rulematcher.get_muxing_rules_registry(), + ): self._db_reader = DbReader() + self._mux_registry = mux_registry async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow: """ @@ -135,6 +139,9 @@ async def activate_workspace(self, workspace_name: str): session.last_update = datetime.datetime.now(datetime.timezone.utc) db_recorder = DbRecorder() await db_recorder.update_session(session) + + # Ensure the mux registry is updated + self._mux_registry.set_active_workspace(workspace.id) return async def recover_workspace(self, workspace_name: str): @@ -189,6 +196,9 @@ async def soft_delete_workspace(self, workspace_name: str): _ = await db_recorder.soft_delete_workspace(selected_workspace) except Exception: raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") + + # Remove the muxes from the registry + del self._mux_registry[workspace_name] return async def hard_delete_workspace(self, workspace_name: str): @@ -243,6 +253,8 @@ async def get_muxes(self, workspace_name: str): # Can't use type hints since the models are not yet defined async def set_muxes(self, workspace_name: str, muxes): + from codegate.api import v1_models + # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: @@ -252,23 +264,19 @@ async def set_muxes(self, workspace_name: str, muxes): db_recorder = DbRecorder() await db_recorder.delete_muxes_by_workspace(workspace.id) - tasks = set() - # Add the new muxes priority = 0 + muxes_with_routes: List[Tuple[v1_models.MuxRule, rulematcher.ModelRoute]] = [] + # Verify all models are valid for mux in muxes: - dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( - mux.provider_id, - mux.model, - ) - if not dbm: - raise WorkspaceCrudError( - f"Model {mux.model} does not exist for provider {mux.provider_id}" - ) + route = await self.get_routing_for_mux(mux) + muxes_with_routes.append((mux, route)) - for mux in muxes: + matchers: List[rulematcher.MuxingRuleMatcher] = [] + + for mux, route in muxes_with_routes: new_mux = MuxRule( id=str(uuid()), provider_endpoint_id=mux.provider_id, @@ -278,8 +286,92 @@ async def set_muxes(self, workspace_name: str, muxes): matcher_blob=mux.matcher if mux.matcher else "", priority=priority, ) - tasks.add(db_recorder.add_mux(new_mux)) + dbmux = await db_recorder.add_mux(new_mux) + + matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route)) priority += 1 - await asyncio.gather(*tasks) + # Set routing list for the workspace + self._mux_registry[workspace_name] = matchers + + async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute: + """Get the routing for a mux + + Note that this particular mux object is the API model, not the database model. + It's only not annotated because of a circular import issue. + """ + dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id) + if not dbprov: + raise WorkspaceCrudError(f"Provider {mux.provider_id} does not exist") + + dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( + mux.provider_id, + mux.model, + ) + if not dbm: + raise WorkspaceCrudError( + f"Model {mux.model} does not exist for provider {mux.provider_id}" + ) + dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_id) + if not dbauth: + raise WorkspaceCrudError(f"Auth material for provider {mux.provider_id} does not exist") + + return rulematcher.ModelRoute( + provider=dbprov, + model=dbm, + auth=dbauth, + ) + + async def get_routing_for_db_mux(self, mux: MuxRule) -> rulematcher.ModelRoute: + """Get the routing for a mux + + Note that this particular mux object is the database model, not the API model. + It's only not annotated because of a circular import issue. + """ + dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_endpoint_id) + if not dbprov: + raise WorkspaceCrudError(f"Provider {mux.provider_endpoint_id} does not exist") + + dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( + mux.provider_endpoint_id, + mux.provider_model_name, + ) + if not dbm: + raise WorkspaceCrudError( + f"Model {mux.provider_model_name} does not " + "exist for provider {mux.provider_endpoint_id}" + ) + dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_endpoint_id) + if not dbauth: + raise WorkspaceCrudError( + f"Auth material for provider {mux.provider_endpoint_id} does not exist" + ) + + return rulematcher.ModelRoute( + model=dbm, + endpoint=dbprov, + auth_material=dbauth, + ) + + async def initialize_mux_registry(self): + """Initialize the mux registry with all workspaces in the database""" + + active_ws = await self.get_active_workspace() + if active_ws: + self._mux_registry.set_active_workspace(active_ws.name) + + # Get all workspaces + workspaces = await self.get_workspaces() + + # For each workspace, get the muxes and set them in the registry + for ws in workspaces: + muxes = await self._db_reader.get_muxes_by_workspace(ws.id) + + matchers: List[rulematcher.MuxingRuleMatcher] = [] + + for mux in muxes: + route = await self.get_routing_for_db_mux(mux) + matchers.append(rulematcher.MuxingMatcherFactory.create(mux, route)) + + self._mux_registry[ws.name] = matchers