Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

knowpro: relative term matching improvements #694

Merged
merged 9 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions ts/packages/knowPro/src/collections.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import { isInTextRange } from "./query.js";
export interface Match<T = any> {
value: T;
score: number;
exactHitCount: number;
hitCount: number;
relatedScore: number;
relatedHitCount: number;
}

/**
Expand Down Expand Up @@ -66,13 +68,46 @@ export class MatchAccumulator<T = any> {
public add(value: T, score: number, isExactMatch: boolean) {
const existingMatch = this.getMatch(value);
if (existingMatch) {
this.updateExisting(existingMatch, score, isExactMatch);
//this.updateExisting(existingMatch, score, isExactMatch);
if (isExactMatch) {
existingMatch.hitCount++;
existingMatch.score += score;
} else {
existingMatch.relatedHitCount++;
existingMatch.relatedScore += score;
}
} else {
this.setMatch({
value,
exactHitCount: isExactMatch ? 1 : 0,
score,
});
if (isExactMatch) {
this.setMatch({
value,
hitCount: 1,
score,
relatedHitCount: 0,
relatedScore: 0,
});
} else {
this.setMatch({
value,
hitCount: 0,
score: 0,
relatedHitCount: 1,
relatedScore: score,
});
}
}
}

public calculateTotalScore(scoreScaler?: (match: Match<T>) => void) {
scoreScaler ??= (m) => {
if (m.relatedHitCount > 0) {
const avgScore = m.relatedScore / m.relatedHitCount;
const normalizedScore = Math.log(1 + avgScore);
m.score += normalizedScore;
//m.score += m.relatedScore;
}
};
for (const match of this.getMatches()) {
scoreScaler(match);
}
}

Expand All @@ -82,10 +117,10 @@ export class MatchAccumulator<T = any> {
isExactMatch: boolean,
): void {
if (isExactMatch) {
existingMatch.exactHitCount++;
existingMatch.hitCount++;
existingMatch.score += newScore;
} else if (existingMatch.score < newScore) {
existingMatch.score = newScore;
} else if (existingMatch.relatedScore < newScore) {
existingMatch.relatedScore = newScore;
}
}

Expand Down Expand Up @@ -149,7 +184,7 @@ export class MatchAccumulator<T = any> {
minHitCount: number | undefined,
): IterableIterator<Match<T>> {
return minHitCount !== undefined && minHitCount > 0
? this.getMatches((m) => m.exactHitCount >= minHitCount)
? this.getMatches((m) => m.hitCount >= minHitCount)
: this.matches.values();
}
}
Expand All @@ -168,14 +203,14 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
| IterableIterator<ScoredSemanticRef>
| undefined,
isExactMatch: boolean,
scoreBoost?: number,
weight?: number,
) {
if (scoredRefs) {
scoreBoost ??= searchTerm.score ?? 0;
weight ??= searchTerm.weight ?? 1;
for (const scoredRef of scoredRefs) {
this.add(
scoredRef.semanticRefIndex,
scoredRef.score + scoreBoost,
scoredRef.score * weight,
isExactMatch,
);
}
Expand All @@ -190,16 +225,16 @@ export class SemanticRefAccumulator extends MatchAccumulator<SemanticRefIndex> {
| IterableIterator<ScoredSemanticRef>
| undefined,
isExactMatch: boolean,
scoreBoost?: number,
weight?: number,
) {
if (scoredRefs) {
scoreBoost ??= searchTerm.score ?? 0;
weight ??= searchTerm.weight ?? 1;
for (const scoredRef of scoredRefs) {
const existingMatch = this.getMatch(scoredRef.semanticRefIndex);
if (existingMatch) {
this.updateExisting(
existingMatch,
scoredRef.score + scoreBoost,
scoredRef.score * weight,
isExactMatch,
);
} else {
Expand Down Expand Up @@ -373,25 +408,46 @@ export class TextRangeCollection {
}

export class TermSet {
constructor(private terms: Map<string, Term> = new Map()) {}

public add(term: Term) {
const existingTerm = this.terms.get(term.text);
if (!existingTerm) {
this.terms.set(term.text, term);
private terms: Map<string, Term> = new Map();
constructor(terms?: Term[]) {
if (terms) {
this.addOrUnion(terms);
}
}

public addOrUnion(term: Term) {
public get size() {
return this.terms.size;
}

public add(term: Term): boolean {
const existingTerm = this.terms.get(term.text);
if (existingTerm) {
const existingScore = existingTerm.score ?? 0;
const newScore = term.score ?? 0;
if (existingScore < newScore) {
existingTerm.score = newScore;
return false;
}
this.terms.set(term.text, term);
return true;
}

public addOrUnion(terms: Term | Term[] | undefined) {
if (terms === undefined) {
return;
}
if (Array.isArray(terms)) {
for (const term of terms) {
this.addOrUnion(term);
}
} else {
this.terms.set(term.text, term);
const term = terms;
const existingTerm = this.terms.get(term.text);
if (existingTerm) {
const existingScore = existingTerm.weight ?? 0;
const newScore = term.weight ?? 0;
if (existingScore < newScore) {
existingTerm.weight = newScore;
}
} else {
this.terms.set(term.text, term);
}
}
}

Expand All @@ -401,13 +457,24 @@ export class TermSet {
: this.terms.get(term.text);
}

public getWeight(term: Term): number | undefined {
return this.terms.get(term.text)?.weight;
}

public has(term: Term): boolean {
return this.terms.has(term.text);
}

public remove(term: Term) {
this.terms.delete(term.text);
}
public clear(): void {
this.terms.clear();
}

public values() {
return this.terms.values();
}
}

export class PropertyTermSet {
Expand Down
23 changes: 17 additions & 6 deletions ts/packages/knowPro/src/dataFormat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import { conversation } from "knowledge-processor";
import { NormalizedEmbedding } from "typeagent";

// an object that can provide a KnowledgeResponse structure
export interface IKnowledgeSource {
Expand Down Expand Up @@ -124,14 +125,14 @@ export interface IConversationData<TMessage> {
export type Term = {
text: string;
/**
* Optional additional score to use when this term matches
* Optional weighting for these matches
*/
score?: number | undefined;
weight?: number | undefined;
};

