Skip to content

Commit

Permalink
knowpro: Threads simplification, refactoring (#750)
Browse files Browse the repository at this point in the history
Ongoing evolution of **threads**
* Simplified + Refactored Thread interfaces.
* Reusable ConversationThreads class
  • Loading branch information
umeshma authored Feb 25, 2025
1 parent e27f6d4 commit 2d489b3
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 118 deletions.
4 changes: 2 additions & 2 deletions ts/examples/chat/src/memory/podcastMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ export function createPodcastCommands(
podcastMessages,
knowledgeResponses,
);
kpPodcast.threadIndex.threads.push(...podcastThreads);
await kpPodcast.threadIndex.buildIndex();
kpPodcast.threads.threads.push(...podcastThreads);
await kpPodcast.buildIndex();

const podcastData = kpPodcast.serialize();
await ensureDir(path.dirname(namedArgs.filePath));
Expand Down
99 changes: 56 additions & 43 deletions ts/packages/knowPro/src/conversationThread.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import { NormalizedEmbedding } from "typeagent";
import { TextRange } from "./dataFormat.js";
import {
deserializeEmbedding,
serializeEmbedding,
TextEmbeddingIndex,
TextEmbeddingIndexSettings,
} from "./fuzzyIndex.js";
Expand All @@ -24,8 +25,15 @@ export type ScoredThreadIndex = {
};

export interface IConversationThreads {
threads: Thread[];
threadDescriptionIndex: IThreadDescriptionIndex;
readonly threads: Thread[];

addThread(thread: Thread): Promise<void>;
lookupThread(
threadDescription: string,
maxMatches?: number,
thresholdScore?: number,
): Promise<ScoredThreadIndex[] | undefined>;
removeThread(threadIndex: ThreadIndex): void;

serialize(): IConversationThreadData;
deserialize(data: IConversationThreadData): void;
Expand All @@ -40,47 +48,18 @@ export interface IThreadDataItem {
embedding: number[];
}

export interface IThreadDescriptionIndex {
addDescription(
description: string,
threadIndex: ThreadIndex | ScoredThreadIndex,
): Promise<void>;
lookupThread(
text: string,
maxMatches?: number,
thresholdScore?: number,
): Promise<ScoredThreadIndex[] | undefined>;
}

export class ThreadDescriptionIndex implements IThreadDescriptionIndex {
public threads: ScoredThreadIndex[];
export class ConversationThreads implements IConversationThreads {
public threads: Thread[];
public embeddingIndex: TextEmbeddingIndex;

constructor(public settings: TextEmbeddingIndexSettings) {
this.threads = [];
this.embeddingIndex = new TextEmbeddingIndex(settings);
}

public async addDescription(
description: string,
threadIndex: ThreadIndex | ScoredThreadIndex,
): Promise<void> {
if (typeof threadIndex === "number") {
threadIndex = {
threadIndex: threadIndex,
score: 1,
};
}
await this.embeddingIndex.addText(description);
this.threads.push(threadIndex);
}

public add(embedding: NormalizedEmbedding, threadIndex: ThreadIndex): void {
this.embeddingIndex.add(embedding);
this.threads.push({
threadIndex: threadIndex,
score: 1,
});
public async addThread(thread: Thread): Promise<void> {
this.threads.push(thread);
await this.embeddingIndex.addText(thread.description);
}

public async lookupThread(
Expand All @@ -99,17 +78,51 @@ export class ThreadDescriptionIndex implements IThreadDescriptionIndex {
}

public removeThread(threadIndex: ThreadIndex) {
const indexOf = this.threads.findIndex(
(t) => t.threadIndex === threadIndex,
);
if (indexOf >= 0) {
this.threads.splice(indexOf, 1);
this.embeddingIndex.removeAt(indexOf);
if (threadIndex >= 0) {
this.threads.splice(threadIndex, 1);
this.embeddingIndex.removeAt(threadIndex);
}
}

public clear(): void {
this.threads = [];
this.embeddingIndex.clear();
}

public async buildIndex(): Promise<void> {
this.embeddingIndex.clear();
for (let i = 0; i < this.threads.length; ++i) {
const thread = this.threads[i];
await this.embeddingIndex.addText(thread.description);
}
}

public serialize(): IConversationThreadData {
const threadData: IThreadDataItem[] = [];
const embeddingIndex = this.embeddingIndex;
for (let i = 0; i < this.threads.length; ++i) {
const thread = this.threads[i];
threadData.push({
thread,
embedding: serializeEmbedding(embeddingIndex.get(i)),
});
}
return {
threads: threadData,
};
}

public deserialize(data: IConversationThreadData): void {
if (data.threads) {
this.threads = [];
this.embeddingIndex.clear();
for (let i = 0; i < data.threads.length; ++i) {
this.threads.push(data.threads[i].thread);
const embedding = deserializeEmbedding(
data.threads[i].embedding,
);
this.embeddingIndex.add(embedding);
}
}
}
}
78 changes: 13 additions & 65 deletions ts/packages/knowPro/src/import.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import {
} from "./relatedTermsIndex.js";
import {
createTextEmbeddingIndexSettings,
deserializeEmbedding,
serializeEmbedding,
TextEmbeddingIndexSettings,
} from "./fuzzyIndex.js";
import { TimestampToTextRangeIndex } from "./timestampIndex.js";
Expand All @@ -40,10 +38,7 @@ import { addPropertiesToIndex, PropertyIndex } from "./propertyIndex.js";
import { IConversationSecondaryIndexes } from "./secondaryIndexes.js";
import {
IConversationThreadData,
IConversationThreads,
IThreadDataItem,
Thread,
ThreadDescriptionIndex,
ConversationThreads,
} from "./conversationThread.js";

// metadata for podcast messages
Expand Down Expand Up @@ -143,7 +138,7 @@ export class Podcast
implements IConversation<PodcastMessageMeta>, IConversationSecondaryIndexes
{
public settings: PodcastSettings;
public threadIndex: PodcastThreads;
public threads: ConversationThreads;

constructor(
public nameTag: string,
Expand All @@ -162,31 +157,31 @@ export class Podcast
| undefined = undefined,
) {
this.settings = createPodcastSettings();
this.threadIndex = new PodcastThreads(this.settings.threadSettings);
this.threads = new ConversationThreads(this.settings.threadSettings);
}

public addMetadataToIndex() {
for (let i = 0; i < this.messages.length; i++) {
const msg = this.messages[i];
const knowlegeResponse = msg.metadata.getKnowledge();
const knowledgeResponse = msg.metadata.getKnowledge();
if (this.semanticRefIndex !== undefined) {
for (const entity of knowlegeResponse.entities) {
for (const entity of knowledgeResponse.entities) {
addEntityToIndex(
entity,
this.semanticRefs,
this.semanticRefIndex,
i,
);
}
for (const action of knowlegeResponse.actions) {
for (const action of knowledgeResponse.actions) {
addActionToIndex(
action,
this.semanticRefs,
this.semanticRefIndex,
i,
);
}
for (const topic of knowlegeResponse.topics) {
for (const topic of knowledgeResponse.topics) {
addTopicToIndex(
topic,
this.semanticRefs,
Expand Down Expand Up @@ -217,7 +212,7 @@ export class Podcast
const result = await buildConversationIndex(this, progressCallback);
this.addMetadataToIndex();
this.buildSecondaryIndexes();
await this.threadIndex.buildIndex();
await this.threads.buildIndex();
return result;
}

Expand Down Expand Up @@ -246,7 +241,7 @@ export class Podcast
semanticRefs: this.semanticRefs,
semanticIndexData: this.semanticRefIndex?.serialize(),
relatedTermsIndexData: this.termToRelatedTermsIndex?.serialize(),
threadData: this.threadIndex.serialize(),
threadData: this.threads.serialize(),
};
}

Expand All @@ -265,8 +260,10 @@ export class Podcast
);
}
if (data.threadData) {
this.threadIndex = new PodcastThreads(this.settings.threadSettings);
this.threadIndex.deserialize(data.threadData);
this.threads = new ConversationThreads(
this.settings.threadSettings,
);
this.threads.deserialize(data.threadData);
}
this.buildSecondaryIndexes();
}
Expand Down Expand Up @@ -443,52 +440,3 @@ function randomDate(startHour = 14) {
date.setDate(Math.floor(Math.random() * 28));
return date;
}

class PodcastThreads implements IConversationThreads {
public threads: Thread[];
public threadDescriptionIndex: ThreadDescriptionIndex;

constructor(settings: TextEmbeddingIndexSettings) {
this.threads = [];
this.threadDescriptionIndex = new ThreadDescriptionIndex(settings);
}

public async buildIndex(): Promise<void> {
for (let i = 0; i < this.threads.length; ++i) {
const thread = this.threads[i];
await this.threadDescriptionIndex.addDescription(
thread.description,
i,
);
}
}

public serialize(): IConversationThreadData {
const threadData: IThreadDataItem[] = [];
const embeddingIndex = this.threadDescriptionIndex.embeddingIndex;
for (let i = 0; i < this.threads.length; ++i) {
const thread = this.threads[i];
threadData.push({
thread,
embedding: serializeEmbedding(embeddingIndex.get(i)),
});
}
return {
threads: threadData,
};
}

public deserialize(data: IConversationThreadData): void {
if (data.threads) {
this.threads = [];
this.threadDescriptionIndex.clear();
for (let i = 0; i < data.threads.length; ++i) {
this.threads.push(data.threads[i].thread);
const embedding = deserializeEmbedding(
data.threads[i].embedding,
);
this.threadDescriptionIndex.add(embedding, i);
}
}
}
}
13 changes: 6 additions & 7 deletions ts/packages/knowPro/src/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,18 +283,17 @@ class SearchQueryBuilder {
this.allScopeSearchTerms.push(...searchTermsUsed);
}
// If a thread index is available...
const threadIndex = this.secondaryIndexes?.threadIndex;
if (filter.threadDescription && threadIndex) {
const threadsInScope =
await threadIndex.threadDescriptionIndex.lookupThread(
filter.threadDescription,
);
const threads = this.secondaryIndexes?.threads;
if (filter.threadDescription && threads) {
const threadsInScope = await threads.lookupThread(
filter.threadDescription,
);
if (threadsInScope) {
scopeSelectors ??= [];
scopeSelectors.push(
new q.ThreadSelector(
threadsInScope.map(
(t) => threadIndex.threads[t.threadIndex],
(t) => threads.threads[t.threadIndex],
),
),
);
Expand Down
2 changes: 1 addition & 1 deletion ts/packages/knowPro/src/secondaryIndexes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface IConversationSecondaryIndexes {
termToRelatedTermsIndex?: ITermToRelatedTermsIndex | undefined;
propertyToSemanticRefIndex: IPropertyToSemanticRefIndex | undefined;
timestampIndex?: ITimestampToTextRangeIndex | undefined;
threadIndex?: IConversationThreads | undefined;
threads?: IConversationThreads | undefined;
}

/**
Expand Down

0 comments on commit 2d489b3

Please sign in to comment.