Skip to content

Commit

Permalink
knowpro improvements (#677)
Browse files Browse the repository at this point in the history
* Improved:
  * Property and facet matching
* Scoring: consistent scoring for search terms, property terms, related
terms
  * Query compilation improvements; cleaner query expressions. 
  * Test app
* Code refactoring
  • Loading branch information
umeshma authored Feb 6, 2025
1 parent 8cbb9c2 commit 8341dfa
Show file tree
Hide file tree
Showing 7 changed files with 467 additions and 351 deletions.
24 changes: 22 additions & 2 deletions ts/examples/chat/src/memory/knowproMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ export async function createKnowproCommands(
maxToDisplay: argNum("Maximum matches to display", 25),
startMinute: argNum("Starting at minute."),
endMinute: argNum("Ending minute."),
exact: argBool("Only display exact matches", false),
},
};
if (kType === undefined) {
Expand Down Expand Up @@ -251,8 +252,10 @@ export async function createKnowproCommands(
const matches = await kp.searchConversation(
conversation,
terms,
keyValuesFromNamedArgs(namedArgs, commandDef),
propertyTermsFromNamedArgs(namedArgs, commandDef),
filterFromNamedArgs(namedArgs),
undefined,
namedArgs.exact ? 1 : undefined,
);
if (matches === undefined || matches.size === 0) {
context.printer.writeLine("No matches");
Expand All @@ -270,6 +273,23 @@ export async function createKnowproCommands(
}
}

function propertyTermsFromNamedArgs(
namedArgs: NamedArgs,
commandDef: CommandMetadata,
): kp.PropertySearchTerm[] {
const keyValues = keyValuesFromNamedArgs(namedArgs, commandDef);
const propertySearchTerms: kp.PropertySearchTerm[] = [];
for (const propertyName of Object.keys(keyValues)) {
const propertyValue = keyValues[propertyName];
const propertySearchTerm = kp.propertySearchTermFromKeyValue(
propertyName,
propertyValue,
);
propertySearchTerms.push(propertySearchTerm);
}
return propertySearchTerms;
}

function filterFromNamedArgs(namedArgs: NamedArgs) {
let filter: kp.SearchFilter = {
type: namedArgs.ktype,
Expand Down Expand Up @@ -425,7 +445,7 @@ export function parseQueryTerms(args: string[]): kp.SearchTerm[] {
const queryTerm: kp.SearchTerm = {
term: { text: allTermStrings[0] },
};
if (allTermStrings.length > 0) {
if (allTermStrings.length > 1) {
queryTerm.relatedTerms = [];
for (let i = 1; i < allTermStrings.length; ++i) {
queryTerm.relatedTerms.push({ text: allTermStrings[i] });
Expand Down
2 changes: 1 addition & 1 deletion ts/examples/chat/src/memory/knowproPrinter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ export class KnowProPrinter extends ChatPrinter {

this.writeInColor(
chalk.green,
`#${i + 1}: ${semanticRef.knowledgeType} [${match.score}]`,
`#${i + 1}: <${match.semanticRefIndex}> ${semanticRef.knowledgeType} [${match.score}]`,
);
this.writeSemanticRef(semanticRef);
this.writeLine();
Expand Down
203 changes: 109 additions & 94 deletions ts/packages/knowPro/src/collections.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ import { isInTextRange } from "./query.js";
export interface Match<T = any> {
value: T;
score: number;
hitCount: number;
exactMatch: boolean;
exactHitCount: number;
}

/**
Expand All @@ -31,21 +30,15 @@ export function sortMatchesByRelevance(matches: Match[]) {

export class MatchAccumulator<T = any> {
private matches: Map<T, Match<T>>;
private maxHitCount: number;

constructor() {
this.matches = new Map<T, Match<T>>();
this.maxHitCount = 0;
}

public get size(): number {
return this.matches.size;
}

public get maxHits(): number {
return this.maxHitCount;
}

public has(value: T): boolean {
return this.matches.has(value);
}
Expand All @@ -56,9 +49,6 @@ export class MatchAccumulator<T = any> {

public setMatch(match: Match<T>): void {
this.matches.set(match.value, match);
if (match.hitCount > this.maxHitCount) {
this.maxHitCount = match.hitCount;
}
}

public setMatches(
Expand All @@ -73,46 +63,29 @@ export class MatchAccumulator<T = any> {
}
}

public add(value: T, score: number, isNewMatchExact: boolean): void {
let match = this.matches.get(value);
if (match) {
// Increment the existing match
if (isNewMatchExact) {
match.hitCount += 1;
match.exactMatch = true;
}
match.score += score;
public add(value: T, score: number, isExactMatch: boolean) {
const existingMatch = this.getMatch(value);
if (existingMatch) {
this.updateExisting(existingMatch, score, isExactMatch);
} else {
// New match
match = {
this.setMatch({
value,
exactHitCount: isExactMatch ? 1 : 0,
score,
hitCount: 1,
exactMatch: isNewMatchExact,
};
this.matches.set(value, match);
}
if (match.hitCount > this.maxHitCount) {
this.maxHitCount = match.hitCount;
});
}
}

public addUnion(other: MatchAccumulator<T>): void {
for (const otherMatch of other.matches.values()) {
const existingMatch = this.matches.get(otherMatch.value);
if (existingMatch) {
if (otherMatch.exactMatch) {
existingMatch.hitCount += otherMatch.hitCount;
} else if (existingMatch.hitCount < otherMatch.hitCount) {
existingMatch.hitCount = otherMatch.hitCount;
}
existingMatch.score += otherMatch.score;
if (existingMatch.hitCount > this.maxHitCount) {
this.maxHitCount = existingMatch.hitCount;
}
} else {
this.setMatch(otherMatch);
}
protected updateExisting(
existingMatch: Match,
newScore: number,
isExactMatch: boolean,
): void {
if (isExactMatch) {
existingMatch.exactHitCount++;
existingMatch.score += newScore;
} else if (existingMatch.score < newScore) {
existingMatch.score = newScore;
}
}

Expand Down Expand Up @@ -161,25 +134,22 @@ export class MatchAccumulator<T = any> {

public clearMatches(): void {
this.matches.clear();
this.maxHitCount = 0;
}

public selectTopNScoring(
maxMatches?: number,
minHitCount?: number,
): number {
const topN = this.getTopNScoring(maxMatches, minHitCount);
if (topN.length > 0) {
this.setMatches(topN, true);
}
this.setMatches(topN, true);
return topN.length;
}

private matchesWithMinHitCount(
minHitCount: number | undefined,
): IterableIterator<Match<T>> {
return minHitCount !== undefined && minHitCount > 0
? this.getMatches((m) => m.hitCount >= minHitCount)
? this.getMatches((m) => m.exactHitCount >= minHitCount)
: this.matches.values();
}
}
Expand All @@ -191,80 +161,67 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
super();
}

public addSearchTermMatch(
public addTermMatches(
searchTerm: Term,
semanticRefs:
scoredRefs:
| ScoredSemanticRef[]
| IterableIterator<ScoredSemanticRef>
| undefined,
isExactMatch: boolean,
scoreBoost?: number,
) {
if (semanticRefs) {
if (scoredRefs) {
scoreBoost ??= searchTerm.score ?? 0;
for (const match of semanticRefs) {
for (const scoredRef of scoredRefs) {
this.add(
match.semanticRefIndex,
match.score + scoreBoost,
true,
scoredRef.semanticRefIndex,
scoredRef.score + scoreBoost,
isExactMatch,
);
}
this.searchTermMatches.add(searchTerm.text);
}
}

public addRelatedTermMatch(
public updateTermMatches(
searchTerm: Term,
relatedTerm: Term,
semanticRefs:
scoredRefs:
| ScoredSemanticRef[]
| IterableIterator<ScoredSemanticRef>
| undefined,
isExactMatch: boolean,
scoreBoost?: number,
) {
if (semanticRefs) {
// Related term matches count as matches for the queryTerm...
// BUT are scored with the score of the related term
scoreBoost ??= relatedTerm.score ?? 0;
for (const semanticRef of semanticRefs) {
let score = semanticRef.score + scoreBoost;
let match = this.getMatch(semanticRef.semanticRefIndex);
if (match !== undefined) {
if (match.score < score) {
match.score = score;
}
if (scoredRefs) {
scoreBoost ??= searchTerm.score ?? 0;
for (const scoredRef of scoredRefs) {
const existingMatch = this.getMatch(scoredRef.semanticRefIndex);
if (existingMatch) {
this.updateExisting(
existingMatch,
scoredRef.score + scoreBoost,
isExactMatch,
);
} else {
match = {
value: semanticRef.semanticRefIndex,
score,
hitCount: 1,
exactMatch: false,
};
this.setMatch(match);
throw new Error(
`No existing match for ${searchTerm.text} Id: ${scoredRef.semanticRefIndex}`,
);
}
}
this.searchTermMatches.add(searchTerm.text);
}
}

public addUnion(other: SemanticRefAccumulator): void {
super.addUnion(other);
unionInPlace(this.searchTermMatches, other.searchTermMatches);
}

public override getSortedByScore(
minHitCount?: number,
): Match<SemanticRefIndex>[] {
return super.getSortedByScore(this.getMinHitCount(minHitCount));
return super.getSortedByScore(minHitCount);
}

public override getTopNScoring(
maxMatches?: number,
minHitCount?: number,
): Match<SemanticRefIndex>[] {
return super.getTopNScoring(
maxMatches,
this.getMinHitCount(minHitCount),
);
return super.getTopNScoring(maxMatches, minHitCount);
}

public *getSemanticRefs(
Expand Down Expand Up @@ -351,11 +308,6 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
};
}, 0);
}

private getMinHitCount(minHitCount?: number): number {
return minHitCount !== undefined ? minHitCount : this.maxHits;
//: this.queryTermMatches.termMatches.size;
}
}

export class MessageAccumulator extends MatchAccumulator<IMessage> {}
Expand Down Expand Up @@ -420,6 +372,69 @@ export class TextRangeCollection {
}
}

export class TermSet {
constructor(private terms: Map<string, Term> = new Map()) {}

public add(term: Term) {
const existingTerm = this.terms.get(term.text);
if (!existingTerm) {
this.terms.set(term.text, term);
}
}

public addOrUnion(term: Term) {
const existingTerm = this.terms.get(term.text);
if (existingTerm) {
const existingScore = existingTerm.score ?? 0;
const newScore = term.score ?? 0;
if (existingScore < newScore) {
existingTerm.score = newScore;
}
} else {
this.terms.set(term.text, term);
}
}

public get(term: string | Term): Term | undefined {
return typeof term === "string"
? this.terms.get(term)
: this.terms.get(term.text);
}

public has(term: Term): boolean {
return this.terms.has(term.text);
}

public clear(): void {
this.terms.clear();
}
}

export class PropertyTermSet {
constructor(private terms: Map<string, Term> = new Map()) {}

public add(propertyName: string, propertyValue: Term) {
const key = this.makeKey(propertyName, propertyValue);
const existingTerm = this.terms.get(key);
if (!existingTerm) {
this.terms.set(key, propertyValue);
}
}

public has(propertyName: string, propertyValue: Term): boolean {
const key = this.makeKey(propertyName, propertyValue);
return this.terms.has(key);
}

public clear(): void {
this.terms.clear();
}

private makeKey(propertyName: string, propertyValue: Term): string {
return propertyName + ":" + propertyValue.text;
}
}

/**
* Return a new set that is the union of two sets
* @param x
Expand Down
Loading

0 comments on commit 8341dfa

Please sign in to comment.