-
Notifications
You must be signed in to change notification settings - Fork 0
Add unix socket listening option #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,9 +5,11 @@ import ( | |||||||||||||||
| "encoding/json" | ||||||||||||||||
| "errors" | ||||||||||||||||
| "fmt" | ||||||||||||||||
| "net" | ||||||||||||||||
| "net/http" | ||||||||||||||||
| "os" | ||||||||||||||||
| "os/signal" | ||||||||||||||||
| "path/filepath" | ||||||||||||||||
| "syscall" | ||||||||||||||||
| "time" | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -27,6 +29,7 @@ const ( | |||||||||||||||
| type ServerOptions struct { | ||||||||||||||||
| Logger logr.Logger | ||||||||||||||||
| Addr string | ||||||||||||||||
| SocketPath string | ||||||||||||||||
| AuthOptions ServerAuthOptions | ||||||||||||||||
| SecurityOptions ServerSecurityOptions | ||||||||||||||||
| Queryer sqlx.QueryerContext | ||||||||||||||||
|
|
@@ -35,6 +38,7 @@ type ServerOptions struct { | |||||||||||||||
|
|
||||||||||||||||
| func (opts *ServerOptions) bindCLIFlags(fs *pflag.FlagSet) { | ||||||||||||||||
| fs.StringVar(&opts.Addr, "http-addr", ":8080", "server listen address") | ||||||||||||||||
| fs.StringVar(&opts.SocketPath, "http-socket", "", "server listen unix socket path. If set, http-addr will be ignored") | ||||||||||||||||
|
|
||||||||||||||||
| opts.AuthOptions.bindCLIFlags(fs) | ||||||||||||||||
| opts.SecurityOptions.bindCLIFlags(fs) | ||||||||||||||||
|
|
@@ -52,7 +56,11 @@ func (opts *ServerOptions) defaults() error { | |||||||||||||||
| opts.Logger = logr.Discard() | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if opts.Addr == "" { | ||||||||||||||||
| if opts.SocketPath != "" { | ||||||||||||||||
| opts.Addr = "" | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if opts.Addr == "" && opts.SocketPath == "" { | ||||||||||||||||
| opts.Addr = ":8080" | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -68,10 +76,12 @@ func (opts *ServerOptions) defaults() error { | |||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| type dbServer struct { | ||||||||||||||||
| logger logr.Logger | ||||||||||||||||
| server *http.Server | ||||||||||||||||
| queryer sqlx.QueryerContext | ||||||||||||||||
| execer sqlx.ExecerContext | ||||||||||||||||
| logger logr.Logger | ||||||||||||||||
| server *http.Server | ||||||||||||||||
| listener net.Listener | ||||||||||||||||
| socket string | ||||||||||||||||
| queryer sqlx.QueryerContext | ||||||||||||||||
| execer sqlx.ExecerContext | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| func NewServer(opts *ServerOptions) (*dbServer, error) { | ||||||||||||||||
|
|
@@ -86,6 +96,7 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { | |||||||||||||||
| // TODO: make it configurable | ||||||||||||||||
| ReadHeaderTimeout: 5 * time.Second, | ||||||||||||||||
| }, | ||||||||||||||||
| socket: opts.SocketPath, | ||||||||||||||||
| queryer: opts.Queryer, | ||||||||||||||||
| execer: opts.Execer, | ||||||||||||||||
| } | ||||||||||||||||
|
|
@@ -128,15 +139,53 @@ func NewServer(opts *ServerOptions) (*dbServer, error) { | |||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| func (server *dbServer) Start(done <-chan struct{}) { | ||||||||||||||||
| go server.server.ListenAndServe() | ||||||||||||||||
| if server.socket != "" { | ||||||||||||||||
| sockDir := filepath.Dir(server.socket) | ||||||||||||||||
| if sockDir != "" && sockDir != "." { | ||||||||||||||||
| if err := os.MkdirAll(sockDir, 0755); err != nil { | ||||||||||||||||
| server.logger.Error(err, "failed to ensure unix socket directory", "socket", server.socket) | ||||||||||||||||
| return | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| if err := os.RemoveAll(server.socket); err != nil { | ||||||||||||||||
|
||||||||||||||||
| if err := os.RemoveAll(server.socket); err != nil { | |
| if err := os.Remove(server.socket); err != nil && !errors.Is(err, os.ErrNotExist) { |
Copilot
AI
Jan 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The socket file permissions (0666 by default for Unix sockets) could be a security concern. Consider setting explicit restrictive permissions after creating the socket using os.Chmod to limit access to the socket file (e.g., 0600 for owner-only or 0660 for owner and group).
Copilot
AI
Jan 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error returned by server.Serve is not captured or logged. While Serve is called in a goroutine, consider capturing and logging the error for better observability, especially to distinguish between expected shutdown (http.ErrServerClosed) and unexpected errors.
Copilot
AI
Jan 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error returned by server.Serve is not captured or logged. While Serve is called in a goroutine, consider capturing and logging the error for better observability, especially to distinguish between expected shutdown (http.ErrServerClosed) and unexpected errors.
Copilot
AI
Jan 2, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The listener is not explicitly closed during shutdown. Although server.Shutdown will close the listener, explicitly closing it ensures proper cleanup and makes the code more maintainable. Add a check and close the listener before or after the Shutdown call.
| defer cancel() | |
| defer cancel() | |
| if server.listener != nil { | |
| if err := server.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { | |
| server.logger.Error(err, "failed to close listener") | |
| } | |
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| package main | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "errors" | ||
| "net" | ||
| "net/http" | ||
| "os" | ||
| "path/filepath" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/jmoiron/sqlx" | ||
| "github.com/stretchr/testify/assert" | ||
| ) | ||
|
|
||
| func TestServerWithUnixSocket(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| dir := t.TempDir() | ||
| socketPath := filepath.Join(dir, "sqlite-rest.sock") | ||
|
|
||
| dbPath := filepath.Join(dir, "test.db") | ||
| db, err := sqlx.Open("sqlite3", dbPath) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| defer db.Close() | ||
|
|
||
| _, err = db.Exec("CREATE TABLE test (id int)") | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| _, err = db.Exec(`INSERT INTO test (id) VALUES (1)`) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
|
|
||
| serverOpts := &ServerOptions{ | ||
| Logger: createTestLogger(t).WithName("test"), | ||
| Queryer: db, | ||
| Execer: db, | ||
| SocketPath: socketPath, | ||
| } | ||
| serverOpts.AuthOptions.disableAuth = true | ||
| serverOpts.SecurityOptions.EnabledTableOrViews = []string{"test"} | ||
|
|
||
| server, err := NewServer(serverOpts) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
|
|
||
| done := make(chan struct{}) | ||
| serverDone := make(chan struct{}) | ||
| go func() { | ||
| server.Start(done) | ||
| close(serverDone) | ||
| }() | ||
|
|
||
| assert.Eventually(t, func() bool { | ||
| _, err := os.Stat(socketPath) | ||
| return err == nil | ||
| }, 5*time.Second, 100*time.Millisecond) | ||
|
|
||
| client := &http.Client{ | ||
| Transport: &http.Transport{ | ||
| DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { | ||
| var d net.Dialer | ||
| return d.DialContext(ctx, "unix", socketPath) | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| req, err := http.NewRequest(http.MethodGet, "http://unix/test", nil) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
|
|
||
| resp, err := client.Do(req) | ||
| if err != nil { | ||
| t.Fatal(err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| assert.Equal(t, http.StatusOK, resp.StatusCode) | ||
|
|
||
| var rows []map[string]interface{} | ||
| err = json.NewDecoder(resp.Body).Decode(&rows) | ||
| assert.NoError(t, err) | ||
| assert.Len(t, rows, 1) | ||
| assert.EqualValues(t, 1, rows[0]["id"]) | ||
|
|
||
| close(done) | ||
| assert.Eventually(t, func() bool { | ||
| select { | ||
| case <-serverDone: | ||
| return true | ||
| default: | ||
| return false | ||
| } | ||
| }, 2*time.Second, 50*time.Millisecond) | ||
|
|
||
| assert.Eventually(t, func() bool { | ||
| _, err := os.Stat(socketPath) | ||
| return errors.Is(err, os.ErrNotExist) | ||
| }, 5*time.Second, 100*time.Millisecond) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using RemoveAll to clean up a socket file is dangerous and could accidentally delete an entire directory if the socket path is misconfigured. Use os.Remove instead, which will fail safely if the path is a directory.