Skip to content

Commit 0b45b9f

Browse files
committed
tls
1 parent 8adbc16 commit 0b45b9f

File tree

4 files changed

+113
-45
lines changed

4 files changed

+113
-45
lines changed

apps/browser-proxy/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
"dependencies": {
1010
"@aws-sdk/client-s3": "^3.645.0",
1111
"debug": "^4.3.7",
12+
"expiry-map": "^2.0.0",
1213
"findhit-proxywrap": "^0.3.13",
13-
"pg-gateway": "0.3.0-alpha.7",
14+
"p-memoize": "^7.1.1",
15+
"pg-gateway": "^0.3.0-alpha.7",
1416
"ws": "^8.18.0"
1517
},
1618
"devDependencies": {

apps/browser-proxy/src/index.ts

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,35 @@
11
import * as nodeNet from 'node:net'
22
import * as https from 'node:https'
3-
import { PostgresConnection } from 'pg-gateway'
3+
import { BackendError, PostgresConnection } from 'pg-gateway'
4+
import { fromNodeSocket } from 'pg-gateway/node'
45
import { WebSocketServer, type WebSocket } from 'ws'
56
import makeDebug from 'debug'
67
import * as tls from 'node:tls'
78
import { extractDatabaseId, isValidServername } from './servername.ts'
8-
import { getTls } from './tls.ts'
9+
import { getTls, setSecureContext } from './tls.ts'
910
import { createStartupMessage } from './create-message.ts'
1011
import { extractIP } from './extract-ip.ts'
1112

1213
const debug = makeDebug('browser-proxy')
1314

14-
const tcpConnections = new Map<string, nodeNet.Socket>()
15+
const tcpConnections = new Map<string, PostgresConnection>()
1516
const websocketConnections = new Map<string, WebSocket>()
1617

17-
let tlsOptions = await getTls()
18-
19-
// refresh the TLS certificate every week
20-
setInterval(
21-
async () => {
22-
tlsOptions = await getTls()
23-
httpsServer.setSecureContext(tlsOptions)
24-
},
25-
1000 * 60 * 60 * 24 * 7
26-
)
27-
2818
const httpsServer = https.createServer({
29-
...tlsOptions,
3019
SNICallback: (servername, callback) => {
3120
debug('SNICallback', servername)
3221
if (isValidServername(servername)) {
3322
debug('SNICallback', 'valid')
34-
callback(null, tls.createSecureContext(tlsOptions))
23+
callback(null)
3524
} else {
3625
debug('SNICallback', 'invalid')
3726
callback(new Error('invalid SNI'))
3827
}
3928
},
4029
})
30+
await setSecureContext(httpsServer)
31+
// reset the secure context every week to pick up any new TLS certificates
32+
setInterval(() => setSecureContext(httpsServer), 1000 * 60 * 60 * 24 * 7)
4133

4234
const websocketServer = new WebSocketServer({
4335
server: httpsServer,
@@ -70,8 +62,8 @@ websocketServer.on('connection', (socket, request) => {
7062

7163
socket.on('message', (data: Buffer) => {
7264
debug('websocket message', data.toString('hex'))
73-
const tcpSocket = tcpConnections.get(databaseId)
74-
tcpSocket?.write(data)
65+
const tcpConnection = tcpConnections.get(databaseId)
66+
tcpConnection?.streamWriter?.write(data)
7567
})
7668

7769
socket.on('close', () => {
@@ -86,50 +78,41 @@ const net = (
8678

8779
const tcpServer = net.createServer()
8880

89-
tcpServer.on('connection', (socket) => {
81+
tcpServer.on('connection', async (socket) => {
9082
let databaseId: string | undefined
9183

92-
const connection = new PostgresConnection(socket, {
93-
tls: tlsOptions,
84+
const connection = await fromNodeSocket(socket, {
85+
tls: getTls,
9486
onTlsUpgrade(state) {
95-
if (!state.tlsInfo?.sniServerName || !isValidServername(state.tlsInfo.sniServerName)) {
96-
// connection.detach()
97-
connection.sendError({
87+
if (!state.tlsInfo?.serverName || !isValidServername(state.tlsInfo.serverName)) {
88+
throw BackendError.create({
9889
code: '08006',
9990
message: 'invalid SNI',
10091
severity: 'FATAL',
10192
})
102-
connection.end()
103-
return
10493
}
10594

106-
const _databaseId = extractDatabaseId(state.tlsInfo.sniServerName!)
95+
const _databaseId = extractDatabaseId(state.tlsInfo.serverName!)
10796

10897
if (!websocketConnections.has(_databaseId!)) {
109-
// connection.detach()
110-
connection.sendError({
98+
throw BackendError.create({
11199
code: 'XX000',
112100
message: 'the browser is not sharing the database',
113101
severity: 'FATAL',
114102
})
115-
connection.end()
116-
return
117103
}
118104

119105
if (tcpConnections.has(_databaseId)) {
120-
// connection.detach()
121-
connection.sendError({
106+
throw BackendError.create({
122107
code: '53300',
123108
message: 'sorry, too many clients already',
124109
severity: 'FATAL',
125110
})
126-
connection.end()
127-
return
128111
}
129112

130113
// only set the databaseId after we've verified the connection
131114
databaseId = _databaseId
132-
tcpConnections.set(databaseId!, connection.socket)
115+
tcpConnections.set(databaseId!, connection)
133116
},
134117
serverVersion() {
135118
return '16.3'
@@ -138,13 +121,11 @@ tcpServer.on('connection', (socket) => {
138121
const websocket = websocketConnections.get(databaseId!)
139122

140123
if (!websocket) {
141-
connection.sendError({
124+
throw BackendError.create({
142125
code: 'XX000',
143126
message: 'the browser is not sharing the database',
144127
severity: 'FATAL',
145128
})
146-
connection.end()
147-
return
148129
}
149130

150131
const clientIpMessage = createStartupMessage('postgres', 'postgres', {
@@ -160,13 +141,11 @@ tcpServer.on('connection', (socket) => {
160141
const websocket = websocketConnections.get(databaseId!)
161142

162143
if (!websocket) {
163-
connection.sendError({
144+
throw BackendError.create({
164145
code: 'XX000',
165146
message: 'the browser is not sharing the database',
166147
severity: 'FATAL',
167148
})
168-
connection.end()
169-
return
170149
}
171150

172151
debug('tcp message', { message })

apps/browser-proxy/src/tls.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import { Buffer } from 'node:buffer'
22
import { GetObjectCommand, S3Client } from '@aws-sdk/client-s3'
3+
import pMemoize from 'p-memoize'
4+
import ExpiryMap from 'expiry-map'
5+
import type { Server } from 'node:https'
36

47
const s3Client = new S3Client({ forcePathStyle: true })
58

6-
export async function getTls() {
9+
async function _getTls() {
710
const cert = await s3Client
811
.send(
912
new GetObjectCommand({
@@ -31,3 +34,12 @@ export async function getTls() {
3134
key: Buffer.from(key),
3235
}
3336
}
37+
38+
// cache the TLS certificate for 1 week
39+
const cache = new ExpiryMap(1000 * 60 * 60 * 24 * 7)
40+
export const getTls = pMemoize(_getTls, { cache })
41+
42+
export async function setSecureContext(httpsServer: Server) {
43+
const tlsOptions = await getTls()
44+
httpsServer.setSecureContext(tlsOptions)
45+
}

package-lock.json

Lines changed: 76 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)