Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
curtisman committed Feb 11, 2025
1 parent ff32d75 commit 64b7ebe
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 48 deletions.
9 changes: 8 additions & 1 deletion ts/packages/actionSchema/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,21 @@ export type GenerateSchemaOptions = {
strict?: boolean; // default true
exact?: boolean; // default false
jsonSchema?: boolean; // default false
jsonSchemaWithTs?: boolean; // default false, applies only when jsonSchema is true.
};

export function generateSchemaTypeDefinition(
definition: SchemaTypeDefinition,
options?: GenerateSchemaOptions,
order?: Map<string, number>,
) {
): string {
// wrap the action schema when json schema is active.
const jsonSchema = options?.jsonSchema ?? false;
const includeTs = !jsonSchema || (options?.jsonSchemaWithTs ?? false);
if (!includeTs) {
return "";
}

const strict = options?.strict ?? true;
const exact = options?.exact ?? false;
const emitted = new Map<
Expand Down
48 changes: 40 additions & 8 deletions ts/packages/actionSchema/src/jsonSchemaGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import * as sc from "./creator.js";
import {
ActionSchemaEntryTypeDefinition,
ActionSchemaGroup,
SchemaObjectField,
SchemaType,
SchemaTypeDefinition,
} from "./type.js";
Expand All @@ -18,38 +19,46 @@ export function wrapTypeWithJsonSchema(

type JsonSchemaObject = {
type: "object";
description?: string;
properties: Record<string, JsonSchema>;
required: string[];
additionalProperties: false;
};
type JsonSchemaArray = {
type: "array";
description?: string;
items: JsonSchema;
};

type JsonSchemaString = {
type: "string";
description?: string;
enum?: string[];
};

type JsonSchemaNumber = {
type: "number";
description?: string;
};

type JsonSchemaBoolean = {
type: "boolean";
};

type JsonSchemaUnion = {
anyOf: JsonSchema[];
description?: string;
};

type JsonSchemaNull = {
type: "null";
description?: string;
};

type JsonSchemaUnion = {
anyOf: JsonSchema[];
description?: string;
};

type JsonSchemaReference = {
$ref: string;
description?: string;
};

type JsonSchemaRoot = {
Expand All @@ -69,6 +78,14 @@ type JsonSchema =
| JsonSchemaUnion
| JsonSchemaReference;

function fieldComments(field: SchemaObjectField): string | undefined {
const combined = [
...(field.comments ?? []),
...(field.trailingComments ?? []),
];
return combined.length > 0 ? combined.join("\n") : undefined;
}

function generateJsonSchemaType(
type: SchemaType,
pending: SchemaTypeDefinition[],
Expand All @@ -79,10 +96,18 @@ function generateJsonSchemaType(
return {
type: "object",
properties: Object.fromEntries(
Object.entries(type.fields).map(([key, field]) => [
key,
generateJsonSchemaType(field.type, pending, strict),
]),
Object.entries(type.fields).map(([key, field]) => {
const fieldType = generateJsonSchemaType(
field.type,
pending,
strict,
);
const comments = fieldComments(field);
if (comments) {
fieldType.description = comments;
}
return [key, fieldType];
}),
),
required: Object.keys(type.fields),
additionalProperties: false,
Expand Down Expand Up @@ -134,6 +159,9 @@ function generateJsonSchemaTypeDefinition(
strict: true,
schema: generateJsonSchemaType(def.type, pending, strict),
};
if (def.comments) {
schema.schema.description = def.comments.join("\n");
}

if (pending.length !== 0) {
const $defs: Record<string, JsonSchema> = {};
Expand All @@ -147,6 +175,10 @@ function generateJsonSchemaTypeDefinition(
pending,
strict,
);
if (definition.comments) {
$defs[definition.name].description =
definition.comments.join("\n");
}
} while (pending.length > 0);
schema.schema.$defs = $defs;
}
Expand Down
65 changes: 41 additions & 24 deletions ts/packages/cli/src/commands/test/translate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,18 @@ export default class TestTranslateCommand extends Command {
description: "Output test result file",
required: true,
}),
success: Flags.boolean({
char: "s",
succeeded: Flags.boolean({
description:
"Copy failed test data and rerun only successful tests from the test result file",
}),
fail: Flags.boolean({
char: "f",
failed: Flags.boolean({
description:
"Copy pass test data and rerun only failed tests from the test result file",
}),
skipped: Flags.boolean({
description:
"Copy skipped test data and rerun only skipped tests from the test result file",
}),
input: Flags.string({
char: "i",
description: "Input test result file to get requests from",
Expand Down Expand Up @@ -167,7 +169,7 @@ export default class TestTranslateCommand extends Command {
}

const output: TestResultFile = { pass: [], fail: [] };
let requests: string[];
let requests: string[] = [];
let repeat: number;
if (flags.input) {
if (argv.length !== 0) {
Expand Down Expand Up @@ -196,29 +198,34 @@ export default class TestTranslateCommand extends Command {
}
}

const includeSuccess = flags.success || !flags.fail;
const includeFail = flags.fail || !flags.success;
if (flags.repeat !== undefined && flags.repeat !== repeat) {
throw new Error("Specified repeat doesn't match result file");
}

if (includeFail) {
requests = input.fail.map((entry) => entry.request);
if (input.skipped) {
requests = requests.concat(input.skipped);
}
const includeAll =
!flags.succeeded && !flags.failed && !flags.skipped;

if (includeAll || flags.succeeded) {
requests = requests.concat(
input.pass.map((entry) => entry.request),
);
} else {
requests = [];
output.pass = input.pass;
}
if (includeSuccess) {
if (flags.failed) {
output.pass = input.pass;
} else {
requests = input.pass
.map((entry) => entry.request)
.concat(requests);
}
if (includeAll || flags.failed) {
requests = requests.concat(
input.fail.map((entry) => entry.request),
);
} else {
output.fail = input.fail;
}

if (flags.repeat !== undefined && flags.repeat !== repeat) {
throw new Error("Specified repeat doesn't match result file");
if (input.skipped !== undefined) {
if (includeAll || flags.skipped) {
requests = requests.concat(input.skipped);
} else {
output.skipped = input.skipped;
}
}
} else {
repeat = flags.repeat ?? defaultRepeat;
Expand Down Expand Up @@ -270,7 +277,7 @@ export default class TestTranslateCommand extends Command {
function print(msg: string) {
processed++;
console.log(
`[${processed.toString().padStart(totalStr.length)}/${totalStr}] ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`,
`${chalk.white(`[${processed.toString().padStart(totalStr.length)}/${totalStr}]`)} ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`,
);
}
const concurrency = flags.concurrency ?? 4;
Expand Down Expand Up @@ -319,11 +326,21 @@ export default class TestTranslateCommand extends Command {

failedTotal++;
failed = true;
print(
chalk.red(
`Failed to consistently generate actions`,
),
);
break;
}
if (actual === undefined) {
failedTotal++;
failed = true;
print(
chalk.red(
`Failed to consistently generate actions`,
),
);
break;
}
if (actual.length !== expected.length) {
Expand Down
21 changes: 15 additions & 6 deletions ts/packages/dispatcher/src/context/chatHistoryPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ export function createTypeAgentRequestPrompt(
}
}

const prompts: string[] = [
`You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`,
`\`\`\``,
translator.validator.getSchemaText(),
`\`\`\``,
];
const prompts: string[] = [];
if (translator.validator.getSchemaText() === "") {
// If the schema is empty, we are skipping the type script schema because of json schema.
prompts.push(
`You are a service that translates user requests into JSON objects`,
);
} else {
prompts.push(
`You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`,
`\`\`\``,
translator.validator.getSchemaText(),
`\`\`\``,
);
}

if (context) {
if (history !== undefined) {
const promptSections: PromptSection[] = history.promptSections;
Expand Down
2 changes: 2 additions & 0 deletions ts/packages/dispatcher/src/context/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type DispatcherConfig = {
generation: {
enabled: boolean;
jsonSchema: boolean;
jsonSchemaWithTs: boolean; // only applies when jsonSchema is true
};
optimize: {
enabled: boolean;
Expand Down Expand Up @@ -157,6 +158,7 @@ const defaultSessionConfig: SessionConfig = {
generation: {
enabled: true,
jsonSchema: false,
jsonSchemaWithTs: true,
},
optimize: {
enabled: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function createActionSchemaJsonValidator<T extends TranslatedAction>(
};
}

export function createActionJsonTranslatorFromSchemaDef<
export function createJsonTranslatorFromActionSchema<
T extends TranslatedAction,
>(
typeName: string,
Expand Down
12 changes: 6 additions & 6 deletions ts/packages/dispatcher/src/translation/agentTranslators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import { createTypeAgentRequestPrompt } from "../context/chatHistoryPrompt.js";
import {
composeActionSchema,
composeSelectedActionSchema,
createActionJsonTranslatorFromSchemaDef,
createJsonTranslatorFromActionSchema,
} from "./actionSchemaJsonTranslator.js";
import {
ActionSchemaTypeDefinition,
generateActionSchema,
generateSchemaTypeDefinition,
ActionSchemaObject,
ActionSchemaCreator as sc,
GenerateSchemaOptions,
} from "action-schema";
import { ActionConfig } from "./actionConfig.js";
import { ActionConfigProvider } from "./actionConfigProvider.js";
Expand Down Expand Up @@ -262,12 +263,11 @@ export function loadAgentJsonTranslator<
multipleActionOptions: MultipleActionOptions,
regenerateSchema: boolean = true,
model?: string,
exact: boolean = true,
jsonSchema: boolean = false,
generateOptions?: GenerateSchemaOptions,
): TypeAgentTranslator<T> {
const options = { model };
const translator = regenerateSchema
? createActionJsonTranslatorFromSchemaDef<T>(
? createJsonTranslatorFromActionSchema<T>(
"AllActions",
composeActionSchema(
translatorName,
Expand All @@ -277,7 +277,7 @@ export function loadAgentJsonTranslator<
multipleActionOptions,
),
options,
{ exact, jsonSchema },
generateOptions,
)
: createJsonTranslatorFromSchemaDef<T>(
"AllActions",
Expand Down Expand Up @@ -368,7 +368,7 @@ export function createTypeAgentTranslatorForSelectedActions<
model?: string,
) {
const options = { model };
const translator = createActionJsonTranslatorFromSchemaDef<T>(
const translator = createJsonTranslatorFromActionSchema<T>(
"AllActions",
composeSelectedActionSchema(
definitions,
Expand Down
7 changes: 5 additions & 2 deletions ts/packages/dispatcher/src/translation/translateRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ export function getTranslatorForSchema(
config.multiple,
config.schema.generation.enabled,
config.model,
!config.schema.optimize.enabled,
config.schema.generation.jsonSchema,
{
exact: !config.schema.optimize.enabled,
jsonSchema: config.schema.generation.jsonSchema,
jsonSchemaWithTs: config.schema.generation.jsonSchemaWithTs,
},
);
context.translatorCache.set(translatorName, newTranslator);
return newTranslator;
Expand Down

0 comments on commit 64b7ebe

Please sign in to comment.