From c9126c3145f09abd41e884e54058fae545cb0dde Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 23 Dec 2023 23:02:37 -0500 Subject: [PATCH] Config: Isolate to a separate file Reduce dependency of globals in main to simplify code a bit. Signed-off-by: kingbri --- config.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ main.py | 55 +++++++++++++++++++++---------------------------------- 2 files changed, 67 insertions(+), 34 deletions(-) create mode 100644 config.py diff --git a/config.py b/config.py new file mode 100644 index 00000000..b6694b29 --- /dev/null +++ b/config.py @@ -0,0 +1,46 @@ +import yaml +import pathlib + +from logger import init_logger +from utils import unwrap + +logger = init_logger(__name__) + +GLOBAL_CONFIG: dict = {} + +def read_config_from_file(config_path: pathlib.Path): + """Sets the global config from a given file path""" + global GLOBAL_CONFIG + + try: + with open(str(config_path), "r", encoding="utf8") as config_file: + GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {}) + except Exception as exc: + logger.error( + "The YAML config couldn't load because of the following error: " + f"\n\n{exc}" + "\n\nTabbyAPI will start anyway and not parse this config file." + ) + GLOBAL_CONFIG = {} + +def get_model_config(): + """Returns the model config from the global config""" + return unwrap(GLOBAL_CONFIG.get("model"), {}) + +def get_draft_model_config(): + """Returns the draft model config from the global config""" + model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) + return unwrap(model_config.get("draft"), {}) + +def get_lora_config(): + """Returns the lora config from the global config""" + model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) + return unwrap(model_config.get("lora"), {}) + +def get_network_config(): + """Returns the network config from the global config""" + return unwrap(GLOBAL_CONFIG.get("network"), {}) + +def get_gen_logging_config(): + """Returns the generation logging config from the global config""" + return unwrap(GLOBAL_CONFIG.get("logging"), {}) diff --git a/main.py b/main.py index a12bb2be..46baa471 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import pathlib import uvicorn -import yaml from asyncio import CancelledError from typing import Optional from uuid import uuid4 @@ -14,6 +13,14 @@ import gen_logging from auth import check_admin_key, check_api_key, load_auth_keys +from config import ( + read_config_from_file, + get_gen_logging_config, + get_model_config, + get_draft_model_config, + get_lora_config, + get_network_config +) from generators import call_with_semaphore, generate_with_semaphore from model import ModelContainer from OAI.types.completion import CompletionRequest @@ -48,7 +55,6 @@ # Globally scoped variables. Undefined until initalized in main MODEL_CONTAINER: Optional[ModelContainer] = None -config: dict = {} def _check_model_container(): @@ -71,12 +77,11 @@ def _check_model_container(): @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): """Lists all models in the model directory.""" - model_config = unwrap(config.get("model"), {}) + model_config = get_model_config() model_dir = unwrap(model_config.get("model_dir"), "models") model_path = pathlib.Path(model_dir) - draft_config = unwrap(model_config.get("draft"), {}) - draft_model_dir = draft_config.get("draft_model_dir") + draft_model_dir = get_draft_model_config().get("draft_model_dir") models = get_model_list(model_path.resolve(), draft_model_dir) if unwrap(model_config.get("use_dummy_models"), False): @@ -127,9 +132,7 @@ async def get_current_model(): @app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(): """Lists all draft models in the model directory.""" - model_config = unwrap(config.get("model"), {}) - draft_config = unwrap(model_config.get("draft"), {}) - draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models") + draft_model_dir = unwrap(get_draft_model_config().get("draft_model_dir"), "models") draft_model_path = pathlib.Path(draft_model_dir) models = get_model_list(draft_model_path.resolve()) @@ -149,13 +152,11 @@ async def load_model(request: Request, data: ModelLoadRequest): if not data.name: raise HTTPException(400, "model_name not found.") - model_config = unwrap(config.get("model"), {}) - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models")) model_path = model_path / data.name load_data = data.model_dump() - draft_config = unwrap(model_config.get("draft"), {}) if data.draft: if not data.draft.draft_model_name: raise HTTPException( @@ -163,7 +164,7 @@ async def load_model(request: Request, data: ModelLoadRequest): ) load_data["draft"]["draft_model_dir"] = unwrap( - draft_config.get("draft_model_dir"), "models" + get_draft_model_config().get("draft_model_dir"), "models" ) if not model_path.exists(): @@ -240,10 +241,7 @@ async def unload_model(): @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def get_all_loras(): """Lists all LoRAs in the lora directory.""" - model_config = unwrap(config.get("model"), {}) - lora_config = unwrap(model_config.get("lora"), {}) - lora_path = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) - + lora_path = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras")) loras = get_lora_list(lora_path.resolve()) return loras @@ -281,9 +279,7 @@ async def load_lora(data: LoraLoadRequest): if not data.loras: raise HTTPException(400, "List of loras to load is not found.") - model_config = unwrap(config.get("model"), {}) - lora_config = unwrap(model_config.get("lora"), {}) - lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) + lora_dir = pathlib.Path(unwrap(get_lora_config().get("lora_dir"), "loras")) if not lora_dir.exists(): raise HTTPException( 400, @@ -478,25 +474,16 @@ async def generator(): if __name__ == "__main__": - # Load from YAML config. Possibly add a config -> kwargs conversion function - try: - with open("config.yml", "r", encoding="utf8") as config_file: - config = unwrap(yaml.safe_load(config_file), {}) - except Exception as exc: - logger.error( - "The YAML config couldn't load because of the following error: " - f"\n\n{exc}" - "\n\nTabbyAPI will start anyway and not parse this config file." - ) - config = {} + # Load from YAML config + read_config_from_file(pathlib.Path("config.yml")) - network_config = unwrap(config.get("network"), {}) + network_config = get_network_config() # Initialize auth keys load_auth_keys(unwrap(network_config.get("disable_auth"), False)) # Override the generation log options if given - log_config = unwrap(config.get("logging"), {}) + log_config = get_gen_logging_config() if log_config: gen_logging.update_from_dict(log_config) @@ -504,7 +491,7 @@ async def generator(): # If an initial model name is specified, create a container # and load the model - model_config = unwrap(config.get("model"), {}) + model_config = get_model_config() if "model_name" in model_config: model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_config.get("model_name") @@ -521,7 +508,7 @@ async def generator(): loading_bar.next() # Load loras - lora_config = unwrap(model_config.get("lora"), {}) + lora_config = get_lora_config() if "loras" in lora_config: lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)