Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip the Typescript schema by default when JsonSchema is enabled. #699

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ts/packages/actionSchema/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,21 @@ export type GenerateSchemaOptions = {
strict?: boolean; // default true
exact?: boolean; // default false
jsonSchema?: boolean; // default false
jsonSchemaWithTs?: boolean; // default false, applies only when jsonSchema is true.
};

export function generateSchemaTypeDefinition(
definition: SchemaTypeDefinition,
options?: GenerateSchemaOptions,
order?: Map<string, number>,
) {
): string {
// wrap the action schema when json schema is active.
const jsonSchema = options?.jsonSchema ?? false;
const includeTs = !jsonSchema || (options?.jsonSchemaWithTs ?? false);
if (!includeTs) {
return "";
}

const strict = options?.strict ?? true;
const exact = options?.exact ?? false;
const emitted = new Map<
Expand Down
48 changes: 40 additions & 8 deletions ts/packages/actionSchema/src/jsonSchemaGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import * as sc from "./creator.js";
import {
ActionSchemaEntryTypeDefinition,
ActionSchemaGroup,
SchemaObjectField,
SchemaType,
SchemaTypeDefinition,
} from "./type.js";
Expand All @@ -18,38 +19,46 @@ export function wrapTypeWithJsonSchema(

type JsonSchemaObject = {
type: "object";
description?: string;
properties: Record<string, JsonSchema>;
required: string[];
additionalProperties: false;
};
type JsonSchemaArray = {
type: "array";
description?: string;
items: JsonSchema;
};

type JsonSchemaString = {
type: "string";
description?: string;
enum?: string[];
};

type JsonSchemaNumber = {
type: "number";
description?: string;
};

type JsonSchemaBoolean = {
type: "boolean";
};

type JsonSchemaUnion = {
anyOf: JsonSchema[];
description?: string;
};

type JsonSchemaNull = {
type: "null";
description?: string;
};

type JsonSchemaUnion = {
anyOf: JsonSchema[];
description?: string;
};

type JsonSchemaReference = {
$ref: string;
description?: string;
};

type JsonSchemaRoot = {
Expand All @@ -69,6 +78,14 @@ type JsonSchema =
| JsonSchemaUnion
| JsonSchemaReference;

function fieldComments(field: SchemaObjectField): string | undefined {
const combined = [
...(field.comments ?? []),
...(field.trailingComments ?? []),
];
return combined.length > 0 ? combined.join("\n") : undefined;
}

function generateJsonSchemaType(
type: SchemaType,
pending: SchemaTypeDefinition[],
Expand All @@ -79,10 +96,18 @@ function generateJsonSchemaType(
return {
type: "object",
properties: Object.fromEntries(
Object.entries(type.fields).map(([key, field]) => [
key,
generateJsonSchemaType(field.type, pending, strict),
]),
Object.entries(type.fields).map(([key, field]) => {
const fieldType = generateJsonSchemaType(
field.type,
pending,
strict,
);
const comments = fieldComments(field);
if (comments) {
fieldType.description = comments;
}
return [key, fieldType];
}),
),
required: Object.keys(type.fields),
additionalProperties: false,
Expand Down Expand Up @@ -134,6 +159,9 @@ function generateJsonSchemaTypeDefinition(
strict: true,
schema: generateJsonSchemaType(def.type, pending, strict),
};
if (def.comments) {
schema.schema.description = def.comments.join("\n");
}

if (pending.length !== 0) {
const $defs: Record<string, JsonSchema> = {};
Expand All @@ -147,6 +175,10 @@ function generateJsonSchemaTypeDefinition(
pending,
strict,
);
if (definition.comments) {
$defs[definition.name].description =
definition.comments.join("\n");
}
} while (pending.length > 0);
schema.schema.$defs = $defs;
}
Expand Down
65 changes: 41 additions & 24 deletions ts/packages/cli/src/commands/test/translate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,18 @@ export default class TestTranslateCommand extends Command {
description: "Output test result file",
required: true,
}),
success: Flags.boolean({
char: "s",
succeeded: Flags.boolean({
description:
"Copy failed test data and rerun only successful tests from the test result file",
}),
fail: Flags.boolean({
char: "f",
failed: Flags.boolean({
description:
"Copy pass test data and rerun only failed tests from the test result file",
}),
skipped: Flags.boolean({
description:
"Copy skipped test data and rerun only skipped tests from the test result file",
}),
input: Flags.string({
char: "i",
description: "Input test result file to get requests from",
Expand Down Expand Up @@ -167,7 +169,7 @@ export default class TestTranslateCommand extends Command {
}

const output: TestResultFile = { pass: [], fail: [] };
let requests: string[];
let requests: string[] = [];
let repeat: number;
if (flags.input) {
if (argv.length !== 0) {
Expand Down Expand Up @@ -196,29 +198,34 @@ export default class TestTranslateCommand extends Command {
}
}

const includeSuccess = flags.success || !flags.fail;
const includeFail = flags.fail || !flags.success;
if (flags.repeat !== undefined && flags.repeat !== repeat) {
throw new Error("Specified repeat doesn't match result file");
}

if (includeFail) {
requests = input.fail.map((entry) => entry.request);
if (input.skipped) {
requests = requests.concat(input.skipped);
}
const includeAll =
!flags.succeeded && !flags.failed && !flags.skipped;

if (includeAll || flags.succeeded) {
requests = requests.concat(
input.pass.map((entry) => entry.request),
);
} else {
requests = [];
output.pass = input.pass;
}
if (includeSuccess) {
if (flags.failed) {
output.pass = input.pass;
} else {
requests = input.pass
.map((entry) => entry.request)
.concat(requests);
}
if (includeAll || flags.failed) {
requests = requests.concat(
input.fail.map((entry) => entry.request),
);
} else {
output.fail = input.fail;
}

if (flags.repeat !== undefined && flags.repeat !== repeat) {
throw new Error("Specified repeat doesn't match result file");
if (input.skipped !== undefined) {
if (includeAll || flags.skipped) {
requests = requests.concat(input.skipped);
} else {
output.skipped = input.skipped;
}
}
} else {
repeat = flags.repeat ?? defaultRepeat;
Expand Down Expand Up @@ -270,7 +277,7 @@ export default class TestTranslateCommand extends Command {
function print(msg: string) {
processed++;
console.log(
`[${processed.toString().padStart(totalStr.length)}/${totalStr}] ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`,
`${chalk.white(`[${processed.toString().padStart(totalStr.length)}/${totalStr}]`)} ${chalk.yellow(`[Fail: ${failedTotal.toString().padStart(totalStr.length)} (${((failedTotal / processed) * 100).toFixed(2).padStart(5)}%)]`)} ${msg}`,
);
}
const concurrency = flags.concurrency ?? 4;
Expand Down Expand Up @@ -319,11 +326,21 @@ export default class TestTranslateCommand extends Command {

failedTotal++;
failed = true;
print(
chalk.red(
`Failed to consistently generate actions`,
),
);
break;
}
if (actual === undefined) {
failedTotal++;
failed = true;
print(
chalk.red(
`Failed to consistently generate actions`,
),
);
break;
}
if (actual.length !== expected.length) {
Expand Down
21 changes: 15 additions & 6 deletions ts/packages/dispatcher/src/context/chatHistoryPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,21 @@ export function createTypeAgentRequestPrompt(
}
}

const prompts: string[] = [
`You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`,
`\`\`\``,
translator.validator.getSchemaText(),
`\`\`\``,
];
const prompts: string[] = [];
if (translator.validator.getSchemaText() === "") {
// If the schema is empty, we are skipping the type script schema because of json schema.
prompts.push(
`You are a service that translates user requests into JSON objects`,
);
} else {
prompts.push(
`You are a service that translates user requests into JSON objects of type "${translator.validator.getTypeName()}" according to the following TypeScript definitions:`,
`\`\`\``,
translator.validator.getSchemaText(),
`\`\`\``,
);
}

if (context) {
if (history !== undefined) {
const promptSections: PromptSection[] = history.promptSections;
Expand Down
2 changes: 2 additions & 0 deletions ts/packages/dispatcher/src/context/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type DispatcherConfig = {
generation: {
enabled: boolean;
jsonSchema: boolean;
jsonSchemaWithTs: boolean; // only applies when jsonSchema is true
};
optimize: {
enabled: boolean;
Expand Down Expand Up @@ -157,6 +158,7 @@ const defaultSessionConfig: SessionConfig = {
generation: {
enabled: true,
jsonSchema: false,
jsonSchemaWithTs: true,
},
optimize: {
enabled: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function createActionSchemaJsonValidator<T extends TranslatedAction>(
};
}

export function createActionJsonTranslatorFromSchemaDef<
export function createJsonTranslatorFromActionSchema<
T extends TranslatedAction,
>(
typeName: string,
Expand Down
12 changes: 6 additions & 6 deletions ts/packages/dispatcher/src/translation/agentTranslators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import { createTypeAgentRequestPrompt } from "../context/chatHistoryPrompt.js";
import {
composeActionSchema,
composeSelectedActionSchema,
createActionJsonTranslatorFromSchemaDef,
createJsonTranslatorFromActionSchema,
} from "./actionSchemaJsonTranslator.js";
import {
ActionSchemaTypeDefinition,
generateActionSchema,
generateSchemaTypeDefinition,
ActionSchemaObject,
ActionSchemaCreator as sc,
GenerateSchemaOptions,
} from "action-schema";
import { ActionConfig } from "./actionConfig.js";
import { ActionConfigProvider } from "./actionConfigProvider.js";
Expand Down Expand Up @@ -262,12 +263,11 @@ export function loadAgentJsonTranslator<
multipleActionOptions: MultipleActionOptions,
regenerateSchema: boolean = true,
model?: string,
exact: boolean = true,
jsonSchema: boolean = false,
generateOptions?: GenerateSchemaOptions,
): TypeAgentTranslator<T> {
const options = { model };
const translator = regenerateSchema
? createActionJsonTranslatorFromSchemaDef<T>(
? createJsonTranslatorFromActionSchema<T>(
"AllActions",
composeActionSchema(
translatorName,
Expand All @@ -277,7 +277,7 @@ export function loadAgentJsonTranslator<
multipleActionOptions,
),
options,
{ exact, jsonSchema },
generateOptions,
)
: createJsonTranslatorFromSchemaDef<T>(
"AllActions",
Expand Down Expand Up @@ -368,7 +368,7 @@ export function createTypeAgentTranslatorForSelectedActions<
model?: string,
) {
const options = { model };
const translator = createActionJsonTranslatorFromSchemaDef<T>(
const translator = createJsonTranslatorFromActionSchema<T>(
"AllActions",
composeSelectedActionSchema(
definitions,
Expand Down
7 changes: 5 additions & 2 deletions ts/packages/dispatcher/src/translation/translateRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ export function getTranslatorForSchema(
config.multiple,
config.schema.generation.enabled,
config.model,
!config.schema.optimize.enabled,
config.schema.generation.jsonSchema,
{
exact: !config.schema.optimize.enabled,
jsonSchema: config.schema.generation.jsonSchema,
jsonSchemaWithTs: config.schema.generation.jsonSchemaWithTs,
},
);
context.translatorCache.set(translatorName, newTranslator);
return newTranslator;
Expand Down
Loading