Skip to content

Commit

Permalink
Tree: Format
Browse files Browse the repository at this point in the history
Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Dec 24, 2023
1 parent c9126c3 commit 703a114
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
6 changes: 6 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

GLOBAL_CONFIG: dict = {}


def read_config_from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG
Expand All @@ -23,24 +24,29 @@ def read_config_from_file(config_path: pathlib.Path):
)
GLOBAL_CONFIG = {}


def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})


def get_draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})


def get_lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})


def get_network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})


def get_gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_model_config,
get_draft_model_config,
get_lora_config,
get_network_config
get_network_config,
)
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
Expand Down
4 changes: 1 addition & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,7 @@ def generate_gen(self, prompt: str, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("top_a"), False)) and not hasattr (
gen_settings, "top_a"
):
if (unwrap(kwargs.get("top_a"), False)) and not hasattr(gen_settings, "top_a"):
logger.warning(
"Top-A is not supported by the currently "
"installed ExLlamaV2 version."
Expand Down

0 comments on commit 703a114

Please sign in to comment.