Skip to content

Commit

Permalink
Neutralize samplers (#59)
Browse files Browse the repository at this point in the history
* Update sample_preset.yml

Neutralized the samplers.

* Sampling: Fix dynatemp defaults

Default max temp and min temp is 1.0

* Sampling: Fix TFS defaults

Default is 1.0

---------

Co-authored-by: AliCat <[email protected]>
Co-authored-by: kingbri <[email protected]>
  • Loading branch information
alicat22 and kingbri1 authored Feb 8, 2024
1 parent 321c9a1 commit bb48f77
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)

# DynaTemp settings
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
min_temp = unwrap(kwargs.get("min_temp"), 1.0)

if max_temp > min_temp:
gen_settings.max_temp = max_temp
Expand All @@ -574,7 +574,7 @@ def generate_gen(self, prompt: str, **kwargs):
# Warn if max/min temp values are > 0
# and if they're less than or equal to each other
if max_temp < min_temp or (
0 not in {min_temp, max_temp} and max_temp == min_temp
1 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
Expand Down
6 changes: 3 additions & 3 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BaseSamplerRequest(BaseModel):
)

tfs: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("tfs", 0.0)
default_factory=lambda: get_default_sampler_value("tfs", 1.0)
)

frequency_penalty: Optional[float] = Field(
Expand Down Expand Up @@ -142,13 +142,13 @@ class BaseSamplerRequest(BaseModel):
)

max_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 0.0),
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
)

min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 0.0),
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
)
Expand Down
8 changes: 4 additions & 4 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ temperature_last:
override: false
force: false
min_temp:
override: 0.0
override: 1.0
force: false
max_temp:
override: 0.0
override: 1.0
force: false
temp_exponent:
override: 0.0
override: 1.0
force: false
smoothing_factor:
override: 0.0
Expand All @@ -57,7 +57,7 @@ min_p:
override: 0.0
force: false
tfs:
override: 0.0
override: 1.0
force: false
typical:
override: 1.0
Expand Down

0 comments on commit bb48f77

Please sign in to comment.