From f01d03e24058c99688ec9697fa3a6e5891f2c52a Mon Sep 17 00:00:00 2001
From: Juan Antonio Osorio <ozz@stacklok.com>
Date: Thu, 30 Jan 2025 16:20:10 +0200
Subject: [PATCH] Add logic to repopulate mux cache in case a model changes.

Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com>
---
 src/codegate/providers/crud/crud.py | 7 +++++++
 src/codegate/workspaces/crud.py     | 5 +++++
 2 files changed, 12 insertions(+)

diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py
index 024825fb0..d09f2cf5c 100644
--- a/src/codegate/providers/crud/crud.py
+++ b/src/codegate/providers/crud/crud.py
@@ -12,6 +12,7 @@
 from codegate.db.connection import DbReader, DbRecorder
 from codegate.providers.base import BaseProvider
 from codegate.providers.registry import ProviderRegistry, get_provider_registry
+from codegate.workspaces import crud as workspace_crud
 
 logger = structlog.get_logger("codegate")
 
@@ -32,6 +33,7 @@ class ProviderCrud:
     def __init__(self):
         self._db_reader = DbReader()
         self._db_writer = DbRecorder()
+        self._ws_crud = workspace_crud.WorkspaceCrud()
 
     async def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]:
         """List all the endpoints."""
@@ -176,6 +178,9 @@ async def update_endpoint(
                 )
             )
 
+        # a model might have been deleted, let's repopulate the cache
+        await self._ws_crud.repopulate_mux_cache()
+
         return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
 
     async def configure_auth_material(
@@ -208,6 +213,8 @@ async def delete_endpoint(self, provider_id: UUID):
 
         await self._db_writer.delete_provider_endpoint(dbendpoint)
 
+        await self._ws_crud.repopulate_mux_cache()
+
     async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelByProvider]:
         """Get the models by provider."""
 
diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py
index 4081350ab..70ac2b187 100644
--- a/src/codegate/workspaces/crud.py
+++ b/src/codegate/workspaces/crud.py
@@ -361,6 +361,11 @@ async def initialize_mux_registry(self):
         if active_ws:
             self._mux_registry.set_active_workspace(active_ws.name)
 
+        return self.repopulate_mux_cache()
+
+    async def repopulate_mux_cache(self):
+        """Repopulate the mux cache with all muxes in the database"""
+
         # Get all workspaces
         workspaces = await self.get_workspaces()