diff --git a/ts/packages/actionSchema/src/generator.ts b/ts/packages/actionSchema/src/generator.ts index da0bcccc..b95e3447 100644 --- a/ts/packages/actionSchema/src/generator.ts +++ b/ts/packages/actionSchema/src/generator.ts @@ -124,14 +124,21 @@ export type GenerateSchemaOptions = { strict?: boolean; // default true exact?: boolean; // default false jsonSchema?: boolean; // default false + jsonSchemaWithTs?: boolean; // default false, applies only when jsonSchema is true. }; export function generateSchemaTypeDefinition( definition: SchemaTypeDefinition, options?: GenerateSchemaOptions, order?: Map, -) { +): string { + // wrap the action schema when json schema is active. const jsonSchema = options?.jsonSchema ?? false; + const includeTs = !jsonSchema || (options?.jsonSchemaWithTs ?? false); + if (!includeTs) { + return ""; + } + const strict = options?.strict ?? true; const exact = options?.exact ?? false; const emitted = new Map< diff --git a/ts/packages/actionSchema/src/jsonSchemaGenerator.ts b/ts/packages/actionSchema/src/jsonSchemaGenerator.ts index 57d1bdd1..d2146b40 100644 --- a/ts/packages/actionSchema/src/jsonSchemaGenerator.ts +++ b/ts/packages/actionSchema/src/jsonSchemaGenerator.ts @@ -5,6 +5,7 @@ import * as sc from "./creator.js"; import { ActionSchemaEntryTypeDefinition, ActionSchemaGroup, + SchemaObjectField, SchemaType, SchemaTypeDefinition, } from "./type.js"; @@ -18,38 +19,46 @@ export function wrapTypeWithJsonSchema( type JsonSchemaObject = { type: "object"; + description?: string; properties: Record; required: string[]; additionalProperties: false; }; type JsonSchemaArray = { type: "array"; + description?: string; items: JsonSchema; }; type JsonSchemaString = { type: "string"; + description?: string; enum?: string[]; }; type JsonSchemaNumber = { type: "number"; + description?: string; }; type JsonSchemaBoolean = { type: "boolean"; -}; - -type JsonSchemaUnion = { - anyOf: JsonSchema[]; + description?: string; }; type JsonSchemaNull = { type: "null"; + description?: string; +}; + +type JsonSchemaUnion = { + anyOf: JsonSchema[]; + description?: string; }; type JsonSchemaReference = { $ref: string; + description?: string; }; type JsonSchemaRoot = { @@ -69,6 +78,14 @@ type JsonSchema = | JsonSchemaUnion | JsonSchemaReference; +function fieldComments(field: SchemaObjectField): string | undefined { + const combined = [ + ...(field.comments ?? []), + ...(field.trailingComments ?? []), + ]; + return combined.length > 0 ? combined.join("\n") : undefined; +} + function generateJsonSchemaType( type: SchemaType, pending: SchemaTypeDefinition[], @@ -79,10 +96,18 @@ function generateJsonSchemaType( return { type: "object", properties: Object.fromEntries( - Object.entries(type.fields).map(([key, field]) => [ - key, - generateJsonSchemaType(field.type, pending, strict), - ]), + Object.entries(type.fields).map(([key, field]) => { + const fieldType = generateJsonSchemaType( + field.type, + pending, + strict, + ); + const comments = fieldComments(field); + if (comments) { + fieldType.description = comments; + } + return [key, fieldType]; + }), ), required: Object.keys(type.fields), additionalProperties: false, @@ -134,6 +159,9 @@ function generateJsonSchemaTypeDefinition( strict: true, schema: generateJsonSchemaType(def.type, pending, strict), }; + if (def.comments) { + schema.schema.description = def.comments.join("\n"); + } if (pending.length !== 0) { const $defs: Record = {}; @@ -147,6 +175,10 @@ function generateJsonSchemaTypeDefinition( pending, strict, ); + if (definition.comments) { + $defs[definition.name].description = + definition.comments.join("\n"); + } } while (pending.length > 0); schema.schema.$defs = $defs; } diff --git a/ts/packages/cli/src/commands/test/translate.ts b/ts/packages/cli/src/commands/test/translate.ts index 62f8b715..2b7c25fd 100644 --- a/ts/packages/cli/src/commands/test/translate.ts +++ b/ts/packages/cli/src/commands/test/translate.ts @@ -128,16 +128,18 @@ export default class TestTranslateCommand extends Command { description: "Output test result file", required: true, }), - success: Flags.boolean({ - char: "s", + succeeded: Flags.boolean({ description: "Copy failed test data and rerun only successful tests from the test result file", }), - fail: Flags.boolean({ - char: "f", + failed: Flags.boolean({ description: "Copy pass test data and rerun only failed tests from the test result file", }), + skipped: Flags.boolean({ + description: + "Copy skipped test data and rerun only skipped tests from the test result file", + }), input: Flags.string({ char: "i", description: "Input test result file to get requests from", @@ -167,7 +169,7 @@ export default class TestTranslateCommand extends Command { } const output: TestResultFile = { pass: [], fail: [] }; - let requests: string[]; + let requests: string[] = []; let repeat: number; if (flags.input) { if (argv.length !== 0) { @@ -196,29 +198,34 @@ export default class TestTranslateCommand extends Command { } } - const includeSuccess = flags.success || !flags.fail; - const includeFail = flags.fail || !flags.success; + if (flags.repeat !== undefined && flags.repeat !== repeat) { + throw new Error("Specified repeat doesn't match result file"); + } - if (includeFail) { - requests = input.fail.map((entry) => entry.request); - if (input.skipped) { - requests = requests.concat(input.skipped); - } + const includeAll = + !flags.succeeded && !flags.failed && !flags.skipped; + + if (includeAll || flags.succeeded) { + requests = requests.concat( + input.pass.map((entry) => entry.request), + ); } else { - requests = []; + output.pass = input.pass; } - if (includeSuccess) { - if (flags.failed) { - output.pass = input.pass; - } else { - requests = input.pass - .map((entry) => entry.request) - .concat(requests); - } + if (includeAll || flags.failed) { + requests = requests.concat( + input.fail.map((entry) => entry.request), + ); + } else { + output.fail = input.fail; } - if (flags.repeat !== undefined && flags.repeat !== repeat) { - throw new Error("Specified repeat doesn't match result file"); + if (input.skipped !== undefined) { + if (includeAll || flags.skipped) { + requests = requests.concat(input.skipped); + } else { + output.skipped = input.skipped; + } } } else { repeat = flags.repeat ?? defaultRepeat; @@ -270,7 +277,7 @@ export default class TestTranslateCommand extends Command { function print(msg: string) { processed++; console.log( - `[${processed.toString().padStart(totalStr.length)}/${totalStr}] ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`, + `${chalk.white(`[${processed.toString().padStart(totalStr.length)}/${totalStr}]`)} ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`, ); } const concurrency = flags.concurrency ?? 4; @@ -319,11 +326,21 @@ export default class TestTranslateCommand extends Command { failedTotal++; failed = true; + print( + chalk.red( + `Failed to consistently generate actions`, + ), + ); break; } if (actual === undefined) { failedTotal++; failed = true; + print( + chalk.red( + `Failed to consistently generate actions`, + ), + ); break; } if (actual.length !== expected.length) { diff --git a/ts/packages/dispatcher/src/context/chatHistoryPrompt.ts b/ts/packages/dispatcher/src/context/chatHistoryPrompt.ts index 23057e35..57efff95 100644 --- a/ts/packages/dispatcher/src/context/chatHistoryPrompt.ts +++ b/ts/packages/dispatcher/src/context/chatHistoryPrompt.ts @@ -22,12 +22,21 @@ export function createTypeAgentRequestPrompt( } } - const prompts: string[] = [ - `You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`, - `\`\`\``, - translator.validator.getSchemaText(), - `\`\`\``, - ]; + const prompts: string[] = []; + if (translator.validator.getSchemaText() === "") { + // If the schema is empty, we are skipping the type script schema because of json schema. + prompts.push( + `You are a service that translates user requests into JSON objects`, + ); + } else { + prompts.push( + `You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`, + `\`\`\``, + translator.validator.getSchemaText(), + `\`\`\``, + ); + } + if (context) { if (history !== undefined) { const promptSections: PromptSection[] = history.promptSections; diff --git a/ts/packages/dispatcher/src/context/session.ts b/ts/packages/dispatcher/src/context/session.ts index 9b25868b..30e51838 100644 --- a/ts/packages/dispatcher/src/context/session.ts +++ b/ts/packages/dispatcher/src/context/session.ts @@ -90,6 +90,7 @@ type DispatcherConfig = { generation: { enabled: boolean; jsonSchema: boolean; + jsonSchemaWithTs: boolean; // only applies when jsonSchema is true }; optimize: { enabled: boolean; @@ -157,6 +158,7 @@ const defaultSessionConfig: SessionConfig = { generation: { enabled: true, jsonSchema: false, + jsonSchemaWithTs: true, }, optimize: { enabled: false, diff --git a/ts/packages/dispatcher/src/translation/actionSchemaJsonTranslator.ts b/ts/packages/dispatcher/src/translation/actionSchemaJsonTranslator.ts index 0daadb3f..5b567934 100644 --- a/ts/packages/dispatcher/src/translation/actionSchemaJsonTranslator.ts +++ b/ts/packages/dispatcher/src/translation/actionSchemaJsonTranslator.ts @@ -70,7 +70,7 @@ function createActionSchemaJsonValidator( }; } -export function createActionJsonTranslatorFromSchemaDef< +export function createJsonTranslatorFromActionSchema< T extends TranslatedAction, >( typeName: string, diff --git a/ts/packages/dispatcher/src/translation/agentTranslators.ts b/ts/packages/dispatcher/src/translation/agentTranslators.ts index bf8b9a16..55269498 100644 --- a/ts/packages/dispatcher/src/translation/agentTranslators.ts +++ b/ts/packages/dispatcher/src/translation/agentTranslators.ts @@ -26,7 +26,7 @@ import { createTypeAgentRequestPrompt } from "../context/chatHistoryPrompt.js"; import { composeActionSchema, composeSelectedActionSchema, - createActionJsonTranslatorFromSchemaDef, + createJsonTranslatorFromActionSchema, } from "./actionSchemaJsonTranslator.js"; import { ActionSchemaTypeDefinition, @@ -34,6 +34,7 @@ import { generateSchemaTypeDefinition, ActionSchemaObject, ActionSchemaCreator as sc, + GenerateSchemaOptions, } from "action-schema"; import { ActionConfig } from "./actionConfig.js"; import { ActionConfigProvider } from "./actionConfigProvider.js"; @@ -262,12 +263,11 @@ export function loadAgentJsonTranslator< multipleActionOptions: MultipleActionOptions, regenerateSchema: boolean = true, model?: string, - exact: boolean = true, - jsonSchema: boolean = false, + generateOptions?: GenerateSchemaOptions, ): TypeAgentTranslator { const options = { model }; const translator = regenerateSchema - ? createActionJsonTranslatorFromSchemaDef( + ? createJsonTranslatorFromActionSchema( "AllActions", composeActionSchema( translatorName, @@ -277,7 +277,7 @@ export function loadAgentJsonTranslator< multipleActionOptions, ), options, - { exact, jsonSchema }, + generateOptions, ) : createJsonTranslatorFromSchemaDef( "AllActions", @@ -368,7 +368,7 @@ export function createTypeAgentTranslatorForSelectedActions< model?: string, ) { const options = { model }; - const translator = createActionJsonTranslatorFromSchemaDef( + const translator = createJsonTranslatorFromActionSchema( "AllActions", composeSelectedActionSchema( definitions, diff --git a/ts/packages/dispatcher/src/translation/translateRequest.ts b/ts/packages/dispatcher/src/translation/translateRequest.ts index 898e0195..19fd105e 100644 --- a/ts/packages/dispatcher/src/translation/translateRequest.ts +++ b/ts/packages/dispatcher/src/translation/translateRequest.ts @@ -73,8 +73,11 @@ export function getTranslatorForSchema( config.multiple, config.schema.generation.enabled, config.model, - !config.schema.optimize.enabled, - config.schema.generation.jsonSchema, + { + exact: !config.schema.optimize.enabled, + jsonSchema: config.schema.generation.jsonSchema, + jsonSchemaWithTs: config.schema.generation.jsonSchemaWithTs, + }, ); context.translatorCache.set(translatorName, newTranslator); return newTranslator;