From f01d03e24058c99688ec9697fa3a6e5891f2c52a Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio <ozz@stacklok.com> Date: Thu, 30 Jan 2025 16:20:10 +0200 Subject: [PATCH] Add logic to repopulate mux cache in case a model changes. Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> --- src/codegate/providers/crud/crud.py | 7 +++++++ src/codegate/workspaces/crud.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 024825fb0..d09f2cf5c 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -12,6 +12,7 @@ from codegate.db.connection import DbReader, DbRecorder from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry +from codegate.workspaces import crud as workspace_crud logger = structlog.get_logger("codegate") @@ -32,6 +33,7 @@ class ProviderCrud: def __init__(self): self._db_reader = DbReader() self._db_writer = DbRecorder() + self._ws_crud = workspace_crud.WorkspaceCrud() async def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]: """List all the endpoints.""" @@ -176,6 +178,9 @@ async def update_endpoint( ) ) + # a model might have been deleted, let's repopulate the cache + await self._ws_crud.repopulate_mux_cache() + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) async def configure_auth_material( @@ -208,6 +213,8 @@ async def delete_endpoint(self, provider_id: UUID): await self._db_writer.delete_provider_endpoint(dbendpoint) + await self._ws_crud.repopulate_mux_cache() + async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelByProvider]: """Get the models by provider.""" diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 4081350ab..70ac2b187 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -361,6 +361,11 @@ async def initialize_mux_registry(self): if active_ws: self._mux_registry.set_active_workspace(active_ws.name) + return self.repopulate_mux_cache() + + async def repopulate_mux_cache(self): + """Repopulate the mux cache with all muxes in the database""" + # Get all workspaces workspaces = await self.get_workspaces()