export interface ITermToRelatedTermsIndex {
lookupTerm(termText: string): Term[] | undefined;
lookupTermFuzzy(termText: string): Promise<Term[] | undefined>;
get termEmbeddings(): ITermEmbeddingIndex | undefined;
serialize(): ITermsToRelatedTermsIndexData;
deserialize(data?: ITermsToRelatedTermsIndexData): void;
}
Expand All @@ -151,11 +152,21 @@ export interface ITermsToRelatedTermsDataItem {
}

export interface ITermEmbeddingIndex {
lookupTermsFuzzy(
term: string,
lookupTerm(
text: string,
maxMatches?: number,
minScore?: number,
): Promise<Term[]>;
lookupTerms(
texts: string[],
maxMatches?: number,
minScore?: number,
): Promise<Term[][]>;
lookupEmbeddings(
text: string,
maxMatches?: number,
minScore?: number,
): Promise<Term[] | undefined>;
): Promise<[string, NormalizedEmbedding][] | undefined>;
serialize(): ITextEmbeddingIndexData;
deserialize(data: ITextEmbeddingIndexData): void;
}
Expand Down
31 changes: 7 additions & 24 deletions ts/packages/knowPro/src/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ export class MatchAllTermsExpr extends QueryOpExpr<SemanticRefAccumulator> {
for (const matchExpr of this.searchTermExpressions) {
matchExpr.accumulateMatches(context, allMatches);
}
allMatches.calculateTotalScore();
return allMatches;
}
}
Expand Down Expand Up @@ -374,27 +375,18 @@ export class MatchSearchTermExpr extends MatchTermExpr {
) {
if (relatedTerm === undefined) {
const semanticRefs = this.lookupTerm(context, term);
if (context.matchedTerms.has(term)) {
matches.updateTermMatches(term, semanticRefs, true);
} else {
if (!context.matchedTerms.has(term)) {
matches.addTermMatches(term, semanticRefs, true);
context.matchedTerms.add(term);
}
} else {
const semanticRefs = this.lookupTerm(context, relatedTerm);
if (context.matchedTerms.has(relatedTerm)) {
matches.updateTermMatches(
term,
semanticRefs,
false,
relatedTerm.score,
);
} else {
if (!context.matchedTerms.has(relatedTerm)) {
matches.addTermMatches(
term,
semanticRefs,
false,
relatedTerm.score,
relatedTerm.weight,
);
context.matchedTerms.add(relatedTerm);
}
Expand Down Expand Up @@ -508,9 +500,7 @@ export class MatchPropertyTermExpr extends MatchTermExpr {
propName,
propVal.text,
);
if (context.matchedPropertyTerms.has(propName, propVal)) {
matches.updateTermMatches(propVal, semanticRefs, true);
} else {
if (!context.matchedPropertyTerms.has(propName, propVal)) {
matches.addTermMatches(propVal, semanticRefs, true);
context.matchedPropertyTerms.add(propName, propVal);
}
Expand All @@ -519,19 +509,12 @@ export class MatchPropertyTermExpr extends MatchTermExpr {
propName,
relatedPropVal.text,
);
if (context.matchedPropertyTerms.has(propName, relatedPropVal)) {
matches.updateTermMatches(
propVal,
semanticRefs,
false,
relatedPropVal.score,
);
} else {
if (!context.matchedPropertyTerms.has(propName, relatedPropVal)) {
matches.addTermMatches(
propVal,
semanticRefs,
false,
relatedPropVal.score,
relatedPropVal.weight,
);
context.matchedPropertyTerms.add(propName, relatedPropVal);
}
Expand Down
Loading
Loading