diff --git a/packages/pglite-socket/src/index.ts b/packages/pglite-socket/src/index.ts index 572b9cfa7..9caed00b6 100644 --- a/packages/pglite-socket/src/index.ts +++ b/packages/pglite-socket/src/index.ts @@ -1,3 +1,4 @@ +import { Parser as ProtocolParser } from '@electric-sql/pg-protocol' import type { PGlite } from '@electric-sql/pglite' import { type Server, type Socket, createServer } from 'net' @@ -25,7 +26,7 @@ class QueryQueueManager { private processing = false private db: PGlite private debug: boolean - private lastHandlerId: null | number = null + private activeHandlerId: null | number = null constructor(db: PGlite, debug = false) { this.db = db @@ -65,6 +66,55 @@ class QueryQueueManager { }) } + private dequeueNextQuery(): null | QueuedQuery { + if (this.activeHandlerId === null) { + return this.queue.shift() ?? null + } + + const index = this.queue.findIndex( + (query) => query.handlerId === this.activeHandlerId, + ) + + if (index === -1) { + this.log( + `waiting for more protocol messages from handler #${this.activeHandlerId}`, + ) + return null + } + + return this.queue.splice(index, 1)[0] + } + + private updateActiveHandler( + handlerId: number, + readyForQueryStatus: null | string, + message: Uint8Array, + ): void { + if (readyForQueryStatus === 'I') { + this.log(`handler #${handlerId} released protocol ownership`) + this.activeHandlerId = null + return + } + + if (readyForQueryStatus === null && message[0] === 0x58) { + this.log(`handler #${handlerId} released protocol ownership on terminate`) + this.activeHandlerId = null + return + } + + this.activeHandlerId = handlerId + + if (readyForQueryStatus) { + this.log( + `handler #${handlerId} retained protocol ownership with ReadyForQuery status ${readyForQueryStatus}`, + ) + } else { + this.log( + `handler #${handlerId} retained protocol ownership until ReadyForQuery`, + ) + } + } + private async processQueue(): Promise { if (this.processing || this.queue.length === 0) { return @@ -72,59 +122,59 @@ class QueryQueueManager { this.processing = true - while (this.queue.length > 0) { - let query + try { + while (this.queue.length > 0) { + const query = this.dequeueNextQuery() + if (!query) break - if (this.db.isInTransaction() && this.lastHandlerId) { - const i = this.queue.findIndex( - (q) => q.handlerId === this.lastHandlerId, + const waitTime = Date.now() - query.timestamp + this.log( + `processing query from handler #${query.handlerId} (waited ${waitTime}ms)`, ) - if (i === -1) { - // we didn't find any other query from the same client! - this.log( - `transaction started, but no query from the same handler id found in queue`, - this.lastHandlerId, - ) - query = null - } else { - query = this.queue.splice(i, 1)[0] - } - } else { - query = this.queue.shift() - } - if (!query) break - const waitTime = Date.now() - query.timestamp - this.log( - `processing query from handler #${query.handlerId} (waited ${waitTime}ms)`, - ) + let result = 0 + let readyForQueryStatus: null | string = null + const parser = new ProtocolParser() - let result = 0 - try { - // Execute the query with exclusive access to PGlite - await this.db.runExclusive(async () => { - return await this.db.execProtocolRawStream(query.message, { - onRawData: (data) => { - result += data.length - query.onData(data) - }, + try { + // Keep one handler attached to the backend until it reaches a + // ReadyForQuery boundary so extended-protocol state can't interleave. + await this.db.runExclusive(async () => { + return await this.db.execProtocolRawStream(query.message, { + onRawData: (data) => { + result += data.length + parser.parse(data, (message) => { + if (message.name === 'readyForQuery') { + readyForQueryStatus = message.status + } + }) + query.onData(data) + }, + }) }) - }) - } catch (error) { - this.log(`query from handler #${query.handlerId} failed:`, error) - query.reject(error as Error) - return - } + } catch (error) { + this.log(`query from handler #${query.handlerId} failed:`, error) + if (this.activeHandlerId === query.handlerId) { + this.activeHandlerId = null + } + query.reject(error as Error) + continue + } - this.log( - `query from handler #${query.handlerId} completed, ${result} bytes`, - ) - this.lastHandlerId = query.handlerId - query.resolve(result) + this.log( + `query from handler #${query.handlerId} completed, ${result} bytes`, + ) + this.updateActiveHandler( + query.handlerId, + readyForQueryStatus, + query.message, + ) + query.resolve(result) + } + } finally { + this.processing = false + this.log(`queue processing complete, queue length is`, this.queue.length) } - - this.processing = false - this.log(`queue processing complete, queue length is`, this.queue.length) } getQueueLength(): number { @@ -147,11 +197,16 @@ class QueryQueueManager { } async clearTransactionIfNeeded(handlerId: number): Promise { - if (this.db.isInTransaction() && this.lastHandlerId === handlerId) { + if (this.activeHandlerId !== handlerId) { + return + } + + if (this.db.isInTransaction()) { await this.db.exec('ROLLBACK') - this.lastHandlerId = null - await this.processQueue() } + + this.activeHandlerId = null + await this.processQueue() } } diff --git a/packages/pglite-socket/tests/query-with-postgres-js-concurrency.test.ts b/packages/pglite-socket/tests/query-with-postgres-js-concurrency.test.ts new file mode 100644 index 000000000..bc4a800ae --- /dev/null +++ b/packages/pglite-socket/tests/query-with-postgres-js-concurrency.test.ts @@ -0,0 +1,55 @@ +import { afterAll, beforeAll, describe, expect, it } from 'vitest' +import postgres from 'postgres' +import { PGlite } from '@electric-sql/pglite' +import { PGLiteSocketServer } from '../src' + +const TEST_PORT = 5435 + +describe('PGLite Socket Server concurrency regression', () => { + let db: PGlite + let server: PGLiteSocketServer + let sql: ReturnType + + beforeAll(async () => { + db = await PGlite.create() + await db.waitReady + + server = new PGLiteSocketServer({ + db, + host: '127.0.0.1', + port: TEST_PORT, + maxConnections: 10, + }) + + await server.start() + + sql = postgres({ + host: '127.0.0.1', + port: TEST_PORT, + database: 'postgres', + username: 'postgres', + password: 'postgres', + idle_timeout: 5, + connect_timeout: 10, + max: 10, + }) + }) + + afterAll(async () => { + await sql?.end({ timeout: 1 }).catch(() => {}) + await server?.stop().catch(() => {}) + await db?.close().catch(() => {}) + }) + + it('keeps extended protocol state isolated across pooled connections', async () => { + for (let i = 0; i < 20; i++) { + const [valueResult, timezoneResult] = await Promise.all([ + sql.unsafe('select $1::int as value', [i]), + sql.unsafe("select current_setting('timezone') as timezone", []), + ]) + + expect(valueResult[0].value).toBe(i) + expect(typeof timezoneResult[0].timezone).toBe('string') + } + }) +})