diff --git a/src/lib/client.ts b/src/lib/client.ts index 6eb6dbcb1..b9b98c109 100644 --- a/src/lib/client.ts +++ b/src/lib/client.ts @@ -235,7 +235,7 @@ export interface IClientOptions extends ISecureClientOptions { url: string, options: IClientOptions, client: MqttClient, - ) => string + ) => string | Promise /** when defined this function will be called to create the Websocket instance, used to add custom protocols or websocket implementations */ createWebsocket?: ( diff --git a/src/lib/connect/ali.ts b/src/lib/connect/ali.ts index 1c275cf46..c2fd695d1 100644 --- a/src/lib/connect/ali.ts +++ b/src/lib/connect/ali.ts @@ -54,7 +54,7 @@ function buildUrl(opts: IClientOptions, client: MqttClient) { url = `${protocol}://${opts.hostname}:${opts.port}${opts.path}` } if (typeof opts.transformWsUrl === 'function') { - url = opts.transformWsUrl(url, opts, client) + return opts.transformWsUrl(url, opts, client) } return url } @@ -111,18 +111,26 @@ const buildStream: StreamBuilder = (client, opts): IStream => { setDefaultOpts(opts) - const url = buildUrl(opts, client) my = opts.my - // https://miniprogram.alipay.com/docs/miniprogram/mpdev/api_network_connectsocket - my.connectSocket({ - url, - protocols: websocketSubProtocol, - }) - proxy = buildProxy() stream = new BufferedDuplex(opts, proxy, my) - bindEventHandler() + const urlOrPromise = buildUrl(opts, client) + + const connectSocket = (url: string) => { + // https://miniprogram.alipay.com/docs/miniprogram/mpdev/api_network_connectsocket + my.connectSocket({ + url, + protocols: websocketSubProtocol, + }) + bindEventHandler() + } + + if (urlOrPromise instanceof Promise) { + urlOrPromise.then(connectSocket).catch((err) => stream.destroy(err)) + } else { + connectSocket(urlOrPromise) + } return stream } diff --git a/src/lib/connect/ws.ts b/src/lib/connect/ws.ts index f5af9830a..f9ed9b39c 100644 --- a/src/lib/connect/ws.ts +++ b/src/lib/connect/ws.ts @@ -1,7 +1,7 @@ import { Buffer } from 'buffer' import Ws, { type ClientOptions } from 'ws' import _debug from 'debug' -import { Transform } from 'readable-stream' +import { Duplex, Transform } from 'readable-stream' import { type IStream, type StreamBuilder } from '../shared' import isBrowser from '../is-browser' import { type IClientOptions } from '../client' @@ -20,9 +20,10 @@ const WSS_OPTIONS = [ ] function buildUrl(opts: IClientOptions, client: MqttClient) { - let url = `${opts.protocol}://${opts.hostname}:${opts.port}${opts.path}` + const url = `${opts.protocol}://${opts.hostname}:${opts.port}${opts.path}` if (typeof opts.transformWsUrl === 'function') { - url = opts.transformWsUrl(url, opts, client) + const result = opts.transformWsUrl(url, opts, client) + return result } return url } @@ -127,7 +128,20 @@ function createBrowserWebSocket(client: MqttClient, opts: IClientOptions) { ? 'mqttv3.1' : 'mqtt' - const url = buildUrl(opts, client) + const urlOrPromise = buildUrl(opts, client) + if (urlOrPromise instanceof Promise) { + return urlOrPromise.then((url) => { + let socket: WebSocket + if (opts.createWebsocket) { + socket = opts.createWebsocket(url, [websocketSubProtocol], opts) + } else { + socket = new WebSocket(url, [websocketSubProtocol]) + } + socket.binaryType = 'arraybuffer' + return socket + }) + } + const url = urlOrPromise let socket: WebSocket if (opts.createWebsocket) { socket = opts.createWebsocket(url, [websocketSubProtocol], opts) @@ -144,16 +158,87 @@ const streamBuilder: StreamBuilder = (client, opts): IStream => { options.hostname = options.hostname || options.host || 'localhost' - const url = buildUrl(options, client) - const socket = createWebSocket(client, url, options) - // @ts-expect-error - This is a type confusion because of the overlap between browser oriented code and Node.js oriented code. - const webSocketStream = Ws.createWebSocketStream(socket, options.wsOptions) + const urlOrPromise = buildUrl(options, client) + + if (typeof urlOrPromise === 'string') { + const url = urlOrPromise + const socket = createWebSocket(client, url, options) + const webSocketStream = Ws.createWebSocketStream( + socket, + + options.wsOptions as any, + ) + + webSocketStream['url'] = url + socket.on('close', () => { + webSocketStream.destroy() + }) + return webSocketStream + } - webSocketStream['url'] = url - socket.on('close', () => { - webSocketStream.destroy() + // async case: buffer data until the URL promise resolves then create the WebSocket + const writeQueue: Array<{ + chunk: any + encoding: string + cb: (err?: Error) => void + }> = [] + let wsStream: ReturnType | null = null + let deferredDestroyed = false + + const deferredStream = new Duplex({ + read() { + // push model - data is pushed from wsStream events once + // the WebSocket connection is established + }, + write(chunk, encoding, cb) { + if (wsStream) { + wsStream.write(chunk, encoding, cb) + } else { + writeQueue.push({ chunk, encoding, cb }) + } + }, + destroy(err, cb) { + deferredDestroyed = true + if (wsStream) wsStream.destroy(err) + cb(err) + }, }) - return webSocketStream + + urlOrPromise + .then((url) => { + if (deferredDestroyed) return + const socket = createWebSocket(client, url, options) + + wsStream = Ws.createWebSocketStream( + socket, + options.wsOptions as any, + ) + deferredStream['url'] = url + + wsStream.on('data', (chunk) => { + deferredStream.push(chunk) + }) + wsStream.on('end', () => { + deferredStream.push(null) + }) + wsStream.on('error', (err) => { + deferredStream.destroy(err) + }) + socket.on('close', () => { + wsStream!.destroy() + }) + + // flush buffered writes + const queue = writeQueue.splice(0) + for (const { chunk, encoding, cb } of queue) { + wsStream.write(chunk, encoding as BufferEncoding, cb) + } + }) + .catch((err) => { + deferredStream.destroy(err) + }) + + return deferredStream } /* istanbul ignore next */ @@ -168,8 +253,8 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { const coerceToBuffer = !opts.objectMode - // the websocket connection - const socket = createBrowserWebSocket(client, opts) + // mutable socket reference - set once the socket is available + let socketRef: WebSocket | null = null // the proxy is a transform stream that forwards data to the socket // it ensures data written to socket is a Buffer @@ -178,36 +263,9 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { if (!opts.objectMode) { proxy._writev = writev.bind(proxy) } - proxy.on('close', () => { - socket.close() - }) - - const eventListenerSupport = typeof socket.addEventListener !== 'undefined' - // was already open when passed in - if (socket.readyState === socket.OPEN) { - stream = proxy - stream.socket = socket - } else { - // socket is not open. Use this to buffer writes until it is opened - stream = new BufferedDuplex(opts, proxy, socket) - - if (eventListenerSupport) { - socket.addEventListener('open', onOpen) - } else { - socket.onopen = onOpen - } - } - - if (eventListenerSupport) { - socket.addEventListener('close', onClose) - socket.addEventListener('error', onError) - socket.addEventListener('message', onMessage) - } else { - socket.onclose = onClose - socket.onerror = onError - socket.onmessage = onMessage - } + // the websocket connection (may be a Promise if transformWsUrl is async) + const socketOrPromise = createBrowserWebSocket(client, opts) // methods for browserStreamBuilder @@ -226,6 +284,86 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { return _proxy } + function attachSocketHandlers(socket: WebSocket) { + socketRef = socket + proxy.on('close', () => { + socket.close() + }) + + const eventListenerSupport = + typeof socket.addEventListener !== 'undefined' + + // was already open when passed in + if (socket.readyState === socket.OPEN) { + stream = proxy + stream.socket = socket + } else { + // socket is not open. Use this to buffer writes until it is opened + stream = new BufferedDuplex(opts, proxy, socket) + + if (eventListenerSupport) { + socket.addEventListener('open', onOpen) + } else { + socket.onopen = onOpen + } + } + + if (eventListenerSupport) { + socket.addEventListener('close', onClose) + socket.addEventListener('error', onError) + socket.addEventListener('message', onMessage) + } else { + socket.onclose = onClose + socket.onerror = onError + socket.onmessage = onMessage + } + } + + if (socketOrPromise instanceof Promise) { + // async case: create a BufferedDuplex immediately to buffer writes, + // then wire up the real socket once the URL promise resolves. + // Note: BufferedDuplex only stores the socket reference and does not + // call any methods on it, so an empty placeholder is safe here. + const placeholderSocket = { + close() {}, + } as unknown as WebSocket + stream = new BufferedDuplex(opts, proxy, placeholderSocket) + socketOrPromise + .then((socket) => { + socketRef = socket + ;(stream as BufferedDuplex).socket = socket + + const eventListenerSupport = + typeof socket.addEventListener !== 'undefined' + + if (eventListenerSupport) { + socket.addEventListener('open', onOpen) + socket.addEventListener('close', onClose) + socket.addEventListener('error', onError) + socket.addEventListener('message', onMessage) + } else { + socket.onopen = onOpen + socket.onclose = onClose + socket.onerror = onError + socket.onmessage = onMessage + } + + if (socket.readyState === socket.OPEN) { + onOpen() + } + + // wire up proxy close to close the real socket + proxy.on('close', () => { + socket.close() + }) + }) + .catch((err) => { + stream.destroy(err) + }) + } else { + attachSocketHandlers(socketOrPromise) + } + function onOpen() { debug('WebSocket onOpen') if (stream instanceof BufferedDuplex) { @@ -272,7 +410,11 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { enc: string, next: (err?: Error) => void, ) { - if (socket.bufferedAmount > bufferSize) { + if (!socketRef) { + next(new Error('WebSocket is not yet available')) + return + } + if (socketRef.bufferedAmount > bufferSize) { // throttle data until buffered amount is reduced. setTimeout(socketWriteBrowser, bufferTimeout, chunk, enc, next) return @@ -284,7 +426,7 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { try { // https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send (note this doesn't have a cb as second arg) - socket.send(chunk) + socketRef.send(chunk) } catch (err) { return next(err) } @@ -293,7 +435,9 @@ const browserStreamBuilder: StreamBuilder = (client, opts) => { } function socketEndBrowser(done: (error?: Error, data?: any) => void) { - socket.close() + if (socketRef) { + socketRef.close() + } done() } diff --git a/src/lib/connect/wx.ts b/src/lib/connect/wx.ts index b4b3ec1ed..4783430ea 100644 --- a/src/lib/connect/wx.ts +++ b/src/lib/connect/wx.ts @@ -58,7 +58,7 @@ function buildUrl(opts: IClientOptions, client: MqttClient) { url = `${protocol}://${opts.hostname}:${opts.port}${opts.path}` } if (typeof opts.transformWsUrl === 'function') { - url = opts.transformWsUrl(url, opts, client) + return opts.transformWsUrl(url, opts, client) } return url } @@ -102,39 +102,51 @@ const buildStream: StreamBuilder = (client, opts) => { setDefaultOpts(opts) - const url = buildUrl(opts, client) - // https://github.com/wechat-miniprogram/api-typings/blob/master/types/wx/lib.wx.api.d.ts#L20984 - socketTask = wx.connectSocket({ - url, - protocols: [websocketSubProtocol], - }) - proxy = buildProxy() - stream = new BufferedDuplex(opts, proxy, socketTask) - stream._destroy = (err, cb) => { - socketTask.close({ - success() { - if (cb) cb(err) - }, + // use a temporary placeholder for socketTask until the URL is resolved + stream = new BufferedDuplex(opts, proxy, null as any) + + const connectSocket = (url: string) => { + // https://github.com/wechat-miniprogram/api-typings/blob/master/types/wx/lib.wx.api.d.ts#L20984 + socketTask = wx.connectSocket({ + url, + protocols: [websocketSubProtocol], }) - } + stream.socket = socketTask - const destroyRef = stream.destroy - stream.destroy = (err, cb) => { - stream.destroy = destroyRef - - setTimeout(() => { + stream._destroy = (err, cb) => { socketTask.close({ - fail() { - stream._destroy(err, cb) + success() { + if (cb) cb(err) }, }) - }, 0) + } + + const destroyRef = stream.destroy + stream.destroy = (err, cb) => { + stream.destroy = destroyRef + + setTimeout(() => { + socketTask.close({ + fail() { + stream._destroy(err, cb) + }, + }) + }, 0) - return stream + return stream + } + + bindEventHandler() } - bindEventHandler() + const urlOrPromise = buildUrl(opts, client) + + if (urlOrPromise instanceof Promise) { + urlOrPromise.then(connectSocket).catch((err) => stream.destroy(err)) + } else { + connectSocket(urlOrPromise) + } return stream } diff --git a/test/node/websocket_client.ts b/test/node/websocket_client.ts index c38c5dbf0..69ebf0b4d 100644 --- a/test/node/websocket_client.ts +++ b/test/node/websocket_client.ts @@ -147,6 +147,32 @@ describe('Websocket Client', () => { }) }) + it('should be able to transform the url using an async function', function _test(t, done) { + const baseUrl = 'ws://localhost:9999/mqtt' + const sig = '?AUTH=token' + const expected = baseUrl + sig + let actual: string + const opts = makeOptions({ + path: '/mqtt', + async transformWsUrl(url, opt, client) { + assert.equal(url, baseUrl) + assert.strictEqual(opt, opts) + assert.strictEqual(client.options, opts) + assert(client instanceof mqtt.MqttClient) + actual = url + sig + return actual + }, + }) + const client = mqtt.connect(opts) + + client.on('connect', () => { + // `url` is set in `connect/ws.ts` `streamBuilder` + assert.equal((client.stream as any).url, expected) + assert.equal(actual, expected) + client.end(true, (err) => done(err)) + }) + }) + it('should be able to create custom Websocket instance', function _test(t, done) { const baseUrl = 'ws://localhost:9999/mqtt' let urlInCallback: string