diff --git a/src/providers/azure/api.ts b/src/providers/azure/api.ts new file mode 100644 index 00000000..644f606d --- /dev/null +++ b/src/providers/azure/api.ts @@ -0,0 +1,28 @@ +export interface AzureFetchPayload { + apiKey: string + baseUrl: string + body: Record + model?: string + signal?: AbortSignal +} + +export const fetchChatCompletion = async(payload: AzureFetchPayload) => { + const { baseUrl, apiKey, body, model, signal } = payload || {} + const initOptions = { + headers: { 'Content-Type': 'application/json', 'api-key': apiKey }, + method: 'POST', + body: JSON.stringify({ ...body }), + signal, + } + return fetch(`${baseUrl}/openai/deployments/${model}/chat/completions?api-version=2023-08-01-preview`, initOptions) +} + +export const fetchImageGeneration = async(payload: AzureFetchPayload) => { + const { baseUrl, apiKey, body } = payload || {} + const initOptions = { + headers: { 'Content-Type': 'application/json', 'api-key': apiKey }, + method: 'POST', + body: JSON.stringify(body), + } + return fetch(`${baseUrl}.openai.azure.com/openai/images/generations:submit?api-version=2023-08-01-preview`, initOptions) +} diff --git a/src/providers/azure/handler.ts b/src/providers/azure/handler.ts new file mode 100644 index 00000000..860db143 --- /dev/null +++ b/src/providers/azure/handler.ts @@ -0,0 +1,100 @@ +import { fetchChatCompletion, fetchImageGeneration } from './api' +import { parseStream } from './parser' +import type { Message } from '@/types/message' +import type { HandlerPayload, Provider } from '@/types/provider' + +export const handlePrompt: Provider['handlePrompt'] = async(payload, signal?: AbortSignal) => { + if (payload.botId === 'chat_continuous') + return handleChatCompletion(payload, signal) + if (payload.botId === 'chat_single') + return handleChatCompletion(payload, signal) + if (payload.botId === 'image_generation') + return handleImageGeneration(payload) +} + +export const handleRapidPrompt: Provider['handleRapidPrompt'] = async(prompt, globalSettings) => { + const rapidPromptPayload = { + conversationId: 'temp', + conversationType: 'chat_single', + botId: 'temp', + globalSettings: { + ...globalSettings, + temperature: 0.4, + maxTokens: 2048, + top_p: 1, + stream: false, + }, + botSettings: {}, + prompt, + messages: [{ role: 'user', content: prompt }], + } as HandlerPayload + const result = await handleChatCompletion(rapidPromptPayload) + if (typeof result === 'string') return result + return '' +} + +const handleChatCompletion = async(payload: HandlerPayload, signal?: AbortSignal) => { + // An array to store the chat messages + const messages: Message[] = [] + + let maxTokens = payload.globalSettings.maxTokens as number + let messageHistorySize = payload.globalSettings.messageHistorySize as number + + // Iterate through the message history + while (messageHistorySize > 0) { + messageHistorySize-- + // Get the last message from the payload + const m = payload.messages.pop() + if (m === undefined) + break + + if (maxTokens - m.content.length < 0) + break + + maxTokens -= m.content.length + messages.unshift(m) + } + + const response = await fetchChatCompletion({ + apiKey: payload.globalSettings.apiKey as string, + baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''), + body: { + messages, + max_tokens: maxTokens, + temperature: payload.globalSettings.temperature as number, + top_p: payload.globalSettings.topP as number, + stream: payload.globalSettings.stream as boolean ?? true, + }, + model: payload.globalSettings.model as string, + signal, + }) + if (!response.ok) { + const responseJson = await response.json() + console.log('responseJson', responseJson) + const errMessage = responseJson.error?.message || response.statusText || 'Unknown error' + throw new Error(errMessage, { cause: responseJson.error }) + } + const isStream = response.headers.get('content-type')?.includes('text/event-stream') + if (isStream) { + return parseStream(response) + } else { + const resJson = await response.json() + return resJson.choices[0].message.content as string + } +} + +const handleImageGeneration = async(payload: HandlerPayload) => { + const prompt = payload.prompt + const response = await fetchImageGeneration({ + apiKey: payload.globalSettings.apiKey as string, + baseUrl: (payload.globalSettings.baseUrl as string).trim().replace(/\/$/, ''), + body: { prompt, n: 1, size: '512x512' }, + }) + if (!response.ok) { + const responseJson = await response.json() + const errMessage = responseJson.error?.message || response.statusText || 'Unknown error' + throw new Error(errMessage) + } + const resJson = await response.json() + return resJson.data[0].url +} diff --git a/src/providers/azure/index.ts b/src/providers/azure/index.ts new file mode 100644 index 00000000..0bcd4fbb --- /dev/null +++ b/src/providers/azure/index.ts @@ -0,0 +1,97 @@ +import { + handlePrompt, + handleRapidPrompt, +} from './handler' +import type { Provider } from '@/types/provider' + +const providerOpenAI = () => { + const provider: Provider = { + id: 'provider-azure', + icon: 'i-simple-icons:microsoftazure', // @unocss-include + name: 'Azure OpenAI', + globalSettings: [ + { + key: 'apiKey', + name: 'API Key', + type: 'api-key', + }, + { + key: 'baseUrl', + name: 'Endpoint', + description: 'OpenAI Endpoint', + type: 'input', + }, + { + key: 'model', + name: 'Azure deployment name', + description: 'Custom model name for Azure OpenAI.', + type: 'input', + }, + { + key: 'maxTokens', + name: 'Max Tokens', + description: 'The maximum number of tokens to generate in the completion.', + type: 'slider', + min: 0, + max: 32768, + default: 2048, + step: 1, + }, + { + key: 'messageHistorySize', + name: 'Max History Message Size', + description: 'The number of retained historical messages will be truncated if the length of the message exceeds the MaxToken parameter.', + type: 'slider', + min: 1, + max: 24, + default: 5, + step: 1, + }, + { + key: 'temperature', + name: 'Temperature', + type: 'slider', + description: 'What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.', + min: 0, + max: 2, + default: 0.7, + step: 0.01, + }, + { + key: 'top_p', + name: 'Top P', + description: 'An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.', + type: 'slider', + min: 0, + max: 1, + default: 1, + step: 0.01, + }, + ], + bots: [ + { + id: 'chat_continuous', + type: 'chat_continuous', + name: 'Continuous Chat', + settings: [], + }, + { + id: 'chat_single', + type: 'chat_single', + name: 'Single Chat', + settings: [], + }, + { + id: 'image_generation', + type: 'image_generation', + name: 'DALLĀ·E', + settings: [], + }, + ], + handlePrompt, + handleRapidPrompt, + } + return provider +} + +export default providerOpenAI diff --git a/src/providers/azure/parser.ts b/src/providers/azure/parser.ts new file mode 100644 index 00000000..c1ade6b3 --- /dev/null +++ b/src/providers/azure/parser.ts @@ -0,0 +1,42 @@ +import { createParser } from 'eventsource-parser' +import type { ParsedEvent, ReconnectInterval } from 'eventsource-parser' + +export const parseStream = (rawResponse: Response) => { + const encoder = new TextEncoder() + const decoder = new TextDecoder() + const rb = rawResponse.body as ReadableStream + + return new ReadableStream({ + async start(controller) { + const streamParser = (event: ParsedEvent | ReconnectInterval) => { + if (event.type === 'event') { + const data = event.data + if (data === '[DONE]') { + controller.close() + return + } + try { + const json = JSON.parse(data) + const text = (json.choices?.[0]?.delta?.content) || '' + const queue = encoder.encode(text) + controller.enqueue(queue) + } catch (e) { + controller.error(e) + } + } + } + const reader = rb.getReader() + const parser = createParser(streamParser) + let done = false + while (!done) { + const { done: isDone, value } = await reader.read() + if (isDone) { + done = true + controller.close() + return + } + parser.feed(decoder.decode(value)) + } + }, + }) +} diff --git a/src/stores/provider.ts b/src/stores/provider.ts index 4efa68ac..c5f85cee 100644 --- a/src/stores/provider.ts +++ b/src/stores/provider.ts @@ -1,10 +1,12 @@ import providerOpenAI from '@/providers/openai' +import providerAzure from '@/providers/azure' import providerReplicate from '@/providers/replicate' import { allConversationTypes } from '@/types/conversation' import type { BotMeta } from '@/types/app' export const providerList = [ providerOpenAI(), + providerAzure(), providerReplicate(), ]