diff --git a/lib/dispatcher/agent.js b/lib/dispatcher/agent.js index a1cc7fd6817..858a5f248f7 100644 --- a/lib/dispatcher/agent.js +++ b/lib/dispatcher/agent.js @@ -1,7 +1,7 @@ 'use strict' const { InvalidArgumentError, MaxOriginsReachedError } = require('../core/errors') -const { kClients, kRunning, kClose, kDestroy, kDispatch, kUrl } = require('../core/symbols') +const { kBusy, kClients, kConnected, kRunning, kClose, kDestroy, kDispatch, kUrl } = require('../core/symbols') const DispatcherBase = require('./dispatcher-base') const Pool = require('./pool') const Client = require('./client') @@ -65,7 +65,7 @@ class Agent extends DispatcherBase { get [kRunning] () { let ret = 0 - for (const { dispatcher } of this[kClients].values()) { + for (const dispatcher of this[kClients].values()) { ret += dispatcher[kRunning] } return ret @@ -86,54 +86,52 @@ class Agent extends DispatcherBase { throw new MaxOriginsReachedError() } - const result = this[kClients].get(key) - let dispatcher = result && result.dispatcher + let dispatcher = this[kClients].get(key) if (!dispatcher) { - const closeClientIfUnused = (connected) => { - const result = this[kClients].get(key) - if (result) { - if (connected) result.count -= 1 - if (result.count <= 0) { - this[kClients].delete(key) - if (!result.dispatcher.destroyed) { - result.dispatcher.close() - } - } + dispatcher = this[kFactory](opts.origin, allowH2 === false + ? { ...this[kOptions], allowH2: false } + : this[kOptions]) - let hasOrigin = false - for (const entry of this[kClients].values()) { - if (entry.origin === origin) { - hasOrigin = true - break - } - } + const closeClientIfUnused = () => { + if (this[kClients].get(key) !== dispatcher) { + return + } + + if (dispatcher[kConnected] > 0 || dispatcher[kBusy]) { + return + } + + this[kClients].delete(key) + if (!dispatcher.destroyed) { + dispatcher.close() + } - if (!hasOrigin) { - this[kOrigins].delete(origin) + let hasOrigin = false + for (const client of this[kClients].values()) { + if (client[kUrl].origin === dispatcher[kUrl].origin) { + hasOrigin = true + break } } + + if (!hasOrigin) { + this[kOrigins].delete(dispatcher[kUrl].origin) + } } - dispatcher = this[kFactory](opts.origin, allowH2 === false - ? { ...this[kOptions], allowH2: false } - : this[kOptions]) + + dispatcher .on('drain', this[kOnDrain]) - .on('connect', (origin, targets) => { - const result = this[kClients].get(key) - if (result) { - result.count += 1 - } - this[kOnConnect](origin, targets) - }) + .on('connect', this[kOnConnect]) .on('disconnect', (origin, targets, err) => { - closeClientIfUnused(true) + closeClientIfUnused() this[kOnDisconnect](origin, targets, err) }) .on('connectionError', (origin, targets, err) => { - closeClientIfUnused(false) + closeClientIfUnused() this[kOnConnectionError](origin, targets, err) }) - this[kClients].set(key, { count: 0, dispatcher, origin }) + this[kClients].set(key, dispatcher) this[kOrigins].add(origin) } @@ -142,7 +140,7 @@ class Agent extends DispatcherBase { [kClose] () { const closePromises = [] - for (const { dispatcher } of this[kClients].values()) { + for (const dispatcher of this[kClients].values()) { closePromises.push(dispatcher.close()) } this[kClients].clear() @@ -152,7 +150,7 @@ class Agent extends DispatcherBase { [kDestroy] (err) { const destroyPromises = [] - for (const { dispatcher } of this[kClients].values()) { + for (const dispatcher of this[kClients].values()) { destroyPromises.push(dispatcher.destroy(err)) } this[kClients].clear() @@ -162,7 +160,7 @@ class Agent extends DispatcherBase { get stats () { const allClientStats = {} - for (const { dispatcher } of this[kClients].values()) { + for (const dispatcher of this[kClients].values()) { if (dispatcher.stats) { allClientStats[dispatcher[kUrl].origin] = dispatcher.stats } diff --git a/lib/dispatcher/client-h1.js b/lib/dispatcher/client-h1.js index ce8b2e0f627..4d21f504706 100644 --- a/lib/dispatcher/client-h1.js +++ b/lib/dispatcher/client-h1.js @@ -1106,7 +1106,7 @@ function writeH1 (client, request) { socket[kReset] = reset } - if (client[kMaxRequests] && socket[kCounter]++ >= client[kMaxRequests]) { + if (client[kMaxRequests] && ++socket[kCounter] >= client[kMaxRequests]) { socket[kReset] = true } diff --git a/lib/mock/mock-agent.js b/lib/mock/mock-agent.js index 61449e077ea..17a7b717c21 100644 --- a/lib/mock/mock-agent.js +++ b/lib/mock/mock-agent.js @@ -167,7 +167,7 @@ class MockAgent extends Dispatcher { } [kMockAgentSet] (origin, dispatcher) { - this[kClients].set(origin, { count: 0, dispatcher }) + this[kClients].set(origin, dispatcher) } [kFactory] (origin) { @@ -179,9 +179,9 @@ class MockAgent extends Dispatcher { [kMockAgentGet] (origin) { // First check if we can immediately find it - const result = this[kClients].get(origin) - if (result?.dispatcher) { - return result.dispatcher + const dispatcher = this[kClients].get(origin) + if (dispatcher) { + return dispatcher } // If the origin is not a string create a dummy parent pool and return to user @@ -192,11 +192,11 @@ class MockAgent extends Dispatcher { } // If we match, create a pool and assign the same dispatches - for (const [keyMatcher, result] of Array.from(this[kClients])) { - if (result && typeof keyMatcher !== 'string' && matchValue(keyMatcher, origin)) { + for (const [keyMatcher, nonExplicitDispatcher] of Array.from(this[kClients])) { + if (nonExplicitDispatcher && typeof keyMatcher !== 'string' && matchValue(keyMatcher, origin)) { const dispatcher = this[kFactory](origin) this[kMockAgentSet](origin, dispatcher) - dispatcher[kDispatches] = result.dispatcher[kDispatches] + dispatcher[kDispatches] = nonExplicitDispatcher[kDispatches] return dispatcher } } @@ -210,7 +210,7 @@ class MockAgent extends Dispatcher { const mockAgentClients = this[kClients] return Array.from(mockAgentClients.entries()) - .flatMap(([origin, result]) => result.dispatcher[kDispatches].map(dispatch => ({ ...dispatch, origin }))) + .flatMap(([origin, dispatcher]) => dispatcher[kDispatches].map(dispatch => ({ ...dispatch, origin }))) .filter(({ pending }) => pending) } diff --git a/test/agent-connection-management.js b/test/agent-connection-management.js new file mode 100644 index 00000000000..cd520a304de --- /dev/null +++ b/test/agent-connection-management.js @@ -0,0 +1,156 @@ +const { test, describe } = require('node:test') +const assert = require('node:assert') +const { createServer } = require('node:http') +const { request, Agent, Pool } = require('..') + +// https://github.com/nodejs/undici/issues/4424 +describe('Agent should close inactive clients', () => { + test('without active connections', async (t) => { + const server = createServer({ keepAliveTimeout: 0 }, async (_req, res) => { + res.setHeader('connection', 'close') + res.writeHead(200) + res.end('ok') + }).listen(0) + + t.after(() => { + server.closeAllConnections?.() + server.close() + }) + + /** @type {Promise} */ + let p + const agent = new Agent({ + factory: (origin, opts) => { + const pool = new Pool(origin, opts) + const { promise, resolve, reject } = Promise.withResolvers() + p = promise + pool.on('disconnect', () => { + setImmediate(() => pool.destroyed ? resolve() : reject(new Error('client not destroyed'))) + }) + return pool + } + }) + const { statusCode } = await request(`http://localhost:${server.address().port}`, { dispatcher: agent }) + assert.equal(statusCode, 200) + + await p + }) + + test('in case of connection error', async (t) => { + /** @type {Promise} */ + let p + const agent = new Agent({ + factory: (origin, opts) => { + const pool = new Pool(origin, opts) + const { promise, resolve, reject } = Promise.withResolvers() + p = promise + pool.on('connectionError', () => { + setImmediate(() => pool.destroyed ? resolve() : reject(new Error('client not destroyed'))) + }) + return pool + } + }) + try { + await request('http://localhost:0', { dispatcher: agent }) + } catch (_) { + // ignore + } + + await p + }) +}) + +// https://github.com/nodejs/undici/issues/5022 +describe('Agent should not close active clients', () => { + test('should reuse replacement keep-alive connection after server closes the previous one', async (t) => { + let nextSocketId = 0 + const socketIds = new Map() + const requestsPerSocket = new Map() + + const server = createServer((req, res) => { + const socket = req.socket + if (!socketIds.has(socket)) { + socketIds.set(socket, ++nextSocketId) + } + + const count = (requestsPerSocket.get(socket) || 0) + 1 + requestsPerSocket.set(socket, count) + + const remaining = 3 - count + res.setHeader('x-socket-id', String(socketIds.get(socket))) + + if (remaining > 0) { + res.setHeader('connection', 'Keep-Alive') + res.setHeader('keep-alive', `timeout=30, max=${remaining}`) + } else { + res.setHeader('connection', 'close') + } + + res.writeHead(200) + res.end('ok') + }).listen(0) + + t.after(() => { + server.closeAllConnections?.() + server.close() + }) + + const agent = new Agent({ connections: 1 }) + t.after(() => agent.close()) + + const socketSequence = [] + for (let i = 0; i < 5; i++) { + const { statusCode, headers, body } = await request(`http://localhost:${server.address().port}`, { + dispatcher: agent + }) + + assert.equal(statusCode, 200) + await body.dump() + socketSequence.push(headers['x-socket-id']) + } + + assert.deepEqual(socketSequence.slice(0, 3), ['1', '1', '1']) + assert.deepEqual(socketSequence.slice(3), ['2', '2']) + }) + + test('should reuse replacement connection after keep-alive max closes the previous one', async (t) => { + let nextSocketId = 0 + const socketIds = new Map() + + const server = createServer((req, res) => { + const socket = req.socket + if (!socketIds.has(socket)) { + socketIds.set(socket, ++nextSocketId) + } + + res.setHeader('x-socket-id', String(socketIds.get(socket))) + res.setHeader('connection', 'Keep-Alive') + res.setHeader('keep-alive', 'timeout=30') + + res.writeHead(200) + res.end('ok') + }).listen(0) + + t.after(() => { + server.closeAllConnections?.() + server.close() + }) + + const agent = new Agent({ connections: 1, maxRequestsPerClient: 3 }) + t.after(() => agent.close()) + + const socketSequence = [] + for (let i = 0; i < 5; i++) { + const { statusCode, headers, body } = await request(`http://localhost:${server.address().port}`, { + dispatcher: agent + }) + + assert.equal(statusCode, 200) + await body.dump() + socketSequence.push(headers['x-socket-id']) + } + + assert.deepEqual(socketSequence.slice(0, 3), ['1', '1', '1']) + assert.deepEqual(socketSequence.slice(3), ['2', '2']) + }) +}) diff --git a/test/close-and-destroy.js b/test/close-and-destroy.js index e52c0072553..582bd8cbaba 100644 --- a/test/close-and-destroy.js +++ b/test/close-and-destroy.js @@ -265,17 +265,20 @@ test('close after and destroy should error', async (t) => { test('close socket and reconnect after maxRequestsPerClient reached', async (t) => { t = tspl(t, { plan: 1 }) + let nextConnectionId = 0 + const socketToIdMap = new Map() + const connectionUsedForRequest = [] const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + connectionUsedForRequest.push(socketToIdMap.get(req.socket)) res.end(req.url) }) after(() => server.close()) server.listen(0, async () => { - let connections = 0 - server.on('connection', () => { - connections++ + server.on('connection', (sock) => { + socketToIdMap.set(sock, nextConnectionId++) }) const client = new Client( `http://localhost:${server.address().port}`, @@ -287,7 +290,7 @@ test('close socket and reconnect after maxRequestsPerClient reached', async (t) await makeRequest() await makeRequest() await makeRequest() - t.strictEqual(connections, 2) + t.deepEqual(connectionUsedForRequest, [0, 0, 1, 1]) function makeRequest () { return client.request({ path: '/', method: 'GET' }) @@ -299,17 +302,20 @@ test('close socket and reconnect after maxRequestsPerClient reached', async (t) test('close socket and reconnect after maxRequestsPerClient reached (async)', async (t) => { t = tspl(t, { plan: 1 }) + let nextConnectionId = 0 + const socketToIdMap = new Map() + const connectionUsedForRequest = [] const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { + connectionUsedForRequest.push(socketToIdMap.get(req.socket)) res.end(req.url) }) after(() => server.close()) server.listen(0, async () => { - let connections = 0 - server.on('connection', () => { - connections++ + server.on('connection', (sock) => { + socketToIdMap.set(sock, nextConnectionId++) }) const client = new Client( `http://localhost:${server.address().port}`, @@ -323,7 +329,7 @@ test('close socket and reconnect after maxRequestsPerClient reached (async)', as makeRequest(), makeRequest() ]) - t.strictEqual(connections, 2) + t.deepEqual(connectionUsedForRequest, [0, 0, 1, 1]) function makeRequest () { return client.request({ path: '/', method: 'GET' }) diff --git a/test/issue-4244.js b/test/issue-4244.js deleted file mode 100644 index 489557fe495..00000000000 --- a/test/issue-4244.js +++ /dev/null @@ -1,65 +0,0 @@ -const { test, describe } = require('node:test') -const assert = require('node:assert') -const { createServer } = require('node:http') -const { request, Agent, Pool } = require('..') - -// https://github.com/nodejs/undici/issues/4424 -describe('Agent should close inactive clients', () => { - test('without active connections', async (t) => { - const server = createServer({ keepAliveTimeout: 0 }, async (_req, res) => { - res.setHeader('connection', 'close') - res.writeHead(200) - res.end('ok') - }).listen(0) - - t.after(() => { - server.closeAllConnections?.() - server.close() - }) - - let p - const agent = new Agent({ - factory: (origin, opts) => { - const pool = new Pool(origin, opts) - let _resolve, _reject - p = new Promise((resolve, reject) => { - _resolve = resolve - _reject = reject - }) - pool.on('disconnect', () => { - setImmediate(() => pool.destroyed ? _resolve() : _reject(new Error('client not destroyed'))) - }) - return pool - } - }) - const { statusCode } = await request(`http://localhost:${server.address().port}`, { dispatcher: agent }) - assert.equal(statusCode, 200) - - await p - }) - - test('in case of connection error', async (t) => { - let p - const agent = new Agent({ - factory: (origin, opts) => { - const pool = new Pool(origin, opts) - let _resolve, _reject - p = new Promise((resolve, reject) => { - _resolve = resolve - _reject = reject - }) - pool.on('connectionError', () => { - setImmediate(() => pool.destroyed ? _resolve() : _reject(new Error('client not destroyed'))) - }) - return pool - } - }) - try { - await request('http://localhost:0', { dispatcher: agent }) - } catch (_) { - // ignore - } - - await p - }) -})