From 55d81b9c171c5d1a9df08d964820bbd597dc80b6 Mon Sep 17 00:00:00 2001 From: ogzhanolguncu Date: Wed, 22 May 2024 15:46:10 +0300 Subject: [PATCH] feat: rename template to prompt and allow getting history from ragchat directly --- src/config.ts | 4 ++-- src/prompts.ts | 2 +- src/rag-chat-base.ts | 8 ++++---- src/rag-chat.test.ts | 2 +- src/rag-chat.ts | 15 +++++++++++---- src/types.ts | 2 +- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/config.ts b/src/config.ts index af00035..e4787d2 100644 --- a/src/config.ts +++ b/src/config.ts @@ -11,7 +11,7 @@ export class Config { public readonly ratelimit?: Ratelimit; public readonly model?: BaseLanguageModelInterface; - public readonly template?: PromptTemplate; + public readonly prompt?: PromptTemplate; constructor(config: RAGChatConfig) { this.vector = config.vector; @@ -20,6 +20,6 @@ export class Config { this.ratelimit = config.ratelimit; this.model = config.model; - this.template = config.template; + this.prompt = config.prompt; } } diff --git a/src/prompts.ts b/src/prompts.ts index 7e55632..cd9d959 100644 --- a/src/prompts.ts +++ b/src/prompts.ts @@ -1,6 +1,6 @@ import { PromptTemplate } from "@langchain/core/prompts"; -export const QA_TEMPLATE = +export const QA_PROMPT_TEMPLATE = PromptTemplate.fromTemplate(`You are a friendly AI assistant augmented with an Upstash Vector Store. To help you answer the questions, a context will be provided. This context is generated by querying the vector store with the user question. Answer the question at the end using only the information available in the context and chat history. diff --git a/src/rag-chat-base.ts b/src/rag-chat-base.ts index 0b83e9f..cc8b766 100644 --- a/src/rag-chat-base.ts +++ b/src/rag-chat-base.ts @@ -17,18 +17,18 @@ export class RAGChatBase { protected historyService: HistoryService; #model: BaseLanguageModelInterface; - #template: PromptTemplate; + #prompt: PromptTemplate; constructor( retrievalService: RetrievalService, historyService: HistoryService, - config: { model: BaseLanguageModelInterface; template: PromptTemplate } + config: { model: BaseLanguageModelInterface; prompt: PromptTemplate } ) { this.retrievalService = retrievalService; this.historyService = historyService; this.#model = config.model; - this.#template = config.template; + this.#prompt = config.prompt; } protected async prepareChat({ @@ -69,7 +69,7 @@ export class RAGChatBase { question: (input) => input.question, context: (input) => input.context, }, - this.#template, + this.#prompt, this.#model, ]); diff --git a/src/rag-chat.test.ts b/src/rag-chat.test.ts index 743f5e3..5603970 100644 --- a/src/rag-chat.test.ts +++ b/src/rag-chat.test.ts @@ -129,7 +129,7 @@ describe("RAG Chat with custom template", () => { token: process.env.UPSTASH_REDIS_REST_TOKEN!, url: process.env.UPSTASH_REDIS_REST_URL!, }), - template: PromptTemplate.fromTemplate("Just say `I'm a cookie monster`. Nothing else."), + prompt: PromptTemplate.fromTemplate("Just say `I'm a cookie monster`. Nothing else."), model: new ChatOpenAI({ modelName: "gpt-3.5-turbo", streaming: false, diff --git a/src/rag-chat.ts b/src/rag-chat.ts index 0fda26a..637cf40 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -1,7 +1,7 @@ import type { AIMessage } from "@langchain/core/messages"; import type { StreamingTextResponse } from "ai"; -import { QA_TEMPLATE } from "./prompts"; +import { QA_PROMPT_TEMPLATE } from "./prompts"; import { UpstashModelError } from "./error/model"; import { RatelimitUpstashError } from "./error/ratelimit"; @@ -29,7 +29,7 @@ export class RAGChat extends RAGChatBase { } super(retrievalService, historyService, { model: config.model, - template: config.template ?? QA_TEMPLATE, + prompt: config.prompt ?? QA_PROMPT_TEMPLATE, }); this.#ratelimitService = ratelimitService; } @@ -38,7 +38,7 @@ export class RAGChat extends RAGChatBase { // Adds chat session id and ratelimit session id if not provided. const options_ = appendDefaultsIfNeeded(options); - //Checks ratelimit of the user. If not enabled `success` will be always true. + // Checks ratelimit of the user. If not enabled `success` will be always true. const { success, resetTime } = await this.#ratelimitService.checkLimit( options_.ratelimitSessionId ); @@ -50,7 +50,7 @@ export class RAGChat extends RAGChatBase { }); } - //Sanitizes the given input by stripping all the newline chars then queries vector db with sanitized question. + // Sanitizes the given input by stripping all the newline chars. Then, queries vector db with sanitized question. const { question, facts } = await this.prepareChat({ question: input, similarityThreshold: options_.similarityThreshold, @@ -58,6 +58,8 @@ export class RAGChat extends RAGChatBase { topK: options_.topK, }); + // Calls LLM service with organized prompt. Prompt holds chat_history, facts gathered from vector db and sanitized question. + // Allows either streaming call via Vercel AI SDK or non-streaming call return options.stream ? this.streamingChainCall(options_, question, facts) : this.chainCall(options_, question, facts); @@ -71,4 +73,9 @@ export class RAGChat extends RAGChatBase { ); return retrievalServiceStatus === "Success" ? "OK" : "NOT-OK"; } + + /** Method to get history of messages used in the RAG Chat*/ + getHistory() { + return this.historyService; + } } diff --git a/src/types.ts b/src/types.ts index dad6e01..dfe99a8 100644 --- a/src/types.ts +++ b/src/types.ts @@ -84,7 +84,7 @@ type RAGChatConfigCommon = { Question: {question} Helpful answer:`) */ - template?: PromptTemplate; + prompt?: PromptTemplate; /** * Ratelimit instance * @example new Ratelimit({