Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround openai issue with temperature: 0 being omitted from request #342

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions packages/grafana-llm-app/pkg/plugin/llm_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"math/rand"
"strings"

Expand Down Expand Up @@ -67,6 +68,32 @@ type ChatCompletionRequest struct {
Model Model `json:"model"`
}

// UnmarshalJSON implements json.Unmarshaler.
// We have a custom implementation here to check whether temperature is being
// explicitly set to `0` in the incoming request, because the `openai.ChatCompletionRequest`
// struct has `omitempty` on the Temperature field and would omit it when marshaling.
// If there is an explicit 0 value in the request, we set it to `math.SmallestNonzeroFloat32`,
// a workaround mentioned in https://github.com/sashabaranov/go-openai/issues/9#issuecomment-894845206.
func (c *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
// Create a wrapper type alias to avoid recursion, otherwise the
// subsequent call to UnmarshalJSON would call this method forever.
type Alias ChatCompletionRequest
var a Alias
if err := json.Unmarshal(data, &a); err != nil {
return err
}
// Also unmarshal to a map to check if temperature is being set explicitly in the request.
r := map[string]any{}
if err := json.Unmarshal(data, &r); err != nil {
return err
}
if t, ok := r["temperature"].(float64); ok && t == 0 {
a.ChatCompletionRequest.Temperature = math.SmallestNonzeroFloat32
}
*c = ChatCompletionRequest(a)
return nil
}

type ChatCompletionStreamResponse struct {
openai.ChatCompletionStreamResponse
// Random padding used to mitigate side channel attacks.
Expand Down
46 changes: 45 additions & 1 deletion packages/grafana-llm-app/pkg/plugin/llm_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package plugin

import (
"encoding/json"
"math"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestModelFromString(t *testing.T) {
Expand Down Expand Up @@ -83,7 +85,7 @@ func TestModelFromString(t *testing.T) {
}
}

func TestUnmarshalJSON(t *testing.T) {
func TestModelUnmarshalJSON(t *testing.T) {
tests := []struct {
input []byte
expected Model
Expand Down Expand Up @@ -164,6 +166,48 @@ func TestUnmarshalJSON(t *testing.T) {
}
}

func TestChatCompletionRequestUnmarshalJSON(t *testing.T) {
for _, tt := range []struct {
input []byte
expected ChatCompletionRequest
}{
{
input: []byte(`{"model":"base"}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0.5}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: 0.5,
},
},
},
{
input: []byte(`{"model":"base", "temperature":0}`),
expected: ChatCompletionRequest{
Model: ModelBase,
ChatCompletionRequest: openai.ChatCompletionRequest{
Temperature: math.SmallestNonzeroFloat32,
},
},
},
} {
t.Run(string(tt.input), func(t *testing.T) {
var req ChatCompletionRequest
err := json.Unmarshal(tt.input, &req)
assert.NoError(t, err)
assert.Equal(t, tt.expected, req)
})
}
}

func TestChatCompletionStreamResponseMarshalJSON(t *testing.T) {
resp := ChatCompletionStreamResponse{
ChatCompletionStreamResponse: openai.ChatCompletionStreamResponse{
Expand Down
Loading