Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ $ sqlite3 bookstore.sqlite3 < examples/bookstore/data.sql
```
$ echo -n "topsecret" > test.token
$ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3
{"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","addr":":8080"}
{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","addr":":8080"}
... <omitted logs>
```

### Start server with Unix domain socket

```
$ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 --http-socket /tmp/sqlite-rest.sock
{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","socket":"/tmp/sqlite-rest.sock"}
... <omitted logs>
```

Expand Down Expand Up @@ -178,4 +186,4 @@ $ sqlite-rest migrate --db-dsn ./bookstore.sqlite3 --direction down --step 1 ./e

## License

MIT
MIT
63 changes: 56 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"

Expand All @@ -27,6 +29,7 @@ const (
type ServerOptions struct {
Logger logr.Logger
Addr string
SocketPath string
AuthOptions ServerAuthOptions
SecurityOptions ServerSecurityOptions
Queryer sqlx.QueryerContext
Expand All @@ -35,6 +38,7 @@ type ServerOptions struct {

func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) {
fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen address")
fs.StringVar(&opts.SocketPath, "http-socket", "", "server listen unix socket path. If set, http-addr will be ignored")

opts.AuthOptions.bindCLIFlags(fs)
opts.SecurityOptions.bindCLIFlags(fs)
Expand All @@ -52,7 +56,11 @@ func (opts *ServerOptions) defaults() error {
opts.Logger = logr.Discard()
}

if opts.Addr == "" {
if opts.SocketPath != "" {
opts.Addr = ""
}

if opts.Addr == "" && opts.SocketPath == "" {
opts.Addr = ":8080"
}

Expand All @@ -68,10 +76,12 @@ func (opts *ServerOptions) defaults() error {
}

type dbServer struct {
logger logr.Logger
server *http.Server
queryer sqlx.QueryerContext
execer sqlx.ExecerContext
logger logr.Logger
server *http.Server
listener net.Listener
socket string
queryer sqlx.QueryerContext
execer sqlx.ExecerContext
}

func NewServer(opts *ServerOptions) (*dbServer, error) {
Expand All @@ -86,6 +96,7 @@ func NewServer(opts *ServerOptions) (*dbServer, error) {
// TODO: make it configurable
ReadHeaderTimeout: 5 * time.Second,
},
socket: opts.SocketPath,
queryer: opts.Queryer,
execer: opts.Execer,
}
Expand Down Expand Up @@ -128,15 +139,53 @@ func NewServer(opts *ServerOptions) (*dbServer, error) {
}

func (server *dbServer) Start(done <-chan struct{}) {
go server.server.ListenAndServe()
if server.socket != "" {
sockDir := filepath.Dir(server.socket)
if sockDir != "" && sockDir != "." {
if err := os.MkdirAll(sockDir, 0755); err != nil {
server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket)
return
}
}

if err := os.RemoveAll(server.socket); err != nil {
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using RemoveAll to clean up a socket file is dangerous and could accidentally delete an entire directory if the socket path is misconfigured. Use os.Remove instead, which will fail safely if the path is a directory.

Suggested change
if err := os.RemoveAll(server.socket); err != nil {
if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) {

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using os.RemoveAll on a socket path could be dangerous if the path points to a directory. Consider using os.Remove instead, which will fail safely if the path is a directory. This prevents accidentally deleting an entire directory structure if misconfigured.

Suggested change
if err := os.RemoveAll(server.socket); err != nil {
if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) {

Copilot uses AI. Check for mistakes.
server.logger.Error(err, "failed to remove stale unix socket", "socket", server.socket)
return
}

l, err := net.Listen("unix", server.socket)
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The socket file permissions (0666 by default for Unix sockets) could be a security concern. Consider setting explicit restrictive permissions after creating the socket using os.Chmod to limit access to the socket file (e.g., 0600 for owner-only or 0660 for owner and group).

Copilot uses AI. Check for mistakes.
if err != nil {
server.logger.Error(err, "failed to listen on unix socket", "socket", server.socket)
return
}
server.listener = l

go server.server.Serve(l)
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error returned by server.Serve is not captured or logged. While Serve is called in a goroutine, consider capturing and logging the error for better observability, especially to distinguish between expected shutdown (http.ErrServerClosed) and unexpected errors.

Copilot uses AI. Check for mistakes.
server.logger.Info("server started", "socket", server.socket)
} else {
l, err := net.Listen("tcp", server.server.Addr)
if err != nil {
server.logger.Error(err, "failed to listen on tcp address", "addr", server.server.Addr)
return
}
server.listener = l

go server.server.Serve(l)
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error returned by server.Serve is not captured or logged. While Serve is called in a goroutine, consider capturing and logging the error for better observability, especially to distinguish between expected shutdown (http.ErrServerClosed) and unexpected errors.

Copilot uses AI. Check for mistakes.
server.logger.Info("server started", "addr", server.server.Addr)
}

server.logger.Info("server started", "addr", server.server.Addr)
<-done

server.logger.Info("shutting down server")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
Copy link

Copilot AI Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The listener is not explicitly closed during shutdown. Although server.Shutdown will close the listener, explicitly closing it ensures proper cleanup and makes the code more maintainable. Add a check and close the listener before or after the Shutdown call.

Suggested change
defer cancel()
defer cancel()
if server.listener != nil {
if err := server.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
server.logger.Error(err, "failed to close listener")
}
}

Copilot uses AI. Check for mistakes.
server.server.Shutdown(shutdownCtx)

if server.socket != "" {
if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) {
server.logger.Error(err, "failed to clean up unix socket", "socket", server.socket)
}
}
}

func (server *dbServer) responseHeader(w http.ResponseWriter, statusCode int) {
Expand Down
108 changes: 108 additions & 0 deletions server_socket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package main

import (
"context"
"encoding/json"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
)

func TestServerWithUnixSocket(t *testing.T) {
t.Parallel()

dir := t.TempDir()
socketPath := filepath.Join(dir, "sqlite-rest.sock")

dbPath := filepath.Join(dir, "test.db")
db, err := sqlx.Open("sqlite3", dbPath)
if err != nil {
t.Fatal(err)
}
defer db.Close()

_, err = db.Exec("CREATE TABLE test (id int)")
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO test (id) VALUES (1)`)
if err != nil {
t.Fatal(err)
}

serverOpts := &ServerOptions{
Logger: createTestLogger(t).WithName("test"),
Queryer: db,
Execer: db,
SocketPath: socketPath,
}
serverOpts.AuthOptions.disableAuth = true
serverOpts.SecurityOptions.EnabledTableOrViews = []string{"test"}

server, err := NewServer(serverOpts)
if err != nil {
t.Fatal(err)
}

done := make(chan struct{})
serverDone := make(chan struct{})
go func() {
server.Start(done)
close(serverDone)
}()

assert.Eventually(t, func() bool {
_, err := os.Stat(socketPath)
return err == nil
}, 5*time.Second, 100*time.Millisecond)

client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
},
},
}

req, err := http.NewRequest(http.MethodGet, "http://unix/test", nil)
if err != nil {
t.Fatal(err)
}

resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode)

var rows []map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&rows)
assert.NoError(t, err)
assert.Len(t, rows, 1)
assert.EqualValues(t, 1, rows[0]["id"])

close(done)
assert.Eventually(t, func() bool {
select {
case <-serverDone:
return true
default:
return false
}
}, 2*time.Second, 50*time.Millisecond)

assert.Eventually(t, func() bool {
_, err := os.Stat(socketPath)
return errors.Is(err, os.ErrNotExist)
}, 5*time.Second, 100*time.Millisecond)
}