Skip to content

Commit

Permalink
Config: Isolate to a separate file
Browse files Browse the repository at this point in the history
Reduce dependency of globals in main to simplify code a bit.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Dec 24, 2023
1 parent 0d2e726 commit c9126c3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 34 deletions.
46 changes: 46 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import yaml
import pathlib

from logger import init_logger
from utils import unwrap

logger = init_logger(__name__)

GLOBAL_CONFIG: dict = {}

def read_config_from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG

try:
with open(str(config_path), "r", encoding="utf8") as config_file:
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
GLOBAL_CONFIG = {}

def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})

def get_draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})

def get_lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})

def get_network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})

def get_gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
55 changes: 21 additions & 34 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import pathlib
import uvicorn
import yaml
from asyncio import CancelledError
from typing import Optional
from uuid import uuid4
Expand All @@ -14,6 +13,14 @@

import gen_logging
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
read_config_from_file,
get_gen_logging_config,
get_model_config,
get_draft_model_config,
get_lora_config,
get_network_config
)
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
from OAI.types.completion import CompletionRequest
Expand Down Expand Up @@ -48,7 +55,6 @@

# Globally scoped variables. Undefined until initalized in main
MODEL_CONTAINER: Optional[ModelContainer] = None
config: dict = {}


def _check_model_container():
Expand All @@ -71,12 +77,11 @@ def _check_model_container():
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = unwrap(config.get("model"), {})
model_config = get_model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)

draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir")
draft_model_dir = get_draft_model_config().get("draft_model_dir")

models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
Expand Down Expand Up @@ -127,9 +132,7 @@ async def get_current_model():
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
model_config = unwrap(config.get("model"), {})
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)

models = get_model_list(draft_model_path.resolve())
Expand All @@ -149,21 +152,19 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")

model_config = unwrap(config.get("model"), {})
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = model_path / data.name

load_data = data.model_dump()

draft_config = unwrap(model_config.get("draft"), {})
if data.draft:
if not data.draft.draft_model_name:
raise HTTPException(
400, "draft_model_name was not found inside the draft object."
)

load_data["draft"]["draft_model_dir"] = unwrap(
draft_config.get("draft_model_dir"), "models"
get_draft_model_config().get("draft_model_dir"), "models"
)

if not model_path.exists():
Expand Down Expand Up @@ -240,10 +241,7 @@ async def unload_model():
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))

lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())

return loras
Expand Down Expand Up @@ -281,9 +279,7 @@ async def load_lora(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")

model_config = unwrap(config.get("model"), {})
lora_config = unwrap(model_config.get("lora"), {})
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
raise HTTPException(
400,
Expand Down Expand Up @@ -478,33 +474,24 @@ async def generator():


if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
try:
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
config = {}
# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))

network_config = unwrap(config.get("network"), {})
network_config = get_network_config()

# Initialize auth keys
load_auth_keys(unwrap(network_config.get("disable_auth"), False))

# Override the generation log options if given
log_config = unwrap(config.get("logging"), {})
log_config = get_gen_logging_config()
if log_config:
gen_logging.update_from_dict(log_config)

gen_logging.broadcast_status()

# If an initial model name is specified, create a container
# and load the model
model_config = unwrap(config.get("model"), {})
model_config = get_model_config()
if "model_name" in model_config:
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_config.get("model_name")
Expand All @@ -521,7 +508,7 @@ async def generator():
loading_bar.next()

# Load loras
lora_config = unwrap(model_config.get("lora"), {})
lora_config = get_lora_config()
if "loras" in lora_config:
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
Expand Down

0 comments on commit c9126c3

Please sign in to comment.