diff --git a/src/connection.ts b/src/connection.ts index 7f374ab8a..9159a9188 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1035,6 +1035,14 @@ class Connection extends EventEmitter { */ declare requestTimer: undefined | NodeJS.Timeout; + /** + * Controller used to abort the connection establishment process + * when the connection is closed before it was fully established. + * + * @private + */ + declare closeController: undefined | AbortController; + /** * Whether an attention message was sent to the server to cancel the * currently active request. @@ -1832,6 +1840,8 @@ class Connection extends EventEmitter { this.once('error', onError); } + this.closeController = new AbortController(); + this.transitionTo(this.STATE.CONNECTING); this.initialiseConnection().then(() => { process.nextTick(() => { @@ -1839,14 +1849,12 @@ class Connection extends EventEmitter { }); }, (err) => { this.transitionTo(this.STATE.FINAL); - this.closed = true; process.nextTick(() => { this.emit('connect', err); }); - process.nextTick(() => { - this.emit('end'); - }); + + this.cleanupConnection(); }); } @@ -1991,6 +1999,8 @@ class Connection extends EventEmitter { * The [[Event_end]] will be emitted once the connection has been closed. */ close() { + this.closeController?.abort(new ConnectionError('Connection closed before the connection was established.', 'ECLOSE')); + this.transitionTo(this.STATE.FINAL); this.cleanupConnection(); } @@ -2016,7 +2026,7 @@ class Connection extends EventEmitter { }, this.config.options.connectTimeout); try { - let signal = timeoutController.signal; + let signal = AbortSignal.any([timeoutController.signal, this.closeController!.signal]); let port = this.config.options.port; @@ -2067,6 +2077,11 @@ class Connection extends EventEmitter { try { signal = AbortSignal.any([signal, controller.signal]); + // The connection may have been closed while we were waiting for the + // socket to connect. Adding an abort listener to an already aborted + // signal will not call the listener, so we need to check here. + signal.throwIfAborted(); + socket.setKeepAlive(true, KEEP_ALIVE_INITIAL_DELAY); this.messageIo = new MessageIO(socket, this.config.options.packetSize, this.debug); @@ -2074,7 +2089,6 @@ class Connection extends EventEmitter { this.socket = socket; - this.closed = false; this.debug.log('connected to ' + this.config.server + ':' + this.config.options.port); this.sendPreLogin(); @@ -3424,9 +3438,21 @@ class Connection extends EventEmitter { const port = this.routingData ? this.routingData.port : this.config.options.port; this.debug.log('Retry after transient failure connecting to ' + server + ':' + port); - const { promise, resolve } = withResolvers(); - setTimeout(resolve, this.config.options.connectionRetryInterval); - await promise; + const closeSignal = this.closeController!.signal; + closeSignal.throwIfAborted(); + + const { promise, resolve, reject } = withResolvers(); + + const onAbort = () => { reject(closeSignal.reason); }; + closeSignal.addEventListener('abort', onAbort, { once: true }); + + const retryTimer = setTimeout(resolve, this.config.options.connectionRetryInterval); + try { + await promise; + } finally { + clearTimeout(retryTimer); + closeSignal.removeEventListener('abort', onAbort); + } this.emit('retry'); this.transitionTo(this.STATE.CONNECTING); diff --git a/test/unit/connection-close-test.ts b/test/unit/connection-close-test.ts new file mode 100644 index 000000000..32a6e47b5 --- /dev/null +++ b/test/unit/connection-close-test.ts @@ -0,0 +1,190 @@ +import { assert } from 'chai'; +import * as net from 'net'; +import { Connection, ConnectionError } from '../../src/tedious'; +import IncomingMessageStream from '../../src/incoming-message-stream'; +import OutgoingMessageStream from '../../src/outgoing-message-stream'; +import Debug from '../../src/debug'; +import PreloginPayload from '../../src/prelogin-payload'; +import Message from '../../src/message'; + +function buildLoginAckToken(): Buffer { + const progname = 'Tedious SQL Server'; + + const buffer = Buffer.from([ + 0xAD, // Type + 0x00, 0x00, // Length + 0x00, // interface number - SQL + 0x74, 0x00, 0x00, 0x04, // TDS version number + Buffer.byteLength(progname, 'ucs2') / 2, ...Buffer.from(progname, 'ucs2'), // Progname + 0x00, // major + 0x00, // minor + 0x00, 0x00, // buildNum + ]); + + buffer.writeUInt16LE(buffer.length - 3, 1); + + return buffer; +} + +describe('Closing a connection while connecting', function() { + let server: net.Server; + let _connections: net.Socket[]; + + beforeEach(function(done) { + _connections = []; + server = net.createServer((connection) => { + _connections.push(connection); + }); + server.listen(0, '127.0.0.1', done); + }); + + afterEach(function(done) { + _connections.forEach((connection) => { + connection.destroy(); + }); + + server.close(done); + }); + + it('should abort the connection process when `close` is called immediately after `connect`', function(done) { + // A server that responds to the full login sequence. Without aborting + // the connection process, the connection will happily continue to log + // in and end up in the `LoggedIn` state, despite being closed. + server.on('connection', async (connection) => { + const debug = new Debug(); + const incomingMessageStream = new IncomingMessageStream(debug); + const outgoingMessageStream = new OutgoingMessageStream(debug, { packetSize: 4 * 1024 }); + + connection.pipe(incomingMessageStream); + outgoingMessageStream.pipe(connection); + + try { + const messageIterator = incomingMessageStream[Symbol.asyncIterator](); + + // PRELOGIN + { + const { value: message, done } = await messageIterator.next(); + if (done) { + return; + } + assert.strictEqual(message.type, 0x12); + + const chunks: Buffer[] = []; + for await (const data of message) { + chunks.push(data); + } + + const responsePayload = new PreloginPayload({ encrypt: false, version: { major: 1, minor: 2, build: 3, subbuild: 0 } }); + const responseMessage = new Message({ type: 0x12 }); + responseMessage.end(responsePayload.data); + outgoingMessageStream.write(responseMessage); + } + + // LOGIN7 + { + const { value: message, done } = await messageIterator.next(); + if (done) { + return; + } + assert.strictEqual(message.type, 0x10); + + const chunks: Buffer[] = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(buildLoginAckToken()); + outgoingMessageStream.write(responseMessage); + } + + // SQL Batch (Initial SQL) + { + const { value: message, done } = await messageIterator.next(); + if (done) { + return; + } + assert.strictEqual(message.type, 0x01); + + const chunks: Buffer[] = []; + for await (const data of message) { + chunks.push(data); + } + + const responseMessage = new Message({ type: 0x04 }); + responseMessage.end(); + outgoingMessageStream.write(responseMessage); + } + } catch (err) { + console.log(err); + } + }); + + const connection = new Connection({ + server: (server.address() as net.AddressInfo).address, + options: { + port: (server.address() as net.AddressInfo).port, + encrypt: false + } + }); + + let endCount = 0; + connection.on('end', () => { + endCount += 1; + }); + + connection.connect((err) => { + assert.instanceOf(err, ConnectionError); + assert.strictEqual(err.code, 'ECLOSE'); + assert.strictEqual('Connection closed before the connection was established.', err.message); + + assert.strictEqual(endCount, 1); + + // Ensure no additional `end` event is emitted afterwards. + setImmediate(() => { + assert.strictEqual(endCount, 1); + + done(); + }); + }); + + connection.close(); + }); + + it('should abort the connection process when `close` is called while the connection is being established', function(done) { + // A server that accepts connections but never responds. + server.on('connection', () => { + setImmediate(() => { + connection.close(); + }); + }); + + const connection = new Connection({ + server: (server.address() as net.AddressInfo).address, + options: { + port: (server.address() as net.AddressInfo).port, + encrypt: false, + connectTimeout: 1000 + } + }); + + let endCount = 0; + connection.on('end', () => { + endCount += 1; + }); + + connection.connect((err) => { + assert.instanceOf(err, ConnectionError); + assert.strictEqual(err.code, 'ECLOSE'); + + assert.strictEqual(endCount, 1); + + // Ensure no additional `end` event is emitted afterwards. + setImmediate(() => { + assert.strictEqual(endCount, 1); + + done(); + }); + }); + }); +});