Skip to content

Commit

Permalink
feat (provider-utils): handle differing zod input/output schemas in j…
Browse files Browse the repository at this point in the history
…son parsing/validation
  • Loading branch information
shaper committed Jan 18, 2025
1 parent e7a9ec9 commit de4b452
Show file tree
Hide file tree
Showing 14 changed files with 592 additions and 56 deletions.
7 changes: 7 additions & 0 deletions .changeset/pretty-poems-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
'ai': patch
'@ai-sdk/provider-utils': patch
'@ai-sdk/ui-utils': patch
---

feat (provider-utils): correctly handle differing zod input/output schemas in json parsing/validation
29 changes: 28 additions & 1 deletion content/cookbook/01-next/30-generate-object.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@ tags: ['next', 'structured data']

Earlier functions like `generateText` and `streamText` gave us the ability to generate unstructured text. However, if you want to generate structured data like JSON, you can provide a schema that describes the structure of your desired object to the `generateObject` function.

The function requires you to provide a schema using [zod](https://zod.dev), a library for defining schemas for JavaScript objects. By using zod, you can also use it to validate the generated object and ensure that it conforms to the specified structure.
The function requires you to provide a schema using either [zod](https://zod.dev) or a custom validator that implements the `Validator` interface. When using zod, you can define schemas for JavaScript objects and use them for both validation and type inference. Custom validators allow you to implement your own validation logic while maintaining type safety.

## Type Transformation

The validation system supports transforming input types to different output types. This is useful when you want to parse and transform the raw input data into a different structure. For example:

```typescript
const schema = z.object({
date: z.string().transform(str => new Date(str)),
});
```

The schema above will accept a string input but transform it into a Date object during validation.

<Browser>
<ObjectGeneration
Expand Down Expand Up @@ -108,6 +120,21 @@ export async function POST(req: Request) {
}
```

## Error Handling

When validation fails, the system throws a `TypeValidationError` that includes both the invalid value and the cause of the validation failure. This makes it easier to debug validation issues:

```typescript
try {
const result = validateTypes({ value, schema });
} catch (error) {
if (error instanceof TypeValidationError) {
console.log('Invalid value:', error.value);
console.log('Cause:', error.cause);
}
}
```

---

<GithubLink link="https://github.com/vercel/ai/blob/main/examples/next-openai-pages/pages/basics/generate-object/index.tsx" />
2 changes: 1 addition & 1 deletion content/docs/03-ai-sdk-core/05-generating-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The AI SDK Core provides two functions to generate text and stream it from LLMs:
- [`generateText`](#generatetext): Generates text for a given prompt and model.
- [`streamText`](#streamtext): Streams text from a given prompt and model.

Advanced LLM features such as [tool calling](./tools-and-tool-calling) and [structured data generation](./generating-structured-data) are built on top of text generation.
Advanced LLM features such as [tool calling](./tools-and-tool-calling) and [structured data generation](./generating-structured-data) are built on top of text generation. When generating structured data, the system validates both input and output types to ensure type safety throughout the generation process.

## `generateText`

Expand Down
30 changes: 26 additions & 4 deletions packages/ai/core/generate-object/output-strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,17 @@ const noSchemaOutputStrategy: OutputStrategy<JSONValue, JSONValue, never> = {
jsonSchema: undefined,

validatePartialResult({ value, textDelta }) {
return { success: true, value: { partial: value, textDelta } };
return {
success: true,
value: {
partial: value,
textDelta,
},
rawValue: {
partial: value,
textDelta,
},
};
},

validateFinalResult(
Expand All @@ -76,7 +86,7 @@ const noSchemaOutputStrategy: OutputStrategy<JSONValue, JSONValue, never> = {
usage: context.usage,
}),
}
: { success: true, value };
: { success: true, value, rawValue: value };
},

createElementStream() {
Expand All @@ -100,6 +110,10 @@ const objectOutputStrategy = <OBJECT>(
partial: value as DeepPartial<OBJECT>,
textDelta,
},
rawValue: {
partial: value as DeepPartial<OBJECT>,
textDelta,
},
};
},

Expand Down Expand Up @@ -198,6 +212,10 @@ const arrayOutputStrategy = <ELEMENT>(
partial: resultArray,
textDelta,
},
rawValue: {
partial: resultArray,
textDelta,
},
};
},

Expand Down Expand Up @@ -225,7 +243,11 @@ const arrayOutputStrategy = <ELEMENT>(
}
}

return { success: true, value: inputArray as Array<ELEMENT> };
return {
success: true,
value: inputArray as Array<ELEMENT>,
rawValue: inputArray as Array<ELEMENT>,
};
},

createElementStream(
Expand Down Expand Up @@ -311,7 +333,7 @@ const enumOutputStrategy = <ENUM extends string>(
const result = value.result as string;

return enumValues.includes(result as ENUM)
? { success: true, value: result as ENUM }
? { success: true, value: result as ENUM, rawValue: result as ENUM }
: {
success: false,
error: new TypeValidationError({
Expand Down
67 changes: 66 additions & 1 deletion packages/provider-utils/src/parse-json.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { describe, it, expect } from 'vitest';
import { parseJSON, safeParseJSON, isParsableJson } from './parse-json';
import { z } from 'zod';
import { JSONParseError, TypeValidationError } from '@ai-sdk/provider';
import { validator } from './validator';

describe('parseJSON', () => {
it('should parse basic JSON without schema', () => {
Expand Down Expand Up @@ -84,7 +85,7 @@ describe('safeParseJSON', () => {

const result = safeParseJSON({
text: '{"user": {"id": "123", "name": "John"}}',
schema: schema as any,
schema,
});

expect(result).toEqual({
Expand Down Expand Up @@ -171,6 +172,70 @@ describe('safeParseJSON', () => {
rawValue: { value: 123 },
});
});

describe('input/output type transformations', () => {
it('should handle zod schema transformations', () => {
const schema = z.object({
id: z.string().transform(val => parseInt(val, 10)),
tags: z
.array(z.string())
.transform(tags => tags.map(t => t.toUpperCase())),
});

const result = safeParseJSON({
text: '{"id": "123", "tags": ["draft", "review"]}',
schema,
});

expect(result).toEqual({
success: true,
value: { id: 123, tags: ['DRAFT', 'REVIEW'] },
rawValue: { id: '123', tags: ['draft', 'review'] },
});
});

it('should handle custom validator transformations', () => {
type Input = { timestamp: string; status: string };
type Output = { date: Date; isActive: boolean };

const customValidator = validator<Output, Input>(value => {
if (
typeof value === 'object' &&
value !== null &&
'timestamp' in value &&
'status' in value
) {
const input = value as Input;
return {
success: true,
value: {
date: new Date(input.timestamp),
isActive: input.status === 'active',
},
rawValue: input,
};
}
return { success: false, error: new Error('Invalid input') };
});

const result = safeParseJSON({
text: '{"timestamp": "2024-01-01", "status": "active"}',
schema: customValidator,
});

expect(result).toEqual({
success: true,
value: {
date: new Date('2024-01-01'),
isActive: true,
},
rawValue: {
timestamp: '2024-01-01',
status: 'active',
},
});
});
});
});

describe('isParsableJson', () => {
Expand Down
16 changes: 9 additions & 7 deletions packages/provider-utils/src/parse-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ export function parseJSON<T>({
}
}

export type ParseResult<T> =
| { success: true; value: T; rawValue: unknown }
export type ParseResult<OUTPUT, INPUT = OUTPUT> =
| { success: true; value: OUTPUT; rawValue: INPUT }
| { success: false; error: JSONParseError | TypeValidationError };

/**
Expand All @@ -74,15 +74,17 @@ export function safeParseJSON(options: {
/**
* Safely parses a JSON string into a strongly-typed object, using a provided schema to validate the object.
*
* @template T - The type of the object to parse the JSON into.
* @template OUTPUT - The output type after schema transformation
* @template INPUT - The input type before schema transformation
* @param {string} text - The JSON string to parse.
* @param {Validator<T>} schema - The schema to use for parsing the JSON.
* @param {ZodSchema<OUTPUT, any, INPUT> | Validator<OUTPUT, INPUT>} schema - The schema to use for parsing the JSON.
* Can be either a Zod schema (supporting input/output transformations) or a Validator instance.
* @returns An object with either a `success` flag and the parsed and typed data, or a `success` flag and an error object.
*/
export function safeParseJSON<T>(options: {
export function safeParseJSON<OUTPUT, INPUT = OUTPUT>(options: {
text: string;
schema: ZodSchema<T> | Validator<T>;
}): ParseResult<T>;
schema: ZodSchema<OUTPUT, any, INPUT> | Validator<OUTPUT, INPUT>;
}): ParseResult<OUTPUT, INPUT>;
export function safeParseJSON<T>({
text,
schema,
Expand Down
75 changes: 73 additions & 2 deletions packages/provider-utils/src/validate-types.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ const customValidator = validator<{ name: string; age: number }>(value =>
typeof value.name === 'string' &&
'age' in value &&
typeof value.age === 'number'
? { success: true, value: value as { name: string; age: number } }
? {
success: true,
value: value as { name: string; age: number },
rawValue: value as { name: string; age: number },
}
: { success: false, error: new Error('Invalid input') },
);

Expand Down Expand Up @@ -60,7 +64,7 @@ describe('safeValidateTypes', () => {
it('should return validated object for valid input', () => {
const input = { name: 'John', age: 30 };
const result = safeValidateTypes({ value: input, schema });
expect(result).toEqual({ success: true, value: input });
expect(result).toEqual({ success: true, value: input, rawValue: input });
});

it('should return error object for invalid input', () => {
Expand All @@ -80,3 +84,70 @@ describe('safeValidateTypes', () => {
});
});
});

describe('type transformations', () => {
const transformSchema = z.object({
id: z.string().transform(val => {
const num = parseInt(val, 10);
if (isNaN(num)) throw new Error('Invalid number');
return num;
}),
name: z.string(),
});

const transformValidator = validator<
{ id: number; name: string },
{ id: string; name: string }
>(value => {
if (
typeof value === 'object' &&
value !== null &&
'id' in value &&
typeof value.id === 'string' &&
'name' in value &&
typeof value.name === 'string'
) {
const num = parseInt(value.id, 10);
if (isNaN(num)) {
return { success: false, error: new Error('Invalid number') };
}
return {
success: true,
value: { id: num, name: value.name },
rawValue: value as { id: string; name: string },
};
}
return { success: false, error: new Error('Invalid input') };
});

describe.each([
['Zod schema', transformSchema],
['Custom validator', transformValidator],
])('using %s', (_, schema) => {
const validInput = { id: '123', name: 'John' };
const expectedOutput = { id: 123, name: 'John' };

it('should transform types in validateTypes', () => {
const result = validateTypes({ value: validInput, schema });
expect(result).toEqual(expectedOutput);
});

it('should transform types in safeValidateTypes', () => {
const result = safeValidateTypes({ value: validInput, schema });
expect(result).toEqual({
success: true,
value: expectedOutput,
rawValue: validInput,
});
});

it('should handle invalid transformations', () => {
const invalidInput = { id: 'not-a-number', name: 'John' };
const result = safeValidateTypes({ value: invalidInput, schema });
expect(result).toEqual({
success: false,
error: expect.any(TypeValidationError),
});
});
});
});
Loading

0 comments on commit de4b452

Please sign in to comment.