From 2d489b370e53c42c9d7bcf57c63b55472c325958 Mon Sep 17 00:00:00 2001 From: Umesh Madan Date: Tue, 25 Feb 2025 13:32:53 -0800 Subject: [PATCH] knowpro: Threads simplification, refactoring (#750) Ongoing evolution of **threads** * Simplified + Refactored Thread interfaces. * Reusable ConversationThreads class --- ts/examples/chat/src/memory/podcastMemory.ts | 4 +- ts/packages/knowPro/src/conversationThread.ts | 99 +++++++++++-------- ts/packages/knowPro/src/import.ts | 78 +++------------ ts/packages/knowPro/src/search.ts | 13 ++- ts/packages/knowPro/src/secondaryIndexes.ts | 2 +- 5 files changed, 78 insertions(+), 118 deletions(-) diff --git a/ts/examples/chat/src/memory/podcastMemory.ts b/ts/examples/chat/src/memory/podcastMemory.ts index a2315743..eb76fa6b 100644 --- a/ts/examples/chat/src/memory/podcastMemory.ts +++ b/ts/examples/chat/src/memory/podcastMemory.ts @@ -210,8 +210,8 @@ export function createPodcastCommands( podcastMessages, knowledgeResponses, ); - kpPodcast.threadIndex.threads.push(...podcastThreads); - await kpPodcast.threadIndex.buildIndex(); + kpPodcast.threads.threads.push(...podcastThreads); + await kpPodcast.buildIndex(); const podcastData = kpPodcast.serialize(); await ensureDir(path.dirname(namedArgs.filePath)); diff --git a/ts/packages/knowPro/src/conversationThread.ts b/ts/packages/knowPro/src/conversationThread.ts index 9e5c3b31..6b03cf7f 100644 --- a/ts/packages/knowPro/src/conversationThread.ts +++ b/ts/packages/knowPro/src/conversationThread.ts @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -import { NormalizedEmbedding } from "typeagent"; import { TextRange } from "./dataFormat.js"; import { + deserializeEmbedding, + serializeEmbedding, TextEmbeddingIndex, TextEmbeddingIndexSettings, } from "./fuzzyIndex.js"; @@ -24,8 +25,15 @@ export type ScoredThreadIndex = { }; export interface IConversationThreads { - threads: Thread[]; - threadDescriptionIndex: IThreadDescriptionIndex; + readonly threads: Thread[]; + + addThread(thread: Thread): Promise; + lookupThread( + threadDescription: string, + maxMatches?: number, + thresholdScore?: number, + ): Promise; + removeThread(threadIndex: ThreadIndex): void; serialize(): IConversationThreadData; deserialize(data: IConversationThreadData): void; @@ -40,20 +48,8 @@ export interface IThreadDataItem { embedding: number[]; } -export interface IThreadDescriptionIndex { - addDescription( - description: string, - threadIndex: ThreadIndex | ScoredThreadIndex, - ): Promise; - lookupThread( - text: string, - maxMatches?: number, - thresholdScore?: number, - ): Promise; -} - -export class ThreadDescriptionIndex implements IThreadDescriptionIndex { - public threads: ScoredThreadIndex[]; +export class ConversationThreads implements IConversationThreads { + public threads: Thread[]; public embeddingIndex: TextEmbeddingIndex; constructor(public settings: TextEmbeddingIndexSettings) { @@ -61,26 +57,9 @@ export class ThreadDescriptionIndex implements IThreadDescriptionIndex { this.embeddingIndex = new TextEmbeddingIndex(settings); } - public async addDescription( - description: string, - threadIndex: ThreadIndex | ScoredThreadIndex, - ): Promise { - if (typeof threadIndex === "number") { - threadIndex = { - threadIndex: threadIndex, - score: 1, - }; - } - await this.embeddingIndex.addText(description); - this.threads.push(threadIndex); - } - - public add(embedding: NormalizedEmbedding, threadIndex: ThreadIndex): void { - this.embeddingIndex.add(embedding); - this.threads.push({ - threadIndex: threadIndex, - score: 1, - }); + public async addThread(thread: Thread): Promise { + this.threads.push(thread); + await this.embeddingIndex.addText(thread.description); } public async lookupThread( @@ -99,12 +78,9 @@ export class ThreadDescriptionIndex implements IThreadDescriptionIndex { } public removeThread(threadIndex: ThreadIndex) { - const indexOf = this.threads.findIndex( - (t) => t.threadIndex === threadIndex, - ); - if (indexOf >= 0) { - this.threads.splice(indexOf, 1); - this.embeddingIndex.removeAt(indexOf); + if (threadIndex >= 0) { + this.threads.splice(threadIndex, 1); + this.embeddingIndex.removeAt(threadIndex); } } @@ -112,4 +88,41 @@ export class ThreadDescriptionIndex implements IThreadDescriptionIndex { this.threads = []; this.embeddingIndex.clear(); } + + public async buildIndex(): Promise { + this.embeddingIndex.clear(); + for (let i = 0; i < this.threads.length; ++i) { + const thread = this.threads[i]; + await this.embeddingIndex.addText(thread.description); + } + } + + public serialize(): IConversationThreadData { + const threadData: IThreadDataItem[] = []; + const embeddingIndex = this.embeddingIndex; + for (let i = 0; i < this.threads.length; ++i) { + const thread = this.threads[i]; + threadData.push({ + thread, + embedding: serializeEmbedding(embeddingIndex.get(i)), + }); + } + return { + threads: threadData, + }; + } + + public deserialize(data: IConversationThreadData): void { + if (data.threads) { + this.threads = []; + this.embeddingIndex.clear(); + for (let i = 0; i < data.threads.length; ++i) { + this.threads.push(data.threads[i].thread); + const embedding = deserializeEmbedding( + data.threads[i].embedding, + ); + this.embeddingIndex.add(embedding); + } + } + } } diff --git a/ts/packages/knowPro/src/import.ts b/ts/packages/knowPro/src/import.ts index 7c30011e..1c078e43 100644 --- a/ts/packages/knowPro/src/import.ts +++ b/ts/packages/knowPro/src/import.ts @@ -26,8 +26,6 @@ import { } from "./relatedTermsIndex.js"; import { createTextEmbeddingIndexSettings, - deserializeEmbedding, - serializeEmbedding, TextEmbeddingIndexSettings, } from "./fuzzyIndex.js"; import { TimestampToTextRangeIndex } from "./timestampIndex.js"; @@ -40,10 +38,7 @@ import { addPropertiesToIndex, PropertyIndex } from "./propertyIndex.js"; import { IConversationSecondaryIndexes } from "./secondaryIndexes.js"; import { IConversationThreadData, - IConversationThreads, - IThreadDataItem, - Thread, - ThreadDescriptionIndex, + ConversationThreads, } from "./conversationThread.js"; // metadata for podcast messages @@ -143,7 +138,7 @@ export class Podcast implements IConversation, IConversationSecondaryIndexes { public settings: PodcastSettings; - public threadIndex: PodcastThreads; + public threads: ConversationThreads; constructor( public nameTag: string, @@ -162,15 +157,15 @@ export class Podcast | undefined = undefined, ) { this.settings = createPodcastSettings(); - this.threadIndex = new PodcastThreads(this.settings.threadSettings); + this.threads = new ConversationThreads(this.settings.threadSettings); } public addMetadataToIndex() { for (let i = 0; i < this.messages.length; i++) { const msg = this.messages[i]; - const knowlegeResponse = msg.metadata.getKnowledge(); + const knowledgeResponse = msg.metadata.getKnowledge(); if (this.semanticRefIndex !== undefined) { - for (const entity of knowlegeResponse.entities) { + for (const entity of knowledgeResponse.entities) { addEntityToIndex( entity, this.semanticRefs, @@ -178,7 +173,7 @@ export class Podcast i, ); } - for (const action of knowlegeResponse.actions) { + for (const action of knowledgeResponse.actions) { addActionToIndex( action, this.semanticRefs, @@ -186,7 +181,7 @@ export class Podcast i, ); } - for (const topic of knowlegeResponse.topics) { + for (const topic of knowledgeResponse.topics) { addTopicToIndex( topic, this.semanticRefs, @@ -217,7 +212,7 @@ export class Podcast const result = await buildConversationIndex(this, progressCallback); this.addMetadataToIndex(); this.buildSecondaryIndexes(); - await this.threadIndex.buildIndex(); + await this.threads.buildIndex(); return result; } @@ -246,7 +241,7 @@ export class Podcast semanticRefs: this.semanticRefs, semanticIndexData: this.semanticRefIndex?.serialize(), relatedTermsIndexData: this.termToRelatedTermsIndex?.serialize(), - threadData: this.threadIndex.serialize(), + threadData: this.threads.serialize(), }; } @@ -265,8 +260,10 @@ export class Podcast ); } if (data.threadData) { - this.threadIndex = new PodcastThreads(this.settings.threadSettings); - this.threadIndex.deserialize(data.threadData); + this.threads = new ConversationThreads( + this.settings.threadSettings, + ); + this.threads.deserialize(data.threadData); } this.buildSecondaryIndexes(); } @@ -443,52 +440,3 @@ function randomDate(startHour = 14) { date.setDate(Math.floor(Math.random() * 28)); return date; } - -class PodcastThreads implements IConversationThreads { - public threads: Thread[]; - public threadDescriptionIndex: ThreadDescriptionIndex; - - constructor(settings: TextEmbeddingIndexSettings) { - this.threads = []; - this.threadDescriptionIndex = new ThreadDescriptionIndex(settings); - } - - public async buildIndex(): Promise { - for (let i = 0; i < this.threads.length; ++i) { - const thread = this.threads[i]; - await this.threadDescriptionIndex.addDescription( - thread.description, - i, - ); - } - } - - public serialize(): IConversationThreadData { - const threadData: IThreadDataItem[] = []; - const embeddingIndex = this.threadDescriptionIndex.embeddingIndex; - for (let i = 0; i < this.threads.length; ++i) { - const thread = this.threads[i]; - threadData.push({ - thread, - embedding: serializeEmbedding(embeddingIndex.get(i)), - }); - } - return { - threads: threadData, - }; - } - - public deserialize(data: IConversationThreadData): void { - if (data.threads) { - this.threads = []; - this.threadDescriptionIndex.clear(); - for (let i = 0; i < data.threads.length; ++i) { - this.threads.push(data.threads[i].thread); - const embedding = deserializeEmbedding( - data.threads[i].embedding, - ); - this.threadDescriptionIndex.add(embedding, i); - } - } - } -} diff --git a/ts/packages/knowPro/src/search.ts b/ts/packages/knowPro/src/search.ts index 1427c3c7..086df774 100644 --- a/ts/packages/knowPro/src/search.ts +++ b/ts/packages/knowPro/src/search.ts @@ -283,18 +283,17 @@ class SearchQueryBuilder { this.allScopeSearchTerms.push(...searchTermsUsed); } // If a thread index is available... - const threadIndex = this.secondaryIndexes?.threadIndex; - if (filter.threadDescription && threadIndex) { - const threadsInScope = - await threadIndex.threadDescriptionIndex.lookupThread( - filter.threadDescription, - ); + const threads = this.secondaryIndexes?.threads; + if (filter.threadDescription && threads) { + const threadsInScope = await threads.lookupThread( + filter.threadDescription, + ); if (threadsInScope) { scopeSelectors ??= []; scopeSelectors.push( new q.ThreadSelector( threadsInScope.map( - (t) => threadIndex.threads[t.threadIndex], + (t) => threads.threads[t.threadIndex], ), ), ); diff --git a/ts/packages/knowPro/src/secondaryIndexes.ts b/ts/packages/knowPro/src/secondaryIndexes.ts index 628b0f9e..4f655a8c 100644 --- a/ts/packages/knowPro/src/secondaryIndexes.ts +++ b/ts/packages/knowPro/src/secondaryIndexes.ts @@ -18,7 +18,7 @@ export interface IConversationSecondaryIndexes { termToRelatedTermsIndex?: ITermToRelatedTermsIndex | undefined; propertyToSemanticRefIndex: IPropertyToSemanticRefIndex | undefined; timestampIndex?: ITimestampToTextRangeIndex | undefined; - threadIndex?: IConversationThreads | undefined; + threads?: IConversationThreads | undefined; } /**