From bb48f77ca1a871770cdd5319a1530f63274f67cc Mon Sep 17 00:00:00 2001 From: AliCat <86847834+alicat22@users.noreply.github.com> Date: Wed, 7 Feb 2024 22:23:09 -0700 Subject: [PATCH] Neutralize samplers (#59) * Update sample_preset.yml Neutralized the samplers. * Sampling: Fix dynatemp defaults Default max temp and min temp is 1.0 * Sampling: Fix TFS defaults Default is 1.0 --------- Co-authored-by: AliCat <86847834+alicat22@users.noreply.github.com> Co-authored-by: kingbri --- backends/exllamav2/model.py | 6 +++--- common/sampling.py | 6 +++--- sampler_overrides/sample_preset.yml | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4b535ed7..590b47c4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -558,8 +558,8 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) # DynaTemp settings - max_temp = unwrap(kwargs.get("max_temp"), 0.0) - min_temp = unwrap(kwargs.get("min_temp"), 0.0) + max_temp = unwrap(kwargs.get("max_temp"), 1.0) + min_temp = unwrap(kwargs.get("min_temp"), 1.0) if max_temp > min_temp: gen_settings.max_temp = max_temp @@ -574,7 +574,7 @@ def generate_gen(self, prompt: str, **kwargs): # Warn if max/min temp values are > 0 # and if they're less than or equal to each other if max_temp < min_temp or ( - 0 not in {min_temp, max_temp} and max_temp == min_temp + 1 not in {min_temp, max_temp} and max_temp == min_temp ): logger.warning( "Max temp is less than or equal to min temp, skipping DynaTemp." diff --git a/common/sampling.py b/common/sampling.py index a9ea9d32..5824dedd 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -64,7 +64,7 @@ class BaseSamplerRequest(BaseModel): ) tfs: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("tfs", 0.0) + default_factory=lambda: get_default_sampler_value("tfs", 1.0) ) frequency_penalty: Optional[float] = Field( @@ -142,13 +142,13 @@ class BaseSamplerRequest(BaseModel): ) max_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("max_temp", 0.0), + default_factory=lambda: get_default_sampler_value("max_temp", 1.0), validation_alias=AliasChoices("max_temp", "dynatemp_high"), description="Aliases: dynatemp_high", ) min_temp: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("min_temp", 0.0), + default_factory=lambda: get_default_sampler_value("min_temp", 1.0), validation_alias=AliasChoices("min_temp", "dynatemp_low"), description="Aliases: dynatemp_low", ) diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 3d1c42a4..da28dbd0 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -31,13 +31,13 @@ temperature_last: override: false force: false min_temp: - override: 0.0 + override: 1.0 force: false max_temp: - override: 0.0 + override: 1.0 force: false temp_exponent: - override: 0.0 + override: 1.0 force: false smoothing_factor: override: 0.0 @@ -57,7 +57,7 @@ min_p: override: 0.0 force: false tfs: - override: 0.0 + override: 1.0 force: false typical: override: 1.0