Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 94 additions & 45 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,49 +53,7 @@ func (s *Server) Serve(file string) error {
}

dir := http.Dir(directory)
chttp := http.NewServeMux()
chttp.Handle("/static/", http.FileServer(http.FS(defaults.StaticFiles)))
chttp.Handle("/", http.FileServer(dir))

// Regex for markdown
regex := regexp.MustCompile(`(?i)\.md$`)

// Serve website with rendered markdown
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
f, err := dir.Open(r.URL.Path)
if err == nil {
//nolint:errcheck
defer f.Close()
}

if err == nil && regex.MatchString(r.URL.Path) {
// Open file and convert to html
bytes, err := readToString(dir, r.URL.Path)
if err != nil {
log.Fatal(err)
return
}
htmlContent, err := s.parser.MdToHTML(bytes)
if err != nil {
log.Fatal(err)
return
}

// Serve
err = serveTemplate(w, htmlStruct{
Content: string(htmlContent),
BoundingBox: s.boundingBox,
CssCodeLight: getCssCode("github"),
CssCodeDark: getCssCode("github-dark"),
})
if err != nil {
log.Fatal(err)
return
}
} else {
chttp.ServeHTTP(w, r)
}
})
handler := s.newHandler(dir)

addr := fmt.Sprintf("http://%s:%d/", s.host, s.port)
if file == "" {
Expand All @@ -122,16 +80,64 @@ func (s *Server) Serve(file string) error {
}
}

var handler http.Handler = http.DefaultServeMux
if s.enableReload {
handler = reloadMiddleware.Handle(http.DefaultServeMux)
handler = reloadMiddleware.Handle(handler)
fmt.Printf("📡 Auto-reload enabled. Files will trigger browser refresh.\n")
} else {
fmt.Printf("🔄 Auto-reload disabled. Use F5 to manually refresh.\n")
}
return http.ListenAndServe(fmt.Sprintf(":%d", s.port), handler)
}

func (s *Server) newHandler(dir http.Dir) http.Handler {
fileServer := http.FileServer(dir)
mux := http.NewServeMux()
mux.Handle("/static/", http.FileServer(http.FS(defaults.StaticFiles)))

regex := regexp.MustCompile(`(?i)\.md$`)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if regex.MatchString(r.URL.Path) {
isFile, err := isRegularFile(dir, r.URL.Path)
if err == nil && isFile {
setNoCacheHeaders(w)

bytes, err := readToString(dir, r.URL.Path)
if err != nil {
log.Fatal(err)
return
}
htmlContent, err := s.parser.MdToHTML(bytes)
if err != nil {
log.Fatal(err)
return
}

err = serveTemplate(w, htmlStruct{
Content: string(htmlContent),
BoundingBox: s.boundingBox,
CssCodeLight: getCssCode("github"),
CssCodeDark: getCssCode("github-dark"),
})
if err != nil {
log.Fatal(err)
return
}
return
}
}

isDirectory, err := isDirectory(dir, r.URL.Path)
if err == nil && isDirectory {
setNoCacheHeaders(w)
stripCacheValidators(r)
}

fileServer.ServeHTTP(w, r)
})

return mux
}

func readToString(dir http.Dir, filename string) ([]byte, error) {
f, err := dir.Open(filename)
if err != nil {
Expand Down Expand Up @@ -172,3 +178,46 @@ func getCssCode(style string) string {
_ = formatter.WriteCSS(buf, s)
return buf.String()
}

func setNoCacheHeaders(w http.ResponseWriter) {
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
}

func stripCacheValidators(r *http.Request) {
r.Header.Del("If-Modified-Since")
r.Header.Del("If-None-Match")
}

func isDirectory(dir http.Dir, name string) (bool, error) {
file, err := dir.Open(name)
if err != nil {
return false, err
}
//nolint:errcheck
defer file.Close()

info, err := file.Stat()
if err != nil {
return false, err
}

return info.IsDir(), nil
}

func isRegularFile(dir http.Dir, name string) (bool, error) {
file, err := dir.Open(name)
if err != nil {
return false, err
}
//nolint:errcheck
defer file.Close()

info, err := file.Stat()
if err != nil {
return false, err
}

return !info.IsDir(), nil
}
92 changes: 92 additions & 0 deletions internal/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package internal

import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
)

func TestDirectoryListingIgnoresCacheValidators(t *testing.T) {
t.Parallel()

tmpDir := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpDir, "README.md"), []byte("# Hello\n"), 0o644); err != nil {
t.Fatalf("write README.md: %v", err)
}

server := NewServer("localhost", 6419, false, false, false, NewParser())
handler := server.newHandler(http.Dir(tmpDir))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("If-Modified-Since", time.Now().Add(24*time.Hour).UTC().Format(http.TimeFormat))

recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)

if recorder.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
}
if got := recorder.Header().Get("Cache-Control"); !strings.Contains(got, "no-store") {
t.Fatalf("expected Cache-Control to disable storage, got %q", got)
}
if !strings.Contains(recorder.Body.String(), "README.md") {
t.Fatalf("expected directory listing body to mention README.md, got %q", recorder.Body.String())
}
}

func TestRegularFileStillSupportsConditionalRequests(t *testing.T) {
t.Parallel()

tmpDir := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpDir, "plain.txt"), []byte("hello\n"), 0o644); err != nil {
t.Fatalf("write plain.txt: %v", err)
}

server := NewServer("localhost", 6419, false, false, false, NewParser())
handler := server.newHandler(http.Dir(tmpDir))

req := httptest.NewRequest(http.MethodGet, "/plain.txt", nil)
req.Header.Set("If-Modified-Since", time.Now().Add(24*time.Hour).UTC().Format(http.TimeFormat))

recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)

if recorder.Code != http.StatusNotModified {
t.Fatalf("expected status %d, got %d", http.StatusNotModified, recorder.Code)
}
}

func TestMarkdownResponsesDisableCaching(t *testing.T) {
t.Parallel()

tmpDir := t.TempDir()
if err := os.WriteFile(filepath.Join(tmpDir, "README.md"), []byte("# Hello\n"), 0o644); err != nil {
t.Fatalf("write README.md: %v", err)
}

server := NewServer("localhost", 6419, false, false, false, NewParser())
handler := server.newHandler(http.Dir(tmpDir))

req := httptest.NewRequest(http.MethodGet, "/README.md", nil)
req.Header.Set("If-Modified-Since", time.Now().Add(24*time.Hour).UTC().Format(http.TimeFormat))

recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)

if recorder.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code)
}
if got := recorder.Header().Get("Cache-Control"); !strings.Contains(got, "no-store") {
t.Fatalf("expected Cache-Control to disable storage, got %q", got)
}
if got := recorder.Header().Get("Content-Type"); got != "text/html" {
t.Fatalf("expected text/html response, got %q", got)
}
if !strings.Contains(recorder.Body.String(), "Hello") {
t.Fatalf("expected rendered markdown response to contain document content, got %q", recorder.Body.String())
}
}
Loading