Skip to content

Commit

Permalink
Merge branch 'dani/rag' of https://github.com/Clarifai/clarifai-nodejs
Browse files Browse the repository at this point in the history
…into dani/logging
  • Loading branch information
DaniAkash committed May 6, 2024
2 parents 5865011 + 3948e4e commit d5ea236
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 59 deletions.
12 changes: 1 addition & 11 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"@parcel/packager-ts": "^2.11.0",
"@parcel/transformer-typescript-types": "^2.11.0",
"@types/async": "^3.2.24",
"@types/cli-progress": "^3.11.5",
"@types/google-protobuf": "^3.15.12",
"@types/js-yaml": "^4.0.9",
"@types/lodash": "^4.17.0",
Expand Down Expand Up @@ -69,7 +68,6 @@
"async": "^3.2.5",
"chalk": "^5.3.0",
"clarifai-nodejs-grpc": "^10.3.2",
"cli-progress": "^3.12.0",
"csv-parse": "^5.5.5",
"from-protobuf-object": "^1.0.2",
"google-protobuf": "^3.21.2",
Expand Down
8 changes: 7 additions & 1 deletion src/client/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { APIError, UserError } from "../errors";
import { ClarifaiUrl, ClarifaiUrlHelper } from "../urls/helper";
import { AuthConfig } from "../utils/types";
import { Lister } from "./lister";
import { Input } from "./input";
import { Input, InputBulkUpload } from "./input";
import {
DeleteDatasetVersionsRequest,
ListDatasetVersionsRequest,
Expand Down Expand Up @@ -144,11 +144,13 @@ export class Dataset extends Lister {
inputType,
labels = false,
batchSize = this.batchSize,
uploadProgressEmitter,
}: {
folderPath: string;
inputType: "image" | "text";
labels: boolean;
batchSize?: number;
uploadProgressEmitter?: InputBulkUpload;
}): Promise<void> {
if (["image", "text"].indexOf(inputType) === -1) {
throw new UserError("Invalid input type");
Expand All @@ -171,6 +173,7 @@ export class Dataset extends Lister {
await this.input.bulkUpload({
inputs: inputProtos,
batchSize: batchSize,
uploadProgressEmitter,
});
}

Expand All @@ -180,12 +183,14 @@ export class Dataset extends Lister {
csvType,
labels = true,
batchSize = 128,
uploadProgressEmitter,
}: {
csvPath: string;
inputType?: "image" | "text" | "video" | "audio";
csvType: "raw" | "url" | "file";
labels?: boolean;
batchSize?: number;
uploadProgressEmitter?: InputBulkUpload;
}): Promise<void> {
if (!["image", "text", "video", "audio"].includes(inputType)) {
throw new UserError(
Expand Down Expand Up @@ -214,6 +219,7 @@ export class Dataset extends Lister {
await this.input.bulkUpload({
inputs: inputProtos,
batchSize: batchSize,
uploadProgressEmitter,
});
}
}
51 changes: 39 additions & 12 deletions src/client/input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
Text,
Video,
} from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb";
import { AuthConfig } from "../utils/types";
import { AuthConfig, Polygon as PolygonType } from "../utils/types";
import { Lister } from "./lister";
import { Buffer } from "buffer";
import fs from "fs";
Expand All @@ -42,11 +42,11 @@ import { StatusCode } from "clarifai-nodejs-grpc/proto/clarifai/api/status/statu
import os from "os";
import chunk from "lodash/chunk";
import { Status } from "clarifai-nodejs-grpc/proto/clarifai/api/status/status_pb";
import cliProgress from "cli-progress";
import async from "async";
import { MAX_RETRIES } from "../constants/dataset";
import { APIError, UserError } from "../errors";
import { logger } from "../utils/logging";
import { EventEmitter } from "events";

interface CSVRecord {
inputid: string;
Expand All @@ -56,6 +56,29 @@ interface CSVRecord {
geopoints: string;
}

interface UploadEvents {
start: ProgressEvent;
progress: ProgressEvent;
error: ErrorEvent;
end: ProgressEvent;
}

interface ProgressEvent {
current: number;
total: number;
}

interface ErrorEvent {
error: Error;
}

type BulkUploadEventEmitter<T> = EventEmitter & {
emit<K extends keyof T>(event: K, payload: T[K]): boolean;
on<K extends keyof T>(event: K, listener: (payload: T[K]) => void): void;
};

export type InputBulkUpload = BulkUploadEventEmitter<UploadEvents>;

/**
* Inputs is a class that provides access to Clarifai API endpoints related to Input information.
* @noInheritDoc
Expand Down Expand Up @@ -743,9 +766,9 @@ export class Input extends Lister {
}: {
inputId: string;
label: string;
polygons: number[][][];
polygons: PolygonType[];
}): Annotation {
const polygonsSchema = z.array(z.array(z.array(z.number())));
const polygonsSchema = z.array(z.array(z.tuple([z.number(), z.number()])));
try {
polygonsSchema.parse(polygons);
} catch {
Expand Down Expand Up @@ -1013,19 +1036,18 @@ export class Input extends Lister {
bulkUpload({
inputs,
batchSize: providedBatchSize = 128,
uploadProgressEmitter,
}: {
inputs: GrpcInput[];
batchSize?: number;
uploadProgressEmitter?: InputBulkUpload;
}): Promise<void> {
const batchSize = Math.min(128, providedBatchSize);
const chunkedInputs = chunk(inputs, batchSize);

const progressBar = new cliProgress.SingleBar(
{},
cliProgress.Presets.shades_classic,
);

progressBar.start(chunkedInputs.length, 0);
let currentProgress = 0;
const total = chunkedInputs.length;
uploadProgressEmitter?.emit("start", { current: currentProgress, total });

return new Promise<void>((resolve, reject) => {
async.mapLimit(
Expand All @@ -1037,7 +1059,11 @@ export class Input extends Lister {
this.retryUploads({
failedInputs,
}).finally(() => {
progressBar.increment();
currentProgress++;
uploadProgressEmitter?.emit("progress", {
current: currentProgress,
total,
});
callback(null, failedInputs);
});
})
Expand All @@ -1048,9 +1074,10 @@ export class Input extends Lister {
(err) => {
if (err) {
logger.error(`Error processing batches ${err.message}`);
uploadProgressEmitter?.emit("error");
reject(err);
}
progressBar.stop();
uploadProgressEmitter?.emit("end", { current: total, total });
logger.info("All inputs processed");
resolve();
},
Expand Down
33 changes: 4 additions & 29 deletions src/client/rag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ type workflowSchema = ReturnType<typeof validateWorkflow>;
export class RAG {
private authConfig: AuthConfig;

private chatStateId: string = "";

public promptWorkflow: Workflow;

public app: App;
Expand Down Expand Up @@ -402,40 +400,17 @@ export class RAG {
);
}

if (clientManageState) {
const singlePrompt = convertMessagesToStr(messages);
const inputProto = Input.getTextInput({
inputId: uuidv4(),
rawText: singlePrompt,
});
const response = await this.promptWorkflow.predict({
inputs: [inputProto],
});
const outputsList = response.resultsList?.[0]?.outputsList;
const output = outputsList[outputsList.length - 1];
messages.push(formatAssistantMessage(output?.data?.text?.raw ?? ""));
return messages;
}

// Server side chat state management
const message = messages[messages.length - 1].content;
if (!message.length) {
throw new UserError("Empty message supplied.");
}

const chatStateId = this.chatStateId !== null ? this.chatStateId : "init";
const singlePrompt = convertMessagesToStr(messages);
const inputProto = Input.getTextInput({
inputId: uuidv4(),
rawText: message,
rawText: singlePrompt,
});
const response = await this.promptWorkflow.predict({
inputs: [inputProto],
workflowStateId: chatStateId,
});

this.chatStateId = response.workflowState?.id ?? "";
const outputsList = response.resultsList?.[0]?.outputsList;
const output = outputsList[outputsList.length - 1];
return [formatAssistantMessage(output?.data?.text?.raw ?? "")];
messages.push(formatAssistantMessage(output?.data?.text?.raw ?? ""));
return messages;
}
}
3 changes: 2 additions & 1 deletion src/datasets/upload/features.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { JavaScriptValue } from "google-protobuf/google/protobuf/struct_pb";
import { Polygon } from "../../utils/types";

export interface TextFeatures {
imagePath?: undefined;
Expand Down Expand Up @@ -33,7 +34,7 @@ export interface VisualDetectionFeatures {
export interface VisualSegmentationFeatures {
imagePath: string;
labels: Array<string | number>;
polygons: Array<Array<Array<number>>>;
polygons: Polygon[];
geoInfo?: [number, number];
id?: number;
metadata?: Record<string, JavaScriptValue>;
Expand Down
2 changes: 1 addition & 1 deletion src/utils/misc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export function mergeObjects(obj1: AuthConfig, obj2: AuthConfig): AuthConfig {
export class BackoffIterator {
private count: number;

constructor({ count = 0 }: { count?: number } = { count: 0 }) {
constructor({ count } = { count: 0 }) {
this.count = count;
}

Expand Down
3 changes: 3 additions & 0 deletions src/utils/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ export type GrpcWithCallback<TRequest, TResponse> = (
export type PaginationRequestParams<T extends Record<string, unknown>> =
| Omit<Partial<T>, "userAppId" | "pageNo" | "perPage">
| Record<string, never>;

export type Point = [number, number];
export type Polygon = Point[];
2 changes: 1 addition & 1 deletion tests/client/rag.integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ describe("Rag", async () => {
const messages = [{ role: "human", content: "What is 1 + 1?" }];
const newMessages = await rag.chat({ messages, clientManageState: true });
expect(newMessages.length).toBe(2);
});
}, 10000);

// TODO: Server side state management is not supported yet - work in progress
it.skip("should predict & manage state on the server", async () => {
Expand Down
40 changes: 39 additions & 1 deletion tests/client/search.integration.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import path from "path";
import { getSchema } from "../../src/schema/search";
import { z } from "zod";
import { afterAll, beforeAll, describe, expect, it } from "vitest";
import { afterAll, beforeAll, describe, expect, it, vi } from "vitest";
import { App, Dataset, Input, Search, User } from "../../src/index";
import { Hit } from "clarifai-nodejs-grpc/proto/clarifai/api/resources_pb";
import EventEmitter from "events";

const NOW = Date.now().toString() + "-search";
const CREATE_APP_USER_ID = import.meta.env.VITE_CLARIFAI_USER_ID;
Expand Down Expand Up @@ -197,11 +198,48 @@ describe("Search", () => {
},
datasetId: datasetObj.id,
});
const eventEmitter = new EventEmitter();
const eventHandler = {
start: (...args: unknown[]) => console.log("start", args),
progress: (...args: unknown[]) => console.log("progress", args),
end: (...args: unknown[]) => console.log("end", args),
error: (...args: unknown[]) => console.log("error", args),
};
const startSpy = vi.spyOn(eventHandler, "start");
const progressSpy = vi.spyOn(eventHandler, "progress");
const endSpy = vi.spyOn(eventHandler, "end");
const errorSpy = vi.spyOn(eventHandler, "error");
eventEmitter.on("start", (start) => {
eventHandler.start(start);
});
eventEmitter.on("progress", (progress) => {
eventHandler.progress(progress);
});
eventEmitter.on("end", (progress) => {
eventHandler.end(progress);
});
eventEmitter.on("error", (error) => {
eventHandler.error(error);
});
await dataset.uploadFromFolder({
folderPath: DATASET_IMAGES_DIR,
inputType: "image",
labels: false,
uploadProgressEmitter: eventEmitter,
});
expect(startSpy).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ current: 0, total: 1 }),
);
expect(progressSpy).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ current: 1, total: 1 }),
);
expect(endSpy).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ current: 1, total: 1 }),
);
expect(errorSpy).not.toHaveBeenCalled();
}, 50000);

it("should get expected hits for filters", async () => {
Expand Down
1 change: 1 addition & 0 deletions vitest.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export default defineConfig({
test: {
coverage: {
reporter: ["text", "json", "html", "clover", "json-summary"],
include: ["src/**/*"],
},
},
});

0 comments on commit d5ea236

Please sign in to comment.