diff --git a/src/common.ts b/src/common.ts index 70ba648..0f1a594 100644 --- a/src/common.ts +++ b/src/common.ts @@ -1,4 +1,4 @@ -import type { MessagePort, TransferListItem } from 'node:worker_threads' +import { type MessagePort } from 'node:worker_threads' /** Channel for communicating between main thread and workers */ export interface TinypoolChannel { @@ -9,6 +9,9 @@ export interface TinypoolChannel { postMessage(message: any): void } +// TODO: Narrow down with generic +type Listener = (...args: any[]) => void + export interface TinypoolWorker { runtime: string initialize(options: { @@ -19,12 +22,31 @@ export interface TinypoolWorker { workerData: TinypoolData trackUnmanagedFds?: boolean }): void + + /** Terminates the worker */ terminate(): Promise - postMessage(message: any, transferListItem?: TransferListItem[]): void + + /** Initialize the worker */ + initializeWorker(message: StartupMessage): void + + /** Run given task on worker */ + runTask(message: RequestMessage): void + + /** Listen on task finish messages */ + onTaskFinished(message: Listener): void + + /** Listen on ready messages */ + onReady(listener: Listener): void + + /** Listen on errors */ + onError(listener: Listener): void + + /** Listen on exit. Called only **once**. */ + onExit(listener: Listener): void + + /** Set's channel for 'main <-> worker' communication */ setChannel?: (channel: TinypoolChannel) => void - on(event: string, listener: (...args: any[]) => void): void - once(event: string, listener: (...args: any[]) => void): void - emit(event: string, ...data: any[]): void + ref?: () => void unref?: () => void threadId: number @@ -45,7 +67,7 @@ export interface TinypoolWorkerMessage< export interface StartupMessage { filename: string | null name: string - port: MessagePort + port?: MessagePort sharedBuffer: Int32Array useAtomics: boolean } @@ -55,6 +77,7 @@ export interface RequestMessage { task: any filename: string name: string + transferList?: any } export interface ReadyMessage { diff --git a/src/entry/worker.ts b/src/entry/worker.ts index 6e41f8e..f955aca 100644 --- a/src/entry/worker.ts +++ b/src/entry/worker.ts @@ -41,6 +41,10 @@ parentPort!.on('message', (message: StartupMessage) => { const { port, sharedBuffer, filename, name } = message + if (!port) { + throw new Error(`Missing port ${JSON.stringify(message)}`) + } + ;(async function () { if (filename !== null) { await getHandler(filename, name) diff --git a/src/index.ts b/src/index.ts index 2309ed8..703f9bf 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,9 +1,5 @@ -import { - MessageChannel, - type MessagePort, - receiveMessageOnPort, -} from 'node:worker_threads' -import { once, EventEmitterAsyncResource } from 'node:events' +import { type MessagePort } from 'node:worker_threads' +import { EventEmitterAsyncResource } from 'node:events' import { AsyncResource } from 'node:async_hooks' import { fileURLToPath, URL } from 'node:url' import { join } from 'node:path' @@ -13,7 +9,6 @@ import { performance } from 'node:perf_hooks' import { readFileSync } from 'node:fs' import { amount as physicalCpuCount } from './physicalCpuCount' import { - type ReadyMessage, type RequestMessage, type ResponseMessage, type StartupMessage, @@ -448,7 +443,6 @@ class WorkerInfo extends AsynchronouslyCreatedResource { freeWorkerId: () => void taskInfos: Map idleTimeout: NodeJS.Timeout | null = null - port: MessagePort sharedBuffer: Int32Array lastSeenResponseCount: number = 0 usedMemory?: number @@ -457,7 +451,6 @@ class WorkerInfo extends AsynchronouslyCreatedResource { constructor( worker: TinypoolWorker, - port: MessagePort, workerId: number, freeWorkerId: () => void, onMessage: ResponseCallback @@ -466,8 +459,7 @@ class WorkerInfo extends AsynchronouslyCreatedResource { this.worker = worker this.workerId = workerId this.freeWorkerId = freeWorkerId - this.port = port - this.port.on('message', (message: ResponseMessage) => + this.worker.onTaskFinished((message: ResponseMessage) => this._handleResponse(message) ) this.onMessage = onMessage @@ -498,7 +490,6 @@ class WorkerInfo extends AsynchronouslyCreatedResource { clearTimeout(timer) } - this.port.close() this.clearIdleTimeout() for (const taskInfo of this.taskInfos.values()) { taskInfo.done(Errors.ThreadTermination()) @@ -518,18 +509,6 @@ class WorkerInfo extends AsynchronouslyCreatedResource { } } - ref(): WorkerInfo { - this.port.ref() - return this - } - - unref(): WorkerInfo { - // Note: Do not call ref()/unref() on the Worker itself since that may cause - // a hard crash, see https://github.com/nodejs/node/pull/33394. - this.port.unref() - return this - } - _handleResponse(message: ResponseMessage): void { this.usedMemory = message.usedMemory this.onMessage(message) @@ -537,7 +516,7 @@ class WorkerInfo extends AsynchronouslyCreatedResource { if (this.taskInfos.size === 0) { // No more tasks running on this Worker means it should not keep the // process running. - this.unref() + //this.unref() } } @@ -548,13 +527,14 @@ class WorkerInfo extends AsynchronouslyCreatedResource { taskId: taskInfo.taskId, filename: taskInfo.filename, name: taskInfo.name, + transferList: taskInfo.transferList, } try { if (taskInfo.channel) { this.worker.setChannel?.(taskInfo.channel) } - this.port.postMessage(message, taskInfo.transferList) + this.worker.runTask(message) } catch (err) { // This would mostly happen if e.g. message contains unserializable data // or transferList is invalid. @@ -564,7 +544,6 @@ class WorkerInfo extends AsynchronouslyCreatedResource { taskInfo.workerInfo = this this.taskInfos.set(taskInfo.taskId, taskInfo) - this.ref() this.clearIdleTimeout() // Inform the worker that there are new messages posted, and wake it up @@ -586,10 +565,10 @@ class WorkerInfo extends AsynchronouslyCreatedResource { if (actualResponseCount !== this.lastSeenResponseCount) { this.lastSeenResponseCount = actualResponseCount - let entry - while ((entry = receiveMessageOnPort(this.port)) !== undefined) { - this._handleResponse(entry.message) - } + // TODO let entry + // while ((entry = receiveMessageOnPort(this.port)) !== undefined) { + // this._handleResponse(entry.message) + // } } } @@ -694,7 +673,7 @@ class ThreadPool { }) const tinypoolPrivateData = { workerId: workerId! } - const worker = + const worker: TinypoolWorker = this.options.runtime === 'child_process' ? new ProcessWorker() : new ThreadWorker() @@ -737,10 +716,8 @@ class ThreadPool { this._processPendingMessages() } - const { port1, port2 } = new MessageChannel() const workerInfo = new WorkerInfo( worker, - port1, workerId!, () => workerIds.set(workerId, true), onMessage @@ -754,32 +731,23 @@ class ThreadPool { const message: StartupMessage = { filename: this.options.filename, name: this.options.name, - port: port2, sharedBuffer: workerInfo.sharedBuffer, useAtomics: this.options.useAtomics, } - worker.postMessage(message, [port2]) - - worker.on('message', (message: ReadyMessage) => { - if (message.ready === true) { - if (workerInfo.currentUsage() === 0) { - workerInfo.unref() - } + worker.initializeWorker(message) - if (!workerInfo.isReady()) { - workerInfo.markAsReady() - } - return + worker.onReady(() => { + if (workerInfo.currentUsage() === 0) { + worker.unref?.() } - worker.emit( - 'error', - new Error(`Unexpected message on Worker: ${inspect(message)}`) - ) + if (!workerInfo.isReady()) { + workerInfo.markAsReady() + } }) - worker.on('error', (err: Error) => { + worker.onError((err: Error) => { // Work around the bug in https://github.com/nodejs/node/pull/33394 worker.ref = () => {} @@ -809,13 +777,7 @@ class ThreadPool { } }) - worker.unref() - port1.on('close', () => { - // The port is only closed if the Worker stops for some reason, but we - // always .unref() the Worker itself. We want to receive e.g. 'error' - // events on it, so we ref it once we know it's going to exit anyway. - worker.ref() - }) + worker.unref?.() this.workers.add(workerInfo) } @@ -1056,13 +1018,14 @@ class ThreadPool { taskInfo.done(new Error('Terminating worker thread')) } - const exitEvents: Promise[] = [] + const exitEvents: Promise[] = [] while (this.workers.size > 0) { const [workerInfo] = this.workers - // @ts-expect-error -- TODO Fix - exitEvents.push(once(workerInfo.worker, 'exit')) - // @ts-expect-error -- TODO Fix - void this._removeWorker(workerInfo) + + if (workerInfo) { + exitEvents.push(new Promise((r) => workerInfo.worker.onExit(r))) + void this._removeWorker(workerInfo) + } } await Promise.all(exitEvents) @@ -1087,8 +1050,7 @@ class ThreadPool { Array.from(this.workers).filter((workerInfo) => { // Remove idle workers if (workerInfo.currentUsage() === 0) { - // @ts-expect-error -- TODO Fix - exitEvents.push(once(workerInfo.worker, 'exit')) + exitEvents.push(new Promise((r) => workerInfo.worker.onExit(r))) void this._removeWorker(workerInfo) } // Mark on-going workers for recycling. diff --git a/src/runtime/process-worker.ts b/src/runtime/process-worker.ts index eadc2cb..f4ed14e 100644 --- a/src/runtime/process-worker.ts +++ b/src/runtime/process-worker.ts @@ -1,10 +1,13 @@ import { type ChildProcess, fork } from 'node:child_process' -import { MessagePort, type TransferListItem } from 'node:worker_threads' import { fileURLToPath } from 'node:url' import { + type RequestMessage, + type ReadyMessage, + type StartupMessage, type TinypoolChannel, type TinypoolWorker, type TinypoolWorkerMessage, + type ResponseMessage, } from '../common' const __tinypool_worker_message__ = true @@ -15,7 +18,6 @@ export default class ProcessWorker implements TinypoolWorker { runtime = 'child_process' process!: ChildProcess threadId!: number - port?: MessagePort channel?: TinypoolChannel waitForExit!: Promise isTerminating = false @@ -36,6 +38,12 @@ export default class ProcessWorker implements TinypoolWorker { this.process.on('exit', this.onUnexpectedExit) this.waitForExit = new Promise((r) => this.process.on('exit', r)) + + this.process.on('message', (data: TinypoolWorkerMessage) => { + if (!data || !data.__tinypool_worker_message__) { + return this.channel?.postMessage(data) + } + }) } onUnexpectedExit = () => { @@ -54,7 +62,6 @@ export default class ProcessWorker implements TinypoolWorker { this.process.kill() await this.waitForExit - this.port?.close() clearTimeout(sigkillTimeout) } @@ -73,24 +80,7 @@ export default class ProcessWorker implements TinypoolWorker { } } - postMessage(message: any, transferListItem?: Readonly) { - transferListItem?.forEach((item) => { - if (item instanceof MessagePort) { - this.port = item - } - }) - - // Mirror port's messages to process - if (this.port) { - this.port.on('message', (message) => - this.send(>{ - ...message, - source: 'port', - __tinypool_worker_message__, - }) - ) - } - + initializeWorker(message: StartupMessage) { return this.send(>{ ...message, source: 'pool', @@ -98,31 +88,56 @@ export default class ProcessWorker implements TinypoolWorker { }) } - on(event: string, callback: (...args: any[]) => void) { - return this.process.on(event, (data: TinypoolWorkerMessage) => { - // All errors should be forwarded to the pool - if (event === 'error') { - return callback(data) - } + runTask(message: RequestMessage): void { + return this.send(>{ + ...message, + source: 'port', + __tinypool_worker_message__, + }) + } - if (!data || !data.__tinypool_worker_message__) { - return this.channel?.postMessage(data) + onReady(callback: (...args: any[]) => void) { + return this.process.on( + 'message', + (data: TinypoolWorkerMessage & ReadyMessage) => { + if ( + data.__tinypool_worker_message__ === true && + data.source === 'pool' && + data.ready === true + ) { + callback() + } } + ) + } - if (data.source === 'pool') { - callback(data) - } else if (data.source === 'port') { - this.port!.postMessage(data) + onTaskFinished(callback: (...args: any[]) => void) { + return this.process.on( + 'message', + (data: TinypoolWorkerMessage & ResponseMessage) => { + if ( + data.__tinypool_worker_message__ === true && + data.source === 'port' + ) { + callback(data) + } } - }) + ) } - once(event: string, callback: (...args: any[]) => void) { - return this.process.once(event, callback) + onError(callback: (...args: any[]) => void) { + return this.process.on('error', (data) => { + // All errors should be forwarded to the pool + return callback(data) + }) } - emit(event: string, ...data: any[]) { - return this.process.emit(event, ...data) + onExit(callback: (...args: any[]) => void) { + if (this.isTerminating) { + return callback() + } + + return this.process.once('exit', callback) } ref() { @@ -130,8 +145,6 @@ export default class ProcessWorker implements TinypoolWorker { } unref() { - this.port?.unref() - // The forked child_process adds event listener on `process.on('message)`. // This requires manual unreffing of its channel. this.process.channel?.unref() diff --git a/src/runtime/thread-worker.ts b/src/runtime/thread-worker.ts index cd7f9a2..f00b948 100644 --- a/src/runtime/thread-worker.ts +++ b/src/runtime/thread-worker.ts @@ -1,12 +1,20 @@ import { fileURLToPath } from 'node:url' -import { type TransferListItem, Worker } from 'node:worker_threads' -import { type TinypoolWorker } from '../common' +import { inspect } from 'node:util' +import { MessageChannel, type MessagePort, Worker } from 'node:worker_threads' +import { + type RequestMessage, + type ReadyMessage, + type StartupMessage, + type TinypoolWorker, +} from '../common' export default class ThreadWorker implements TinypoolWorker { name = 'ThreadWorker' runtime = 'worker_threads' thread!: Worker threadId!: number + port!: MessagePort + workerPort!: MessagePort initialize(options: Parameters[0]) { this.thread = new Worker( @@ -14,26 +22,63 @@ export default class ThreadWorker implements TinypoolWorker { options ) this.threadId = this.thread.threadId + + const { port1, port2 } = new MessageChannel() + this.port = port1 + this.workerPort = port2 + + port1.on('close', () => { + // The port is only closed if the Worker stops for some reason, but we + // always .unref() the Worker itself. We want to receive e.g. 'error' + // events on it, so we ref it once we know it's going to exit anyway. + this.ref?.() + }) } async terminate() { + this.port.close() return this.thread.terminate() } - postMessage(message: any, transferListItem?: Readonly) { - return this.thread.postMessage(message, transferListItem) + initializeWorker(message: StartupMessage) { + return this.thread.postMessage( + { + ...message, + port: this.workerPort, + }, + [this.workerPort] + ) + } + + runTask(message: RequestMessage): void { + this.port.ref() + + return this.port.postMessage(message, message.transferList) + } + + onReady(callback: (...args: any[]) => void) { + return this.thread.on('message', (message: ReadyMessage) => { + if (message.ready === true) { + return callback() + } + + this.thread.emit( + 'error', + new Error(`Unexpected message on Worker: ${inspect(message)}`) + ) + }) } - on(event: string, callback: (...args: any[]) => void) { - return this.thread.on(event, callback) + onTaskFinished(listener: (...args: any[]) => void): void { + this.port.on('message', listener) } - once(event: string, callback: (...args: any[]) => void) { - return this.thread.once(event, callback) + onError(callback: (...args: any[]) => void) { + return this.thread.on('error', callback) } - emit(event: string, ...data: any[]) { - return this.thread.emit(event, ...data) + onExit(callback: (...args: any[]) => void) { + return this.thread.once('exit', callback) } ref() { diff --git a/test/termination.test.ts b/test/termination.test.ts index 3cc60ed..77c7d2a 100644 --- a/test/termination.test.ts +++ b/test/termination.test.ts @@ -57,20 +57,24 @@ test('writing to terminating worker does not crash', async () => { await destroyed }) -test('recycling workers while closing pool does not crash', async () => { - const pool = new Tinypool({ - runtime: 'child_process', - filename: resolve(__dirname, 'fixtures/nested-pool.mjs'), - isolateWorkers: true, - minThreads: 1, - maxThreads: 1, - }) +test( + 'recycling workers while closing pool does not crash', + { timeout: 10_000 }, + async () => { + const pool = new Tinypool({ + runtime: 'child_process', + filename: resolve(__dirname, 'fixtures/nested-pool.mjs'), + isolateWorkers: true, + minThreads: 1, + maxThreads: 1, + }) - await Promise.all( - (Array(10) as (() => Promise)[]) - .fill(() => pool.run({})) - .map((fn) => fn()) - ) + await Promise.all( + (Array(10) as (() => Promise)[]) + .fill(() => pool.run({})) + .map((fn) => fn()) + ) - await pool.destroy() -}) + await pool.destroy() + } +) diff --git a/test/uncaught-exception-from-handler.test.ts b/test/uncaught-exception-from-handler.test.ts index 8a6f408..7b82560 100644 --- a/test/uncaught-exception-from-handler.test.ts +++ b/test/uncaught-exception-from-handler.test.ts @@ -46,7 +46,7 @@ test('uncaught exception in immediate after task yields error event', async () = // Hack a bit to make sure we get the 'exit'/'error' events. expect(pool.threads.length).toBe(1) - pool.threads[0]!.ref?.() + pool.threads[0]!.ref!() // This is the main aassertion here. expect((await errorEvent)[0]!.message).toEqual('not_caught')