Skip to content

Commit

Permalink
Sampling: Add top-a support
Browse files Browse the repository at this point in the history
Currently in exllamav2 dev, but will be in the next release.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Dec 23, 2023
1 parent 6a5bbd2 commit 80ef379
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CommonCompletionRequest(BaseModel):
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
top_a: Optional[float] = 0.0
typical: Optional[float] = 1.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
Expand Down Expand Up @@ -111,6 +112,7 @@ def to_gen_params(self):
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
Expand Down
9 changes: 9 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,20 @@ def generate_gen(self, prompt: str, **kwargs):
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("top_a"), False)) and not hasattr (
gen_settings, "top_a"
):
logger.warning(
"Top-A is not supported by the currently "
"installed ExLlamaV2 version."
)

# Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0)
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
Expand Down

0 comments on commit 80ef379

Please sign in to comment.