Skip to content

Commit 62a4bbc

Browse files
authored
Merge pull request #105 from supabase-community/feat/db-sharing-pg-dump
Fix pg_dump for databases created with PGlite < v0.2.9
2 parents 58fe271 + 68d37be commit 62a4bbc

File tree

10 files changed

+609
-146
lines changed

10 files changed

+609
-146
lines changed

apps/browser-proxy/src/index.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,22 @@ httpsServer.listen(443, () => {
1616
tcpServer.listen(5432, () => {
1717
console.log('tcp server listening on port 5432')
1818
})
19+
20+
const shutdown = async () => {
21+
await Promise.allSettled([
22+
new Promise<void>((res) =>
23+
httpsServer.close(() => {
24+
res()
25+
})
26+
),
27+
new Promise<void>((res) =>
28+
tcpServer.close(() => {
29+
res()
30+
})
31+
),
32+
])
33+
process.exit(0)
34+
}
35+
36+
process.on('SIGTERM', shutdown)
37+
process.on('SIGINT', shutdown)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export const VECTOR_OID = 99999
2+
export const FIRST_NORMAL_OID = 16384
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import { VECTOR_OID } from './constants.ts'
2+
import { parseDataRowFields, parseRowDescription } from './utils.ts'
3+
4+
export function isGetExtensionMembershipQuery(message: Uint8Array): boolean {
5+
// Check if it's a SimpleQuery message (starts with 'Q')
6+
if (message[0] !== 0x51) {
7+
// 'Q' in ASCII
8+
return false
9+
}
10+
11+
const query =
12+
"SELECT classid, objid, refobjid FROM pg_depend WHERE refclassid = 'pg_extension'::regclass AND deptype = 'e' ORDER BY 3"
13+
14+
// Skip the message type (1 byte) and message length (4 bytes)
15+
const messageString = new TextDecoder().decode(message.slice(5))
16+
17+
// Trim any trailing null character
18+
const trimmedMessage = messageString.replace(/\0+$/, '')
19+
20+
// Check if the message exactly matches the query
21+
return trimmedMessage === query
22+
}
23+
24+
export function patchGetExtensionMembershipResult(data: Uint8Array, vectorOid: string): Uint8Array {
25+
let offset = 0
26+
const messages: Uint8Array[] = []
27+
let isDependencyTable = false
28+
let objidIndex = -1
29+
let refobjidIndex = -1
30+
let patchedRowCount = 0
31+
let totalRowsProcessed = 0
32+
33+
const expectedColumns = ['classid', 'objid', 'refobjid']
34+
35+
while (offset < data.length) {
36+
const messageType = data[offset]
37+
const messageLength = new DataView(data.buffer, data.byteOffset + offset + 1, 4).getUint32(
38+
0,
39+
false
40+
)
41+
const message = data.subarray(offset, offset + messageLength + 1)
42+
43+
if (messageType === 0x54) {
44+
// RowDescription
45+
const columnNames = parseRowDescription(message)
46+
isDependencyTable =
47+
columnNames.length === 3 && columnNames.every((col) => expectedColumns.includes(col))
48+
if (isDependencyTable) {
49+
objidIndex = columnNames.indexOf('objid')
50+
refobjidIndex = columnNames.indexOf('refobjid')
51+
}
52+
} else if (messageType === 0x44 && isDependencyTable) {
53+
// DataRow
54+
const fields = parseDataRowFields(message)
55+
totalRowsProcessed++
56+
57+
if (fields.length === 3) {
58+
const refobjid = fields[refobjidIndex]!.value
59+
60+
if (refobjid === vectorOid) {
61+
const patchedMessage = patchDependencyRow(message, refobjidIndex)
62+
messages.push(patchedMessage)
63+
patchedRowCount++
64+
offset += messageLength + 1
65+
continue
66+
}
67+
}
68+
}
69+
70+
messages.push(message)
71+
offset += messageLength + 1
72+
}
73+
74+
return new Uint8Array(
75+
messages.reduce((acc, val) => {
76+
const combined = new Uint8Array(acc.length + val.length)
77+
combined.set(acc)
78+
combined.set(val, acc.length)
79+
return combined
80+
}, new Uint8Array())
81+
)
82+
}
83+
84+
function patchDependencyRow(message: Uint8Array, refobjidIndex: number): Uint8Array {
85+
const newArray = new Uint8Array(message)
86+
let offset = 7 // Start after message type (1 byte), message length (4 bytes), and field count (2 bytes)
87+
88+
// Navigate to the refobjid field
89+
for (let i = 0; i < refobjidIndex; i++) {
90+
const fieldLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
91+
offset += 4 // Skip the length field
92+
if (fieldLength > 0) {
93+
offset += fieldLength // Skip the field value
94+
}
95+
}
96+
97+
// Now we're at the start of the refobjid field
98+
const refobjidLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
99+
offset += 4 // Move past the length field
100+
101+
const encoder = new TextEncoder()
102+
103+
// Write the new OID value
104+
const newRefobjidBytes = encoder.encode(VECTOR_OID.toString().padStart(refobjidLength, '0'))
105+
newArray.set(newRefobjidBytes, offset)
106+
107+
return newArray
108+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import { VECTOR_OID } from './constants.ts'
2+
import { parseDataRowFields, parseRowDescription } from './utils.ts'
3+
4+
export function isGetExtensionsQuery(message: Uint8Array): boolean {
5+
// Check if it's a SimpleQuery message (starts with 'Q')
6+
if (message[0] !== 0x51) {
7+
// 'Q' in ASCII
8+
return false
9+
}
10+
11+
const query =
12+
'SELECT x.tableoid, x.oid, x.extname, n.nspname, x.extrelocatable, x.extversion, x.extconfig, x.extcondition FROM pg_extension x JOIN pg_namespace n ON n.oid = x.extnamespace'
13+
14+
// Skip the message type (1 byte) and message length (4 bytes)
15+
const messageString = new TextDecoder().decode(message.slice(5))
16+
17+
// Trim any trailing null character
18+
const trimmedMessage = messageString.replace(/\0+$/, '')
19+
20+
// Check if the message exactly matches the query
21+
return trimmedMessage === query
22+
}
23+
24+
export function patchGetExtensionsResult(data: Uint8Array) {
25+
let offset = 0
26+
const messages: Uint8Array[] = []
27+
let isVectorExtensionTable = false
28+
let oidColumnIndex = -1
29+
let extnameColumnIndex = -1
30+
let vectorOid: string | null = null
31+
32+
const expectedColumns = [
33+
'tableoid',
34+
'oid',
35+
'extname',
36+
'nspname',
37+
'extrelocatable',
38+
'extversion',
39+
'extconfig',
40+
'extcondition',
41+
]
42+
43+
while (offset < data.length) {
44+
const messageType = data[offset]
45+
const messageLength = new DataView(data.buffer, data.byteOffset + offset + 1, 4).getUint32(
46+
0,
47+
false
48+
)
49+
50+
const message = data.subarray(offset, offset + messageLength + 1)
51+
52+
if (messageType === 0x54) {
53+
// RowDescription
54+
const columnNames = parseRowDescription(message)
55+
56+
isVectorExtensionTable =
57+
columnNames.length === expectedColumns.length &&
58+
columnNames.every((col) => expectedColumns.includes(col))
59+
60+
if (isVectorExtensionTable) {
61+
oidColumnIndex = columnNames.indexOf('oid')
62+
extnameColumnIndex = columnNames.indexOf('extname')
63+
}
64+
} else if (messageType === 0x44 && isVectorExtensionTable) {
65+
// DataRow
66+
const fields = parseDataRowFields(message)
67+
if (fields[extnameColumnIndex]?.value === 'vector') {
68+
vectorOid = fields[oidColumnIndex]!.value!
69+
const patchedMessage = patchOidField(message, oidColumnIndex, fields)
70+
messages.push(patchedMessage)
71+
offset += messageLength + 1
72+
continue
73+
}
74+
}
75+
76+
messages.push(message)
77+
offset += messageLength + 1
78+
}
79+
80+
return {
81+
message: Buffer.concat(messages),
82+
vectorOid,
83+
}
84+
}
85+
86+
function patchOidField(
87+
message: Uint8Array,
88+
oidIndex: number,
89+
fields: { value: string | null; length: number }[]
90+
): Uint8Array {
91+
const oldOidField = fields[oidIndex]!
92+
const newOid = VECTOR_OID.toString().padStart(oldOidField.length, '0')
93+
94+
const newArray = new Uint8Array(message)
95+
96+
let offset = 7 // Start after message type (1 byte), message length (4 bytes), and field count (2 bytes)
97+
98+
// Navigate to the OID field
99+
for (let i = 0; i < oidIndex; i++) {
100+
const fieldLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
101+
offset += 4 // Skip the length field
102+
if (fieldLength > 0) {
103+
offset += fieldLength // Skip the field value
104+
}
105+
}
106+
107+
// Now we're at the start of the OID field
108+
const oidLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
109+
offset += 4 // Move past the length field
110+
111+
// Ensure the new OID fits in the allocated space
112+
if (newOid.length !== oidLength) {
113+
console.warn(
114+
`New OID length (${newOid.length}) doesn't match the original length (${oidLength}). Skipping patch.`
115+
)
116+
return message
117+
}
118+
119+
// Write the new OID value
120+
for (let i = 0; i < oidLength; i++) {
121+
newArray[offset + i] = newOid.charCodeAt(i)
122+
}
123+
124+
return newArray
125+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import type { ClientParameters } from 'pg-gateway'
2+
import { isGetExtensionsQuery, patchGetExtensionsResult } from './get-extensions-query.ts'
3+
import {
4+
isGetExtensionMembershipQuery,
5+
patchGetExtensionMembershipResult,
6+
} from './get-extension-membership-query.ts'
7+
import { FIRST_NORMAL_OID } from './constants.ts'
8+
import type { Socket } from 'node:net'
9+
10+
type ConnectionId = string
11+
12+
type State =
13+
| { step: 'wait-for-get-extensions-query' }
14+
| { step: 'get-extensions-query-received' }
15+
| { step: 'wait-for-get-extension-membership-query'; vectorOid: string }
16+
| { step: 'get-extension-membership-query-received'; vectorOid: string }
17+
| { step: 'complete' }
18+
19+
/**
20+
* Middleware to patch pg_dump results for PGlite < v0.2.8
21+
* PGlite < v0.2.8 has a bug in which userland extensions are not dumped because their oid is lower than FIRST_NORMAL_OID
22+
* This middleware patches the results of the get_extensions and get_extension_membership queries to increase the oid of the `vector` extension so it can be dumped
23+
* For more context, see: https://github.com/electric-sql/pglite/issues/352
24+
*/
25+
class PgDumpMiddleware {
26+
private state: Map<ConnectionId, State> = new Map()
27+
28+
constructor() {}
29+
30+
client(
31+
socket: Socket,
32+
connectionId: string,
33+
context: {
34+
clientParams?: ClientParameters
35+
},
36+
message: Uint8Array
37+
) {
38+
if (context.clientParams?.application_name !== 'pg_dump') {
39+
return message
40+
}
41+
42+
if (!this.state.has(connectionId)) {
43+
this.state.set(connectionId, { step: 'wait-for-get-extensions-query' })
44+
socket.on('close', () => {
45+
this.state.delete(connectionId)
46+
})
47+
}
48+
49+
const connectionState = this.state.get(connectionId)!
50+
51+
switch (connectionState.step) {
52+
case 'wait-for-get-extensions-query':
53+
// https://github.com/postgres/postgres/blob/a19f83f87966f763991cc76404f8e42a36e7e842/src/bin/pg_dump/pg_dump.c#L5834-L5837
54+
if (isGetExtensionsQuery(message)) {
55+
this.state.set(connectionId, { step: 'get-extensions-query-received' })
56+
}
57+
break
58+
case 'wait-for-get-extension-membership-query':
59+
// https://github.com/postgres/postgres/blob/a19f83f87966f763991cc76404f8e42a36e7e842/src/bin/pg_dump/pg_dump.c#L18173-L18178
60+
if (isGetExtensionMembershipQuery(message)) {
61+
this.state.set(connectionId, {
62+
step: 'get-extension-membership-query-received',
63+
vectorOid: connectionState.vectorOid,
64+
})
65+
}
66+
break
67+
}
68+
69+
return message
70+
}
71+
72+
server(
73+
connectionId: string,
74+
context: {
75+
clientParams?: ClientParameters
76+
},
77+
message: Uint8Array
78+
) {
79+
if (context.clientParams?.application_name !== 'pg_dump' || !this.state.has(connectionId)) {
80+
return message
81+
}
82+
83+
const connectionState = this.state.get(connectionId)!
84+
85+
switch (connectionState.step) {
86+
case 'get-extensions-query-received':
87+
const patched = patchGetExtensionsResult(message)
88+
if (patched.vectorOid) {
89+
if (parseInt(patched.vectorOid) >= FIRST_NORMAL_OID) {
90+
this.state.set(connectionId, {
91+
step: 'complete',
92+
})
93+
} else {
94+
this.state.set(connectionId, {
95+
step: 'wait-for-get-extension-membership-query',
96+
vectorOid: patched.vectorOid,
97+
})
98+
}
99+
}
100+
return patched.message
101+
case 'get-extension-membership-query-received':
102+
const patchedMessage = patchGetExtensionMembershipResult(message, connectionState.vectorOid)
103+
this.state.set(connectionId, { step: 'complete' })
104+
return patchedMessage
105+
default:
106+
return message
107+
}
108+
}
109+
}
110+
111+
export const pgDumpMiddleware = new PgDumpMiddleware()

0 commit comments

Comments
 (0)