Skip to content

Commit

Permalink
don't delete models on provider update (#836)
Browse files Browse the repository at this point in the history
This makes sure that the foreign key references stay intact in the
muxing table.

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX authored Jan 30, 2025
1 parent 16d525f commit d24c989
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
7 changes: 4 additions & 3 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,15 @@ async def add_provider_model(self, model: ProviderModel) -> ProviderModel:
added_model = await self._execute_update_pydantic_model(model, sql, should_raise=True)
return added_model

async def delete_provider_models(self, provider_id: str):
async def delete_provider_model(self, provider_id: str, model: str) -> Optional[ProviderModel]:
sql = text(
"""
DELETE FROM provider_models
WHERE provider_endpoint_id = :provider_endpoint_id
WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name
"""
)
conditions = {"provider_endpoint_id": provider_id}

conditions = {"provider_endpoint_id": provider_id, "name": model}
await self._execute_with_no_return(sql, conditions)

async def delete_muxes_by_workspace(self, workspace_id: str):
Expand Down
32 changes: 23 additions & 9 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,40 @@ async def update_endpoint(
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

# Reset all provider models.
await self._db_writer.delete_provider_models(str(endpoint.id))
models_set = set(models)

for model in models:
# Get the models from the provider
models_in_db = await self._db_reader.get_provider_models_by_provider_id(str(endpoint.id))

models_in_db_set = set(model.name for model in models_in_db)

# Add the models that are in the provider but not in the DB
for model in models_set - models_in_db_set:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=founddbe.id,
name=model,
)
)

# Remove the models that are in the DB but not in the provider
for model in models_in_db_set - models_set:
await self._db_writer.delete_provider_model(
founddbe.id,
model,
)

dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
# If an API key was provided or we've changed the auth type, we update the auth material
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)
)

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

Expand Down

0 comments on commit d24c989

Please sign in to comment.