Skip to content

Commit c47a434

Browse files
authored
Merge pull request #103 from supabase-community/feat/db-sharing-connectionid
Database sharing enhancements
2 parents 294f9c8 + 62a4bbc commit c47a434

22 files changed

+1054
-367
lines changed

apps/browser-proxy/.env.example

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ AWS_REGION=us-east-1
66
LOGFLARE_SOURCE_URL="<logflare-source-url>"
77
# enable PROXY protocol support
88
#PROXIED=true
9-
WILDCARD_DOMAIN=browser.staging.db.build
9+
SUPABASE_URL="<supabase-url>"
10+
SUPABASE_ANON_KEY="<supabase-anon-key>"
11+
WILDCARD_DOMAIN=browser.staging.db.build

apps/browser-proxy/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
},
99
"dependencies": {
1010
"@aws-sdk/client-s3": "^3.645.0",
11+
"@supabase/supabase-js": "^2.45.4",
1112
"debug": "^4.3.7",
1213
"expiry-map": "^2.0.0",
1314
"findhit-proxywrap": "^0.3.13",
15+
"nanoid": "^5.0.7",
1416
"p-memoize": "^7.1.1",
1517
"pg-gateway": "^0.3.0-beta.3",
1618
"ws": "^8.18.0"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import type { PostgresConnection } from 'pg-gateway'
2+
import type { WebSocket } from 'ws'
3+
4+
type DatabaseId = string
5+
type ConnectionId = string
6+
7+
class ConnectionManager {
8+
private socketsByDatabase: Map<DatabaseId, ConnectionId> = new Map()
9+
private sockets: Map<ConnectionId, PostgresConnection> = new Map()
10+
private websockets: Map<DatabaseId, WebSocket> = new Map()
11+
12+
constructor() {}
13+
14+
public hasSocketForDatabase(databaseId: DatabaseId) {
15+
return this.socketsByDatabase.has(databaseId)
16+
}
17+
18+
public getSocket(connectionId: ConnectionId) {
19+
return this.sockets.get(connectionId)
20+
}
21+
22+
public setSocket(databaseId: DatabaseId, connectionId: ConnectionId, socket: PostgresConnection) {
23+
this.sockets.set(connectionId, socket)
24+
this.socketsByDatabase.set(databaseId, connectionId)
25+
}
26+
27+
public deleteSocketForDatabase(databaseId: DatabaseId) {
28+
const connectionId = this.socketsByDatabase.get(databaseId)
29+
this.socketsByDatabase.delete(databaseId)
30+
if (connectionId) {
31+
this.sockets.delete(connectionId)
32+
}
33+
}
34+
35+
public hasWebsocket(databaseId: DatabaseId) {
36+
return this.websockets.has(databaseId)
37+
}
38+
39+
public getWebsocket(databaseId: DatabaseId) {
40+
return this.websockets.get(databaseId)
41+
}
42+
43+
public setWebsocket(databaseId: DatabaseId, websocket: WebSocket) {
44+
this.websockets.set(databaseId, websocket)
45+
}
46+
47+
public deleteWebsocket(databaseId: DatabaseId) {
48+
this.websockets.delete(databaseId)
49+
this.deleteSocketForDatabase(databaseId)
50+
}
51+
}
52+
53+
export const connectionManager = new ConnectionManager()

apps/browser-proxy/src/create-message.ts

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export function createStartupMessage(
22
user: string,
33
database: string,
44
additionalParams: Record<string, string> = {}
5-
): ArrayBuffer {
5+
): Uint8Array {
66
const encoder = new TextEncoder()
77

88
// Protocol version number (3.0)
@@ -22,9 +22,8 @@ export function createStartupMessage(
2222
}
2323
messageLength += 1 // Null terminator
2424

25-
const message = new ArrayBuffer(4 + messageLength)
26-
const view = new DataView(message)
27-
const uint8Array = new Uint8Array(message)
25+
const uint8Array = new Uint8Array(4 + messageLength)
26+
const view = new DataView(uint8Array.buffer)
2827

2928
let offset = 0
3029
view.setInt32(offset, messageLength + 4, false) // Total message length (including itself)
@@ -44,5 +43,13 @@ export function createStartupMessage(
4443

4544
uint8Array.set([0], offset) // Final null terminator
4645

47-
return message
46+
return uint8Array
47+
}
48+
49+
export function createTerminateMessage(): Uint8Array {
50+
const uint8Array = new Uint8Array(5)
51+
const view = new DataView(uint8Array.buffer)
52+
view.setUint8(0, 'X'.charCodeAt(0))
53+
view.setUint32(1, 4, false)
54+
return uint8Array
4855
}

apps/browser-proxy/src/debug.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import createDebug from 'debug'
2+
3+
createDebug.formatters.e = (fn) => fn()
4+
5+
export const debug = createDebug('browser-proxy')

apps/browser-proxy/src/index.ts

Lines changed: 25 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,179 +1,12 @@
1-
import * as nodeNet from 'node:net'
2-
import * as https from 'node:https'
3-
import { BackendError, PostgresConnection } from 'pg-gateway'
4-
import { fromNodeSocket } from 'pg-gateway/node'
5-
import { WebSocketServer, type WebSocket } from 'ws'
6-
import makeDebug from 'debug'
7-
import { extractDatabaseId, isValidServername } from './servername.ts'
8-
import { getTls, setSecureContext } from './tls.ts'
9-
import { createStartupMessage } from './create-message.ts'
10-
import { extractIP } from './extract-ip.ts'
11-
import {
12-
DatabaseShared,
13-
DatabaseUnshared,
14-
logEvent,
15-
UserConnected,
16-
UserDisconnected,
17-
} from './telemetry.ts'
1+
import { httpsServer } from './websocket-server.ts'
2+
import { tcpServer } from './tcp-server.ts'
183

19-
const debug = makeDebug('browser-proxy')
20-
21-
const tcpConnections = new Map<string, PostgresConnection>()
22-
const websocketConnections = new Map<string, WebSocket>()
23-
24-
const httpsServer = https.createServer({
25-
SNICallback: (servername, callback) => {
26-
debug('SNICallback', servername)
27-
if (isValidServername(servername)) {
28-
debug('SNICallback', 'valid')
29-
callback(null)
30-
} else {
31-
debug('SNICallback', 'invalid')
32-
callback(new Error('invalid SNI'))
33-
}
34-
},
4+
process.on('unhandledRejection', (reason, promise) => {
5+
console.error({ location: 'unhandledRejection', reason, promise })
356
})
36-
await setSecureContext(httpsServer)
37-
// reset the secure context every week to pick up any new TLS certificates
38-
setInterval(() => setSecureContext(httpsServer), 1000 * 60 * 60 * 24 * 7)
39-
40-
const websocketServer = new WebSocketServer({
41-
server: httpsServer,
42-
})
43-
44-
websocketServer.on('error', (error) => {
45-
debug('websocket server error', error)
46-
})
47-
48-
websocketServer.on('connection', (socket, request) => {
49-
debug('websocket connection')
50-
51-
const host = request.headers.host
52-
53-
if (!host) {
54-
debug('No host header present')
55-
socket.close()
56-
return
57-
}
587

59-
const databaseId = extractDatabaseId(host)
60-
61-
if (websocketConnections.has(databaseId)) {
62-
socket.send('sorry, too many clients already')
63-
socket.close()
64-
return
65-
}
66-
67-
websocketConnections.set(databaseId, socket)
68-
69-
logEvent(new DatabaseShared({ databaseId }))
70-
71-
socket.on('message', (data: Buffer) => {
72-
debug('websocket message', data.toString('hex'))
73-
const tcpConnection = tcpConnections.get(databaseId)
74-
tcpConnection?.streamWriter?.write(data)
75-
})
76-
77-
socket.on('close', () => {
78-
websocketConnections.delete(databaseId)
79-
logEvent(new DatabaseUnshared({ databaseId }))
80-
})
81-
})
82-
83-
// we need to use proxywrap to make our tcp server to enable the PROXY protocol support
84-
const net = (
85-
process.env.PROXIED ? (await import('findhit-proxywrap')).default.proxy(nodeNet) : nodeNet
86-
) as typeof nodeNet
87-
88-
const tcpServer = net.createServer()
89-
90-
tcpServer.on('connection', async (socket) => {
91-
let databaseId: string | undefined
92-
93-
const connection = await fromNodeSocket(socket, {
94-
tls: getTls,
95-
onTlsUpgrade(state) {
96-
if (!state.tlsInfo?.serverName || !isValidServername(state.tlsInfo.serverName)) {
97-
throw BackendError.create({
98-
code: '08006',
99-
message: 'invalid SNI',
100-
severity: 'FATAL',
101-
})
102-
}
103-
104-
const _databaseId = extractDatabaseId(state.tlsInfo.serverName!)
105-
106-
if (!websocketConnections.has(_databaseId!)) {
107-
throw BackendError.create({
108-
code: 'XX000',
109-
message: 'the browser is not sharing the database',
110-
severity: 'FATAL',
111-
})
112-
}
113-
114-
if (tcpConnections.has(_databaseId)) {
115-
throw BackendError.create({
116-
code: '53300',
117-
message: 'sorry, too many clients already',
118-
severity: 'FATAL',
119-
})
120-
}
121-
122-
// only set the databaseId after we've verified the connection
123-
databaseId = _databaseId
124-
tcpConnections.set(databaseId!, connection)
125-
logEvent(new UserConnected({ databaseId }))
126-
},
127-
serverVersion() {
128-
return '16.3'
129-
},
130-
onAuthenticated() {
131-
const websocket = websocketConnections.get(databaseId!)
132-
133-
if (!websocket) {
134-
throw BackendError.create({
135-
code: 'XX000',
136-
message: 'the browser is not sharing the database',
137-
severity: 'FATAL',
138-
})
139-
}
140-
141-
const clientIpMessage = createStartupMessage('postgres', 'postgres', {
142-
client_ip: extractIP(socket.remoteAddress!),
143-
})
144-
websocket.send(clientIpMessage)
145-
},
146-
onMessage(message, state) {
147-
if (!state.isAuthenticated) {
148-
return
149-
}
150-
151-
const websocket = websocketConnections.get(databaseId!)
152-
153-
if (!websocket) {
154-
throw BackendError.create({
155-
code: 'XX000',
156-
message: 'the browser is not sharing the database',
157-
severity: 'FATAL',
158-
})
159-
}
160-
161-
debug('tcp message', { message })
162-
websocket.send(message)
163-
164-
// return an empty buffer to indicate that the message has been handled
165-
return new Uint8Array()
166-
},
167-
})
168-
169-
socket.on('close', () => {
170-
if (databaseId) {
171-
tcpConnections.delete(databaseId)
172-
logEvent(new UserDisconnected({ databaseId }))
173-
const websocket = websocketConnections.get(databaseId)
174-
websocket?.send(createStartupMessage('postgres', 'postgres', { client_ip: '' }))
175-
}
176-
})
8+
process.on('uncaughtException', (error) => {
9+
console.error({ location: 'uncaughtException', error })
17710
})
17811

17912
httpsServer.listen(443, () => {
@@ -183,3 +16,22 @@ httpsServer.listen(443, () => {
18316
tcpServer.listen(5432, () => {
18417
console.log('tcp server listening on port 5432')
18518
})
19+
20+
const shutdown = async () => {
21+
await Promise.allSettled([
22+
new Promise<void>((res) =>
23+
httpsServer.close(() => {
24+
res()
25+
})
26+
),
27+
new Promise<void>((res) =>
28+
tcpServer.close(() => {
29+
res()
30+
})
31+
),
32+
])
33+
process.exit(0)
34+
}
35+
36+
process.on('SIGTERM', shutdown)
37+
process.on('SIGINT', shutdown)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export const VECTOR_OID = 99999
2+
export const FIRST_NORMAL_OID = 16384

0 commit comments

Comments
 (0)