From 30ab8e04b9fabdfb675148821736412d3ec4ab5a Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Mon, 10 Feb 2025 23:14:22 -0500 Subject: [PATCH] Args: Add subcommands to run actions Migrate OpenAPI and sample config export to subcommands "export-openapi" and "export-config". Also add a "download" subcommand that passes args to the TabbyAPI downloader. This allows models to be downloaded via the API and CLI args. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- .github/workflows/pages.yml | 4 +-- common/actions.py | 60 ++++++++++++++++++++++++++----------- common/args.py | 53 ++++++++++++++++++++++++++++++++ common/downloader.py | 4 +-- main.py | 18 +++++++---- start.py | 4 +-- 6 files changed, 112 insertions(+), 31 deletions(-) diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index c3a61f21..d0306dfb 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -48,8 +48,8 @@ jobs: npm install @redocly/cli -g - name: Export OpenAPI docs run: | - python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold - python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI + python main.py export-openapi --export-path "openapi-kobold.json" + python main.py export-openapi --export-path "openapi-oai.json" - name: Build and store Redocly site run: | mkdir static diff --git a/common/actions.py b/common/actions.py index c7f0a717..771f127b 100644 --- a/common/actions.py +++ b/common/actions.py @@ -1,27 +1,51 @@ +import argparse +import asyncio import json from loguru import logger -from common.tabby_config import config, generate_config_file +from common.downloader import hf_repo_download +from common.tabby_config import generate_config_file +from common.utils import unwrap from endpoints.server import export_openapi -def branch_to_actions() -> bool: - """Checks if a optional action needs to be run.""" +def download_action(args: argparse.Namespace): + asyncio.run( + hf_repo_download( + repo_id=args.repo_id, + folder_name=args.folder_name, + revision=args.revision, + token=args.token, + include=args.include, + exclude=args.exclude, + ) + ) - if config.actions.export_openapi: - openapi_json = export_openapi() - with open(config.actions.openapi_export_path, "w") as f: - f.write(json.dumps(openapi_json)) - logger.info( - "Successfully wrote OpenAPI spec to " - + f"{config.actions.openapi_export_path}" - ) - elif config.actions.export_config: - generate_config_file(filename=config.actions.config_export_path) - else: - # did not branch - return False +def config_export_action(args: argparse.Namespace): + export_path = unwrap(args.export_path, "config_sample.yml") + generate_config_file(filename=export_path) - # branched and ran an action - return True + +def openapi_export_action(args: argparse.Namespace): + export_path = unwrap(args.export_path, "openapi.json") + openapi_json = export_openapi() + + with open(export_path, "w") as f: + f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to " + f"{export_path}") + + +def run_subcommand(args: argparse.Namespace) -> bool: + match args.actions: + case "download": + download_action(args) + return True + case "export-config": + config_export_action(args) + return True + case "export-openapi": + openapi_export_action(args) + return True + case _: + return False diff --git a/common/args.py b/common/args.py index 39ba2936..56f9659e 100644 --- a/common/args.py +++ b/common/args.py @@ -37,6 +37,8 @@ def init_argparser( existing_parser, argparse.ArgumentParser(description="TabbyAPI server") ) + add_subcommands(parser) + # Loop through each top-level field in the config for field_name, field_info in TabbyConfigModel.model_fields.items(): field_type = unwrap_optional_type(field_info.annotation) @@ -59,6 +61,57 @@ def init_argparser( return parser +def add_subcommands(parser: argparse.ArgumentParser): + """Adds subcommands to an existing argparser""" + + actions_subparsers = parser.add_subparsers( + dest="actions", help="Extra actions that can be run instead of the main server." + ) + + # Calls download action + download_parser = actions_subparsers.add_parser( + "download", help="Calls the model downloader" + ) + download_parser.add_argument("repo_id", type=str, help="HuggingFace repo ID") + download_parser.add_argument( + "--folder-name", + type=str, + help="Folder name where the model should be downloaded", + ) + download_parser.add_argument( + "--revision", + type=str, + help="Branch name in HuggingFace repo", + ) + download_parser.add_argument( + "--token", type=str, help="HuggingFace access token for private repos" + ) + download_parser.add_argument( + "--include", type=str, help="Glob pattern of files to include" + ) + download_parser.add_argument( + "--exclude", type=str, help="Glob pattern of files to exclude" + ) + + # Calls openapi action + openapi_export_parser = actions_subparsers.add_parser( + "export-openapi", help="Exports an OpenAPI compliant JSON schema" + ) + openapi_export_parser.add_argument( + "--export-path", + help="Path to export the generated OpenAPI JSON (default: openapi.json)", + ) + + # Calls config export action + config_export_parser = actions_subparsers.add_parser( + "export-config", help="Generates and exports a sample config YAML file" + ) + config_export_parser.add_argument( + "--export-path", + help="Path to export the generated sample config (default: config_sample.yml)", + ) + + def convert_args_to_dict( args: argparse.Namespace, parser: argparse.ArgumentParser ) -> dict: diff --git a/common/downloader.py b/common/downloader.py index 6813e0d8..37e6434b 100644 --- a/common/downloader.py +++ b/common/downloader.py @@ -98,10 +98,10 @@ async def hf_repo_download( folder_name: Optional[str], revision: Optional[str], token: Optional[str], - chunk_limit: Optional[float], include: Optional[List[str]], exclude: Optional[List[str]], - timeout: Optional[int], + chunk_limit: Optional[float] = None, + timeout: Optional[int] = None, repo_type: Optional[str] = "model", ): """Gets a repo's information from HuggingFace and downloads it locally.""" diff --git a/main.py b/main.py index c7981d58..df4e472c 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" +import argparse import asyncio import os import pathlib @@ -11,7 +12,7 @@ from common import gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys -from common.actions import branch_to_actions +from common.actions import run_subcommand from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler @@ -99,7 +100,10 @@ async def entrypoint_async(): await start_api(host, port) -def entrypoint(arguments: Optional[dict] = None): +def entrypoint( + args: Optional[argparse.Namespace] = None, + parser: Optional[argparse.ArgumentParser] = None, +): setup_logger() # Set up signal aborting @@ -115,15 +119,17 @@ def entrypoint(arguments: Optional[dict] = None): install() # Parse and override config from args - if arguments is None: + if args is None: parser = init_argparser() - arguments = convert_args_to_dict(parser.parse_args(), parser) + args = parser.parse_args() + + dict_args = convert_args_to_dict(args, parser) # load config - config.load(arguments) + config.load(dict_args) # branch to default paths if required - if branch_to_actions(): + if run_subcommand(args): return # Check exllamav2 version and give a descriptive error if it's too old diff --git a/start.py b/start.py index 731616c6..8e958298 100644 --- a/start.py +++ b/start.py @@ -275,8 +275,6 @@ def migrate_gpu_lib(): from common.args import convert_args_to_dict from main import entrypoint - converted_args = convert_args_to_dict(args, parser) - # Create a config if it doesn't exist # This is not necessary to run TabbyAPI, but is new user proof config_path = ( @@ -292,7 +290,7 @@ def migrate_gpu_lib(): ) print("Starting TabbyAPI...") - entrypoint(converted_args) + entrypoint(args, parser) except (ModuleNotFoundError, ImportError): traceback.print_exc() print(