diff --git a/README.md b/README.md index 2ef0ef4..53279e5 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,15 @@ $ sqlite3 bookstore.sqlite3 < examples/bookstore/data.sql ``` $ echo -n "topsecret" > test.token $ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 -{"level":"info","ts":1672528510.825417,"logger":"db-server","caller":"sqlite-rest/server.go:121","msg":"server started","addr":":8080"} +{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","addr":":8080"} +... +``` + +### Start server with Unix domain socket + +``` +$ sqlite-rest serve --auth-token-file test.token --security-allow-table books --db-dsn ./bookstore.sqlite3 --http-socket /tmp/sqlite-rest.sock +{"level":"info","ts":1672528510.825417,"logger":"db-server","msg":"server started","socket":"/tmp/sqlite-rest.sock"} ... ``` @@ -178,4 +186,4 @@ $ sqlite-rest migrate --db-dsn ./bookstore.sqlite3 --direction down --step 1 ./e ## License -MIT \ No newline at end of file +MIT diff --git a/server.go b/server.go index 6be2857..c10458d 100644 --- a/server.go +++ b/server.go @@ -5,9 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" @@ -27,6 +29,7 @@ const ( type ServerOptions struct { Logger logr.Logger Addr string + SocketPath string AuthOptions ServerAuthOptions SecurityOptions ServerSecurityOptions Queryer sqlx.QueryerContext @@ -35,6 +38,7 @@ type ServerOptions struct { func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) { fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen address") + fs.StringVar(&opts.SocketPath, "http-socket", "", "server listen unix socket path. If set, http-addr will be ignored") opts.AuthOptions.bindCLIFlags(fs) opts.SecurityOptions.bindCLIFlags(fs) @@ -52,7 +56,11 @@ func (opts *ServerOptions) defaults() error { opts.Logger = logr.Discard() } - if opts.Addr == "" { + if opts.SocketPath != "" { + opts.Addr = "" + } + + if opts.Addr == "" && opts.SocketPath == "" { opts.Addr = ":8080" } @@ -68,10 +76,12 @@ func (opts *ServerOptions) defaults() error { } type dbServer struct { - logger logr.Logger - server *http.Server - queryer sqlx.QueryerContext - execer sqlx.ExecerContext + logger logr.Logger + server *http.Server + listener net.Listener + socket string + queryer sqlx.QueryerContext + execer sqlx.ExecerContext } func NewServer(opts *ServerOptions) (*dbServer, error) { @@ -86,6 +96,7 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { // TODO: make it configurable ReadHeaderTimeout: 5 * time.Second, }, + socket: opts.SocketPath, queryer: opts.Queryer, execer: opts.Execer, } @@ -128,15 +139,53 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { } func (server *dbServer) Start(done <-chan struct{}) { - go server.server.ListenAndServe() + if server.socket != "" { + sockDir := filepath.Dir(server.socket) + if sockDir != "" && sockDir != "." { + if err := os.MkdirAll(sockDir, 0755); err != nil { + server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket) + return + } + } + + if err := os.RemoveAll(server.socket); err != nil { + server.logger.Error(err, "failed to remove stale unix socket", "socket", server.socket) + return + } + + l, err := net.Listen("unix", server.socket) + if err != nil { + server.logger.Error(err, "failed to listen on unix socket", "socket", server.socket) + return + } + server.listener = l + + go server.server.Serve(l) + server.logger.Info("server started", "socket", server.socket) + } else { + l, err := net.Listen("tcp", server.server.Addr) + if err != nil { + server.logger.Error(err, "failed to listen on tcp address", "addr", server.server.Addr) + return + } + server.listener = l + + go server.server.Serve(l) + server.logger.Info("server started", "addr", server.server.Addr) + } - server.logger.Info("server started", "addr", server.server.Addr) <-done server.logger.Info("shutting down server") shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() server.server.Shutdown(shutdownCtx) + + if server.socket != "" { + if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) { + server.logger.Error(err, "failed to clean up unix socket", "socket", server.socket) + } + } } func (server *dbServer) responseHeader(w http.ResponseWriter, statusCode int) { diff --git a/server_socket_test.go b/server_socket_test.go new file mode 100644 index 0000000..bbc5c9e --- /dev/null +++ b/server_socket_test.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" +) + +func TestServerWithUnixSocket(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "sqlite-rest.sock") + + dbPath := filepath.Join(dir, "test.db") + db, err := sqlx.Open("sqlite3", dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE test (id int)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec(`INSERT INTO test (id) VALUES (1)`) + if err != nil { + t.Fatal(err) + } + + serverOpts := &ServerOptions{ + Logger: createTestLogger(t).WithName("test"), + Queryer: db, + Execer: db, + SocketPath: socketPath, + } + serverOpts.AuthOptions.disableAuth = true + serverOpts.SecurityOptions.EnabledTableOrViews = []string{"test"} + + server, err := NewServer(serverOpts) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + serverDone := make(chan struct{}) + go func() { + server.Start(done) + close(serverDone) + }() + + assert.Eventually(t, func() bool { + _, err := os.Stat(socketPath) + return err == nil + }, 5*time.Second, 100*time.Millisecond) + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + }, + }, + } + + req, err := http.NewRequest(http.MethodGet, "http://unix/test", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var rows []map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&rows) + assert.NoError(t, err) + assert.Len(t, rows, 1) + assert.EqualValues(t, 1, rows[0]["id"]) + + close(done) + assert.Eventually(t, func() bool { + select { + case <-serverDone: + return true + default: + return false + } + }, 2*time.Second, 50*time.Millisecond) + + assert.Eventually(t, func() bool { + _, err := os.Stat(socketPath) + return errors.Is(err, os.ErrNotExist) + }, 5*time.Second, 100*time.Millisecond) +}