diff --git a/bun.lockb b/bun.lockb index 70ad316..68b7b2a 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 7a33ee7..6cfc1f0 100644 --- a/package.json +++ b/package.json @@ -49,6 +49,10 @@ "vitest": "latest" }, "dependencies": { - "@upstash/vector": "^1.0.7" + "@langchain/community": "^0.0.50", + "@langchain/core": "^0.1.58", + "@langchain/openai": "^0.0.28", + "@upstash/sdk": "0.0.19-alpha", + "ai": "^3.0.35" } } diff --git a/src/clients/ratelimiter/index.ts b/src/clients/ratelimiter/index.ts new file mode 100644 index 0000000..deb3846 --- /dev/null +++ b/src/clients/ratelimiter/index.ts @@ -0,0 +1,42 @@ +import { Ratelimit } from "@upstash/sdk"; + +import type { Redis, Upstash } from "@upstash/sdk"; +import { InternalUpstashError } from "../../error/internal"; + +const DEFAULT_RATELIMITER_NAME = "@upstash-rag-chat-ratelimit"; +const MAX_ALLOWED_CHAT_REQUEST = 10; + +export class RatelimiterClientConstructor { + private redisClient?: Redis; + private ratelimiterClient?: Ratelimit; + private sdkClient: Upstash; + + constructor(sdkClient: Upstash, redisClient?: Redis) { + this.redisClient = redisClient; + this.sdkClient = sdkClient; + } + + public async getRatelimiterClient(): Promise { + if (!this.ratelimiterClient) { + try { + await this.initializeRatelimiterClient(); + } catch (error) { + console.error("Failed to initialize Ratelimiter client:", error); + return undefined; + } + } + return this.ratelimiterClient; + } + + private initializeRatelimiterClient = async () => { + if (!this.redisClient) + throw new InternalUpstashError("Redis client is in missing in initializeRatelimiterClient!"); + + const ratelimiter = await this.sdkClient.newRatelimitClient(this.redisClient, { + limiter: Ratelimit.tokenBucket(MAX_ALLOWED_CHAT_REQUEST, "1d", MAX_ALLOWED_CHAT_REQUEST), + prefix: DEFAULT_RATELIMITER_NAME, + }); + + this.ratelimiterClient = ratelimiter; + }; +} diff --git a/src/clients/redis/index.test.ts b/src/clients/redis/index.test.ts new file mode 100644 index 0000000..c9d5117 --- /dev/null +++ b/src/clients/redis/index.test.ts @@ -0,0 +1,65 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { Upstash } from "@upstash/sdk"; +import { describe, expect, test } from "bun:test"; +import { DEFAULT_REDIS_CONFIG, DEFAULT_REDIS_DB_NAME, RedisClientConstructor } from "."; + +const upstashSDK = new Upstash({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, +}); + +describe("Redis Client", () => { + test( + "Initialize client without db name", + async () => { + const constructor = new RedisClientConstructor({ + sdkClient: upstashSDK, + }); + const redisClient = await constructor.getRedisClient(); + + expect(redisClient).toBeTruthy(); + + await upstashSDK.deleteRedisDatabase(DEFAULT_REDIS_DB_NAME); + }, + { timeout: 10_000 } + ); + + test( + "Initialize client with db name", + async () => { + const constructor = new RedisClientConstructor({ + sdkClient: upstashSDK, + redisDbNameOrInstance: "test-name", + }); + const redisClient = await constructor.getRedisClient(); + + expect(redisClient).toBeTruthy(); + + await upstashSDK.deleteRedisDatabase("test-name"); + }, + { timeout: 10_000 } + ); + + test( + "Initialize client with existing instance", + async () => { + const dbName = DEFAULT_REDIS_CONFIG.name + "suffix"; + const redisInstance = await upstashSDK.createRedisDatabase({ + ...DEFAULT_REDIS_CONFIG, + name: dbName, + }); + const existingRedisClient = await upstashSDK.newRedisClient(redisInstance.database_name); + + const constructor = new RedisClientConstructor({ + sdkClient: upstashSDK, + redisDbNameOrInstance: existingRedisClient, + }); + const redisClient = await constructor.getRedisClient(); + + expect(redisClient).toBeTruthy(); + + await upstashSDK.deleteRedisDatabase(dbName); + }, + { timeout: 10_000 } + ); +}); diff --git a/src/clients/redis/index.ts b/src/clients/redis/index.ts new file mode 100644 index 0000000..d0876ac --- /dev/null +++ b/src/clients/redis/index.ts @@ -0,0 +1,84 @@ +import type { CreateCommandPayload, Upstash } from "@upstash/sdk"; + +import { Redis } from "@upstash/sdk"; +import type { PreferredRegions } from "../../types"; + +export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis"; + +export const DEFAULT_REDIS_CONFIG: CreateCommandPayload = { + name: DEFAULT_REDIS_DB_NAME, + tls: true, + region: "us-east-1", + eviction: false, +}; + +type Config = { + sdkClient: Upstash; + redisDbNameOrInstance?: string | Redis; + preferredRegion?: PreferredRegions; +}; + +export class RedisClientConstructor { + private redisDbNameOrInstance?: string | Redis; + private preferredRegion?: PreferredRegions; + private sdkClient: Upstash; + private redisClient?: Redis; + + constructor({ sdkClient, preferredRegion, redisDbNameOrInstance }: Config) { + this.redisDbNameOrInstance = redisDbNameOrInstance; + this.sdkClient = sdkClient; + this.preferredRegion = preferredRegion ?? "us-east-1"; + } + + public async getRedisClient(): Promise { + if (!this.redisClient) { + try { + await this.initializeRedisClient(); + } catch (error) { + console.error("Failed to initialize Redis client:", error); + return undefined; + } + } + return this.redisClient; + } + + private initializeRedisClient = async () => { + const { redisDbNameOrInstance } = this; + + // Direct Redis instance provided + if (redisDbNameOrInstance instanceof Redis) { + this.redisClient = redisDbNameOrInstance; + return; + } + + // Redis name provided + if (typeof redisDbNameOrInstance === "string") { + await this.createRedisClientByName(redisDbNameOrInstance); + return; + } + + // No specific Redis information provided, using default configuration + await this.createRedisClientByDefaultConfig(); + }; + + private createRedisClientByName = async (redisDbName: string) => { + try { + const redis = await this.sdkClient.getRedisDatabase(redisDbName); + this.redisClient = await this.sdkClient.newRedisClient(redis.database_name); + } catch { + await this.createRedisClientByDefaultConfig(redisDbName); + } + }; + + private createRedisClientByDefaultConfig = async (redisDbName?: string) => { + const redisDatabase = await this.sdkClient.getOrCreateRedisDatabase({ + ...DEFAULT_REDIS_CONFIG, + name: redisDbName ?? DEFAULT_REDIS_CONFIG.name, + region: this.preferredRegion ?? DEFAULT_REDIS_CONFIG.region, + }); + + if (redisDatabase?.database_name) { + this.redisClient = await this.sdkClient.newRedisClient(redisDatabase.database_name); + } + }; +} diff --git a/src/clients/vector/index.test.ts b/src/clients/vector/index.test.ts new file mode 100644 index 0000000..250900b --- /dev/null +++ b/src/clients/vector/index.test.ts @@ -0,0 +1,65 @@ +/* eslint-disable @typescript-eslint/no-non-null-assertion */ +import { Upstash } from "@upstash/sdk"; +import { describe, expect, test } from "bun:test"; +import { DEFAULT_VECTOR_DB_NAME, VectorClientConstructor, DEFAULT_VECTOR_CONFIG } from "."; + +const upstashSDK = new Upstash({ + email: process.env.UPSTASH_EMAIL!, + token: process.env.UPSTASH_TOKEN!, +}); + +describe("Redis Client", () => { + test( + "Initialize client without index name", + async () => { + const constructor = new VectorClientConstructor({ + sdkClient: upstashSDK, + }); + const vectorClient = await constructor.getVectorClient(); + + expect(vectorClient).toBeTruthy(); + + await upstashSDK.deleteVectorIndex(DEFAULT_VECTOR_DB_NAME); + }, + { timeout: 10_000 } + ); + + test( + "Initialize client with db name", + async () => { + const constructor = new VectorClientConstructor({ + sdkClient: upstashSDK, + indexNameOrInstance: "test-name", + }); + const redisClient = await constructor.getVectorClient(); + + expect(redisClient).toBeTruthy(); + + await upstashSDK.deleteVectorIndex("test-name"); + }, + { timeout: 10_000 } + ); + + test( + "Initialize client with existing instance", + async () => { + const indexName = DEFAULT_VECTOR_CONFIG.name + "suffix"; + const vectorInstance = await upstashSDK.createVectorIndex({ + ...DEFAULT_VECTOR_CONFIG, + name: indexName, + }); + const existingVectorClient = await upstashSDK.newVectorClient(vectorInstance.name); + + const constructor = new VectorClientConstructor({ + sdkClient: upstashSDK, + indexNameOrInstance: existingVectorClient, + }); + const vectorClient = await constructor.getVectorClient(); + + expect(vectorClient).toBeTruthy(); + + await upstashSDK.deleteVectorIndex(indexName); + }, + { timeout: 10_000 } + ); +}); diff --git a/src/clients/vector/index.ts b/src/clients/vector/index.ts new file mode 100644 index 0000000..e9bd01a --- /dev/null +++ b/src/clients/vector/index.ts @@ -0,0 +1,89 @@ +import type { CreateIndexPayload, Upstash } from "@upstash/sdk"; +import { Index } from "@upstash/sdk"; + +import type { PreferredRegions } from "../../types"; +import { InternalUpstashError } from "../../error/internal"; + +export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector"; + +export const DEFAULT_VECTOR_CONFIG: CreateIndexPayload = { + name: DEFAULT_VECTOR_DB_NAME, + similarity_function: "EUCLIDEAN", + embedding_model: "MXBAI_EMBED_LARGE_V1", + region: "us-east-1", + type: "payg", +}; + +type Config = { + sdkClient: Upstash; + indexNameOrInstance?: string | Index; + preferredRegion?: PreferredRegions; +}; + +export class VectorClientConstructor { + private indexNameOrInstance?: string | Index; + private preferredRegion?: PreferredRegions; + private sdkClient: Upstash; + private vectorClient?: Index; + + constructor({ sdkClient, preferredRegion, indexNameOrInstance: indexNameOrInstance }: Config) { + this.indexNameOrInstance = indexNameOrInstance; + this.sdkClient = sdkClient; + this.preferredRegion = preferredRegion ?? "us-east-1"; + } + + public async getVectorClient(): Promise { + if (!this.vectorClient) { + try { + await this.initializeVectorClient(); + } catch (error) { + console.error("Failed to initialize Vector client:", error); + return undefined; + } + } + return this.vectorClient; + } + + private initializeVectorClient = async () => { + const { indexNameOrInstance } = this; + + // Direct Vector instance provided + if (indexNameOrInstance instanceof Index) { + this.vectorClient = indexNameOrInstance; + return; + } + + // Vector name provided + if (typeof indexNameOrInstance === "string") { + await this.createVectorClientByName(indexNameOrInstance); + return; + } + + // No specific Vector information provided, using default configuration + await this.createVectorClientByDefaultConfig(); + }; + + private createVectorClientByName = async (indexName: string) => { + try { + const index = await this.sdkClient.getVectorIndexByName(indexName); + if (!index) throw new InternalUpstashError("Index is missing!"); + + this.vectorClient = await this.sdkClient.newVectorClient(index.name); + } catch { + await this.createVectorClientByDefaultConfig(indexName); + } + }; + + private createVectorClientByDefaultConfig = async (indexName?: string) => { + const index = await this.sdkClient.getOrCreateIndex({ + ...DEFAULT_VECTOR_CONFIG, + name: indexName ?? DEFAULT_VECTOR_CONFIG.name, + region: this.preferredRegion ?? DEFAULT_VECTOR_CONFIG.region, + }); + + if (index?.name) { + const client = await this.sdkClient.newVectorClient(index.name); + this.vectorClient = client; + } + }; +} diff --git a/src/error/internal.ts b/src/error/internal.ts new file mode 100644 index 0000000..b115f3c --- /dev/null +++ b/src/error/internal.ts @@ -0,0 +1,6 @@ +export class InternalUpstashError extends Error { + constructor(message: string) { + super(message); + this.name = "InternalUpstashError"; + } +} diff --git a/src/error/model.ts b/src/error/model.ts new file mode 100644 index 0000000..179966f --- /dev/null +++ b/src/error/model.ts @@ -0,0 +1,6 @@ +export class UpstashModelError extends Error { + constructor(message: string) { + super(message); + this.name = "UpstashModelError"; + } +} diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..8fb9838 --- /dev/null +++ b/src/index.ts @@ -0,0 +1,183 @@ +import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; +import type { PromptTemplate } from "@langchain/core/prompts"; +import type { Index, Ratelimit, Redis, UpstashConfig } from "@upstash/sdk"; +import type { Callbacks } from "@langchain/core/callbacks/manager"; +import type { BaseMessage } from "@langchain/core/messages"; +import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables"; +import { Upstash } from "@upstash/sdk"; +import { LangChainStream, StreamingTextResponse } from "ai"; + +import { RatelimiterClientConstructor } from "./clients/ratelimiter"; +import { RedisClientConstructor } from "./clients/redis"; +import { VectorClientConstructor } from "./clients/vector"; +import { InternalUpstashError } from "./error/internal"; +import { UpstashModelError } from "./error/model"; +import { QA_TEMPLATE } from "./prompts"; +import { redisChatMessagesHistory } from "./redis-custom-history"; +import type { PreferredRegions } from "./types"; +import { formatChatHistory, formatFacts, sanitizeQuestion } from "./utils"; + +const SIMILARITY_THRESHOLD = 0.5; + +type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string }; + +type RAGChatConfigCommon = { + model: BaseLanguageModelInterface; + template?: PromptTemplate; + umbrellaConfig: Omit; + preferredRegion?: PreferredRegions; +}; + +export type RAGChatConfig = ( + | { + vector?: Index; + redis?: string; + ratelimit?: Ratelimit; + } + | { + vector?: string; + redis?: Redis; + ratelimit?: Ratelimit; + } +) & + RAGChatConfigCommon; + +export class RAGChat { + private sdkClient: Upstash; + private config?: RAGChatConfig; + + //CLIENTS + private vectorClient?: Index; + private redisClient?: Redis; + private ratelimiterClient?: Ratelimit; + + constructor(email: string, token: string, config?: RAGChatConfig) { + this.sdkClient = new Upstash({ email, token, ...config?.umbrellaConfig }); + this.config = config; + } + + private async getFactsFromVector( + question: string, + similarityThreshold = SIMILARITY_THRESHOLD + ): Promise { + if (!this.vectorClient) + throw new InternalUpstashError("vectorClient is missing in getFactsFromVector"); + + const index = this.vectorClient; + const result = await index.query<{ value: string }>({ + data: question, + topK: 5, + includeMetadata: true, + includeVectors: false, + }); + + const allValuesUndefined = result.every((embedding) => embedding.metadata?.value === undefined); + if (allValuesUndefined) { + throw new TypeError(` + Query to the vector store returned ${result.length} vectors but none had "value" field in their metadata. + Text of your vectors should be in the "value" field in the metadata for the RAG Chat. + `); + } + + const facts = result + .filter((x) => x.score >= similarityThreshold) + .map((embedding, index) => `- Context Item ${index}: ${embedding.metadata?.value ?? ""}`); + return formatFacts(facts); + } + + chat = async ( + input: string, + chatOptions: { stream: boolean; sessionId: string; includeHistory?: number } + ) => { + await this.initializeClients(); + + const question = sanitizeQuestion(input); + const facts = await this.getFactsFromVector(question); + + const { stream, sessionId, includeHistory } = chatOptions; + + if (stream) { + return this.chainCallStreaming(question, facts, sessionId, includeHistory); + } + + return this.chainCall({ sessionId, includeHistory }, question, facts); + }; + + private chainCallStreaming = ( + question: string, + facts: string, + sessionId: string, + includeHistory?: number + ) => { + const { stream, handlers } = LangChainStream(); + void this.chainCall({ sessionId, includeHistory }, question, facts, [handlers]); + return new StreamingTextResponse(stream, {}); + }; + + private chainCall( + chatOptions: { sessionId: string; includeHistory?: number }, + question: string, + facts: string, + handlers?: Callbacks + ) { + if (!this.config?.model) throw new UpstashModelError("Model is missing!"); + + const formattedHistoryChain = RunnableSequence.from([ + { + chat_history: (input) => formatChatHistory(input.chat_history ?? []), + question: (input) => input.question, + context: (input) => input.context, + }, + this.config.template ?? QA_TEMPLATE, + this.config.model, + ]); + + if (!this.redisClient) throw new InternalUpstashError("redisClient is missing in chat"); + const redis = this.redisClient; + + const chainWithMessageHistory = new RunnableWithMessageHistory({ + runnable: formattedHistoryChain, + getMessageHistory: (sessionId: string) => + redisChatMessagesHistory({ + sessionId, + redis, + length: chatOptions.includeHistory, + }), + inputMessagesKey: "question", + historyMessagesKey: "chat_history", + }); + + return chainWithMessageHistory.invoke( + { + question, + context: facts, + }, + { + callbacks: handlers ?? undefined, + configurable: { sessionId: chatOptions.sessionId }, + } + ); + } + + private async initializeClients() { + if (!this.vectorClient) + this.vectorClient = await new VectorClientConstructor({ + sdkClient: this.sdkClient, + indexNameOrInstance: this.config?.vector, + preferredRegion: this.config?.preferredRegion, + }).getVectorClient(); + + if (!this.redisClient) + this.redisClient = await new RedisClientConstructor({ + sdkClient: this.sdkClient, + redisDbNameOrInstance: this.config?.redis, + preferredRegion: this.config?.preferredRegion, + }).getRedisClient(); + + if (!this.ratelimiterClient) + this.ratelimiterClient = await new RatelimiterClientConstructor( + this.sdkClient, + this.redisClient + ).getRatelimiterClient(); + } +} diff --git a/src/prompts.ts b/src/prompts.ts new file mode 100644 index 0000000..7e55632 --- /dev/null +++ b/src/prompts.ts @@ -0,0 +1,18 @@ +import { PromptTemplate } from "@langchain/core/prompts"; + +export const QA_TEMPLATE = + PromptTemplate.fromTemplate(`You are a friendly AI assistant augmented with an Upstash Vector Store. +To help you answer the questions, a context will be provided. This context is generated by querying the vector store with the user question. +Answer the question at the end using only the information available in the context and chat history. +If the answer is not available in the chat history or context, do not answer the question and politely let the user know that you can only answer if the answer is available in context or the chat history. + +------------- +Chat history: +{chat_history} +------------- +Context: +{context} +------------- + +Question: {question} +Helpful answer:`); diff --git a/src/redis-custom-history.ts b/src/redis-custom-history.ts new file mode 100644 index 0000000..956a880 --- /dev/null +++ b/src/redis-custom-history.ts @@ -0,0 +1,122 @@ +/* eslint-disable @typescript-eslint/no-magic-numbers */ +/* eslint-disable @typescript-eslint/no-unnecessary-condition */ +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import type { BaseMessage, StoredMessage } from "@langchain/core/messages"; +import { + mapChatMessagesToStoredMessages, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; +import { Redis, type RedisConfigNodejs } from "@upstash/redis"; + +//REFER HERE: https://github.com/langchain-ai/langchainjs/blob/main/libs/langchain-community/src/stores/message/upstash_redis.ts +/** + * Type definition for the input parameters required to initialize an + * instance of the UpstashRedisChatMessageHistory class. + */ +export type CustomUpstashRedisChatMessageHistoryInput = { + sessionId: string; + sessionTTL?: number; + config?: RedisConfigNodejs; + client?: Redis; + topLevelChatHistoryLength?: number; +}; + +/** + * Class used to store chat message history in Redis. It provides methods + * to add, get, and clear messages. + */ +export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace = ["langchain", "stores", "message", "upstash_redis"]; + + get lc_secrets() { + return { + "config.url": "UPSTASH_REDIS_REST_URL", + "config.token": "UPSTASH_REDIS_REST_TOKEN", + }; + } + + public client: Redis; + + private sessionId: string; + + private sessionTTL?: number; + private topLevelChatHistoryLength?: number; + + constructor(fields: CustomUpstashRedisChatMessageHistoryInput) { + super(fields); + const { sessionId, sessionTTL, config, client, topLevelChatHistoryLength } = fields; + if (client) { + this.client = client; + } else if (config) { + this.client = new Redis(config); + } else { + throw new Error( + `Upstash Redis message stores require either a config object or a pre-configured client.` + ); + } + this.sessionId = sessionId; + this.sessionTTL = sessionTTL; + this.topLevelChatHistoryLength = topLevelChatHistoryLength; + } + + /** + * Retrieves the chat messages from the Redis database. + * @returns An array of BaseMessage instances representing the chat history. + */ + async getMessages(chatHistoryLength?: number): Promise { + const length = chatHistoryLength ?? this.topLevelChatHistoryLength ?? [0, -1]; + + const rawStoredMessages: StoredMessage[] = await this.client.lrange( + this.sessionId, + typeof length === "number" ? 0 : length[0], + typeof length === "number" ? length : length[1] + ); + + const orderedMessages = rawStoredMessages.reverse(); + const previousMessages = orderedMessages.filter( + (x): x is StoredMessage => x.type !== undefined && x.data.content !== undefined + ); + return mapStoredMessagesToChatMessages(previousMessages); + } + + /** + * Adds a new message to the chat history in the Redis database. + * @param message The message to be added to the chat history. + * @returns Promise resolving to void. + */ + async addMessage(message: BaseMessage): Promise { + const messageToAdd = mapChatMessagesToStoredMessages([message]); + await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); + if (this.sessionTTL) { + await this.client.expire(this.sessionId, this.sessionTTL); + } + } + + /** + * Deletes all messages from the chat history in the Redis database. + * @returns Promise resolving to void. + */ + async clear(): Promise { + await this.client.del(this.sessionId); + } +} + +const DAY_IN_SECONDS = 86_400; +const TOP_6 = 5; + +export const redisChatMessagesHistory = ({ + length = TOP_6, + sessionId, + redis, +}: { + sessionId: string; + length?: number; + redis: Redis; +}) => { + return new CustomUpstashRedisChatMessageHistory({ + sessionId, + sessionTTL: DAY_IN_SECONDS, + topLevelChatHistoryLength: length, + client: redis, + }); +}; diff --git a/src/types.ts b/src/types.ts new file mode 100644 index 0000000..c0e75f8 --- /dev/null +++ b/src/types.ts @@ -0,0 +1 @@ +export type PreferredRegions = "eu-west-1" | "us-east-1"; diff --git a/src/utils.ts b/src/utils.ts new file mode 100644 index 0000000..3c5ac68 --- /dev/null +++ b/src/utils.ts @@ -0,0 +1,19 @@ +import type { BaseMessage } from "@langchain/core/messages"; + +export const sanitizeQuestion = (question: string) => { + return question.trim().replaceAll("\n", " "); +}; + +export const formatFacts = (facts: string[]): string => { + return facts.join("\n"); +}; + +export const formatChatHistory = (chatHistory: BaseMessage[]) => { + const formattedDialogueTurns = chatHistory.map((dialogueTurn) => + dialogueTurn._getType() === "human" + ? `Human: ${dialogueTurn.content}` + : `Assistant: ${dialogueTurn.content}` + ); + + return formatFacts(formattedDialogueTurns); +};