Skip to content

Commit

Permalink
Args: Add subcommands to run actions
Browse files Browse the repository at this point in the history
Migrate OpenAPI and sample config export to subcommands "export-openapi"
and "export-config".

Also add a "download" subcommand that passes args to the TabbyAPI
downloader. This allows models to be downloaded via the API and
CLI args.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Feb 11, 2025
1 parent 30f02e5 commit 30ab8e0
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ jobs:
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: |
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
python main.py export-openapi --export-path "openapi-kobold.json"
python main.py export-openapi --export-path "openapi-oai.json"
- name: Build and store Redocly site
run: |
mkdir static
Expand Down
60 changes: 42 additions & 18 deletions common/actions.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,51 @@
import argparse
import asyncio
import json
from loguru import logger

from common.tabby_config import config, generate_config_file
from common.downloader import hf_repo_download
from common.tabby_config import generate_config_file
from common.utils import unwrap
from endpoints.server import export_openapi


def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""
def download_action(args: argparse.Namespace):
asyncio.run(
hf_repo_download(
repo_id=args.repo_id,
folder_name=args.folder_name,
revision=args.revision,
token=args.token,
include=args.include,
exclude=args.exclude,
)
)

if config.actions.export_openapi:
openapi_json = export_openapi()

with open(config.actions.openapi_export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info(
"Successfully wrote OpenAPI spec to "
+ f"{config.actions.openapi_export_path}"
)
elif config.actions.export_config:
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch
return False
def config_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "config_sample.yml")
generate_config_file(filename=export_path)

# branched and ran an action
return True

def openapi_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "openapi.json")
openapi_json = export_openapi()

with open(export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to " + f"{export_path}")


def run_subcommand(args: argparse.Namespace) -> bool:
match args.actions:
case "download":
download_action(args)
return True
case "export-config":
config_export_action(args)
return True
case "export-openapi":
openapi_export_action(args)
return True
case _:
return False
53 changes: 53 additions & 0 deletions common/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def init_argparser(
existing_parser, argparse.ArgumentParser(description="TabbyAPI server")
)

add_subcommands(parser)

# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional_type(field_info.annotation)
Expand All @@ -59,6 +61,57 @@ def init_argparser(
return parser


def add_subcommands(parser: argparse.ArgumentParser):
"""Adds subcommands to an existing argparser"""

actions_subparsers = parser.add_subparsers(
dest="actions", help="Extra actions that can be run instead of the main server."
)

# Calls download action
download_parser = actions_subparsers.add_parser(
"download", help="Calls the model downloader"
)
download_parser.add_argument("repo_id", type=str, help="HuggingFace repo ID")
download_parser.add_argument(
"--folder-name",
type=str,
help="Folder name where the model should be downloaded",
)
download_parser.add_argument(
"--revision",
type=str,
help="Branch name in HuggingFace repo",
)
download_parser.add_argument(
"--token", type=str, help="HuggingFace access token for private repos"
)
download_parser.add_argument(
"--include", type=str, help="Glob pattern of files to include"
)
download_parser.add_argument(
"--exclude", type=str, help="Glob pattern of files to exclude"
)

# Calls openapi action
openapi_export_parser = actions_subparsers.add_parser(
"export-openapi", help="Exports an OpenAPI compliant JSON schema"
)
openapi_export_parser.add_argument(
"--export-path",
help="Path to export the generated OpenAPI JSON (default: openapi.json)",
)

# Calls config export action
config_export_parser = actions_subparsers.add_parser(
"export-config", help="Generates and exports a sample config YAML file"
)
config_export_parser.add_argument(
"--export-path",
help="Path to export the generated sample config (default: config_sample.yml)",
)


def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions common/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ async def hf_repo_download(
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
chunk_limit: Optional[float],
include: Optional[List[str]],
exclude: Optional[List[str]],
timeout: Optional[int],
chunk_limit: Optional[float] = None,
timeout: Optional[int] = None,
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""
Expand Down
18 changes: 12 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""

import argparse
import asyncio
import os
import pathlib
Expand All @@ -11,7 +12,7 @@
from common import gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.actions import branch_to_actions
from common.actions import run_subcommand
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.signals import signal_handler
Expand Down Expand Up @@ -99,7 +100,10 @@ async def entrypoint_async():
await start_api(host, port)


def entrypoint(arguments: Optional[dict] = None):
def entrypoint(
args: Optional[argparse.Namespace] = None,
parser: Optional[argparse.ArgumentParser] = None,
):
setup_logger()

# Set up signal aborting
Expand All @@ -115,15 +119,17 @@ def entrypoint(arguments: Optional[dict] = None):
install()

# Parse and override config from args
if arguments is None:
if args is None:
parser = init_argparser()
arguments = convert_args_to_dict(parser.parse_args(), parser)
args = parser.parse_args()

dict_args = convert_args_to_dict(args, parser)

# load config
config.load(arguments)
config.load(dict_args)

# branch to default paths if required
if branch_to_actions():
if run_subcommand(args):
return

# Check exllamav2 version and give a descriptive error if it's too old
Expand Down
4 changes: 1 addition & 3 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,6 @@ def migrate_gpu_lib():
from common.args import convert_args_to_dict
from main import entrypoint

converted_args = convert_args_to_dict(args, parser)

# Create a config if it doesn't exist
# This is not necessary to run TabbyAPI, but is new user proof
config_path = (
Expand All @@ -292,7 +290,7 @@ def migrate_gpu_lib():
)

print("Starting TabbyAPI...")
entrypoint(converted_args)
entrypoint(args, parser)
except (ModuleNotFoundError, ImportError):
traceback.print_exc()
print(
Expand Down

0 comments on commit 30ab8e0

Please sign in to comment.