diff --git a/checks/checks.go b/checks/checks.go index ebc64d5..e33df4c 100644 --- a/checks/checks.go +++ b/checks/checks.go @@ -12,6 +12,7 @@ type Checks struct { BlockList *BlockList Carbon *Carbon Headers *Headers + Hsts *Hsts IpAddress *Ip LegacyRank *LegacyRank LinkedPages *LinkedPages @@ -28,6 +29,7 @@ func NewChecks() *Checks { BlockList: NewBlockList(&ip.NetDNSLookup{}), Carbon: NewCarbon(client), Headers: NewHeaders(client), + Hsts: NewHsts(client), IpAddress: NewIp(NewNetIp()), LegacyRank: NewLegacyRank(legacyrank.NewInMemoryStore()), LinkedPages: NewLinkedPages(client), diff --git a/checks/hsts.go b/checks/hsts.go new file mode 100644 index 0000000..19ace02 --- /dev/null +++ b/checks/hsts.go @@ -0,0 +1,87 @@ +package checks + +import ( + "context" + "net/http" + "strconv" + "strings" + "unicode" +) + +type HSTSResponse struct { + Message string `json:"message"` + Compatible bool `json:"compatible"` + HSTSHeader string `json:"hstsHeader"` +} + +type Hsts struct { + client *http.Client +} + +func NewHsts(client *http.Client) *Hsts { + return &Hsts{client: client} +} + +func (h *Hsts) Validate(ctx context.Context, url string) (*HSTSResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return nil, err + } + + resp, err := h.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + hstsHeader := resp.Header.Get("Strict-Transport-Security") + if hstsHeader == "" { + return &HSTSResponse{Message: "Site does not serve any HSTS headers."}, nil + } + + if !strings.Contains(hstsHeader, "max-age") { + return &HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil + } + + var maxAgeString string + for _, h := range strings.Split(hstsHeader, " ") { + if strings.Contains(h, "max-age=") { + maxAgeString = extractMaxAgeFromHeader(h) + } + } + + maxAge, err := strconv.Atoi(maxAgeString) + if err != nil { + return nil, err + } + + if maxAge < 10886400 { + return &HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil + } + + if !strings.Contains(hstsHeader, "includeSubDomains") { + return &HSTSResponse{Message: "HSTS header does not include all subdomains."}, nil + } + + if !strings.Contains(hstsHeader, "preload") { + return &HSTSResponse{Message: "HSTS header does not contain the preload directive."}, nil + } + + return &HSTSResponse{ + Message: "Site is compatible with the HSTS preload list!", + Compatible: true, + HSTSHeader: hstsHeader, + }, nil +} + +func extractMaxAgeFromHeader(header string) string { + var maxAge strings.Builder + + for _, b := range header { + if unicode.IsDigit(b) { + maxAge.WriteRune(b) + } + } + + return maxAge.String() +} diff --git a/checks/hsts_test.go b/checks/hsts_test.go new file mode 100644 index 0000000..e64eb59 --- /dev/null +++ b/checks/hsts_test.go @@ -0,0 +1,125 @@ +package checks + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xray-web/web-check-api/testutils" +) + +func TestValidate(t *testing.T) { + t.Parallel() + + t.Run("given an empty header", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{""}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "Site does not serve any HSTS headers.", actual.Message) + assert.False(t, actual.Compatible) + assert.Empty(t, actual.HSTSHeader) + }) + + t.Run("given a header without max age", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{"includeSubDomains; preload"}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "HSTS max-age is less than 10886400.", actual.Message) + assert.False(t, actual.Compatible) + assert.Empty(t, actual.HSTSHeader) + }) + + t.Run("given max age less than 10886400", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{"max-age=47; includeSubDomains; preload"}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "HSTS max-age is less than 10886400.", actual.Message) + assert.False(t, actual.Compatible) + assert.Empty(t, actual.HSTSHeader) + }) + + t.Run("given a header without includeSubDomains", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; preload"}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "HSTS header does not include all subdomains.", actual.Message) + assert.False(t, actual.Compatible) + assert.Empty(t, actual.HSTSHeader) + }) + + t.Run("given a header without preload", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; includeSubDomains"}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "HSTS header does not contain the preload directive.", actual.Message) + assert.False(t, actual.Compatible) + assert.Empty(t, actual.HSTSHeader) + }) + + t.Run("given a valid header", func(t *testing.T) { + t.Parallel() + + client := testutils.MockClient(&http.Response{ + Header: http.Header{"Strict-Transport-Security": []string{"max-age=47474747; includeSubDomains; preload"}}}) + h := NewHsts(client) + + actual, err := h.Validate(context.Background(), "test.com") + assert.NoError(t, err) + + assert.Equal(t, "Site is compatible with the HSTS preload list!", actual.Message) + assert.True(t, actual.Compatible) + assert.NotEmpty(t, actual.HSTSHeader) + }) +} + +func TestExtractMaxAgeFromHeader(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + header string + expected string + }{ + {"give valid header", "max-age=47474747;", "47474747"}, + {"given an empty header", "", ""}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + actual := extractMaxAgeFromHeader(tc.header) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/handlers/hsts.go b/handlers/hsts.go index 647cba7..83c3eb1 100644 --- a/handlers/hsts.go +++ b/handlers/hsts.go @@ -1,54 +1,12 @@ package handlers import ( - "fmt" "net/http" - "regexp" - "strings" -) - -type HSTSResponse struct { - Message string `json:"message"` - Compatible bool `json:"compatible"` - HSTSHeader string `json:"hstsHeader"` -} - -func checkHSTS(url string) (HSTSResponse, error) { - client := &http.Client{} - - req, err := http.NewRequest("HEAD", url, nil) - if err != nil { - return HSTSResponse{}, fmt.Errorf("error creating request: %s", err.Error()) - } - - resp, err := client.Do(req) - if err != nil { - return HSTSResponse{}, fmt.Errorf("error making request: %s", err.Error()) - } - defer resp.Body.Close() - hstsHeader := resp.Header.Get("strict-transport-security") - if hstsHeader == "" { - return HSTSResponse{Message: "Site does not serve any HSTS headers."}, nil - } - - maxAgeMatch := regexp.MustCompile(`max-age=(\d+)`).FindStringSubmatch(hstsHeader) - if maxAgeMatch == nil || len(maxAgeMatch) < 2 || maxAgeMatch[1] == "" || maxAgeMatch[1] < "10886400" { - return HSTSResponse{Message: "HSTS max-age is less than 10886400."}, nil - } - - if !strings.Contains(hstsHeader, "includeSubDomains") { - return HSTSResponse{Message: "HSTS header does not include all subdomains."}, nil - } - - if !strings.Contains(hstsHeader, "preload") { - return HSTSResponse{Message: "HSTS header does not contain the preload directive."}, nil - } - - return HSTSResponse{Message: "Site is compatible with the HSTS preload list!", Compatible: true, HSTSHeader: hstsHeader}, nil -} + "github.com/xray-web/web-check-api/checks" +) -func HandleHsts() http.Handler { +func HandleHsts(h *checks.Hsts) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rawURL, err := extractURL(r) if err != nil { @@ -56,7 +14,7 @@ func HandleHsts() http.Handler { return } - result, err := checkHSTS(rawURL.String()) + result, err := h.Validate(r.Context(), rawURL.String()) if err != nil { JSONError(w, err, http.StatusInternalServerError) return diff --git a/handlers/hsts_test.go b/handlers/hsts_test.go index 983ae22..c9ba27d 100644 --- a/handlers/hsts_test.go +++ b/handlers/hsts_test.go @@ -1,7 +1,6 @@ package handlers import ( - "encoding/json" "net/http" "net/http/httptest" "testing" @@ -11,18 +10,15 @@ import ( func TestHandleHsts(t *testing.T) { t.Parallel() - req := httptest.NewRequest("GET", "/check-hsts?url=example.com", nil) - rec := httptest.NewRecorder() - HandleHsts().ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + t.Run("missing URL parameter", func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodGet, "/check-hsts", nil) + rec := httptest.NewRecorder() - var response HSTSResponse - err := json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) + HandleHsts(nil).ServeHTTP(rec, req) - assert.NotNil(t, response) - assert.Equal(t, "Site does not serve any HSTS headers.", response.Message) - assert.False(t, response.Compatible) - assert.Empty(t, response.HSTSHeader) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.JSONEq(t, `{"error": "missing URL parameter"}`, rec.Body.String()) + }) } diff --git a/server/server.go b/server/server.go index 5ee1493..9513a86 100644 --- a/server/server.go +++ b/server/server.go @@ -41,7 +41,7 @@ func (s *Server) routes() { s.mux.Handle("GET /api/firewall", handlers.HandleFirewall()) s.mux.Handle("GET /api/get-ip", handlers.HandleGetIP(s.checks.IpAddress)) s.mux.Handle("GET /api/headers", handlers.HandleGetHeaders(s.checks.Headers)) - s.mux.Handle("GET /api/hsts", handlers.HandleHsts()) + s.mux.Handle("GET /api/hsts", handlers.HandleHsts(s.checks.Hsts)) s.mux.Handle("GET /api/http-security", handlers.HandleHttpSecurity()) s.mux.Handle("GET /api/legacy-rank", handlers.HandleLegacyRank(s.checks.LegacyRank)) s.mux.Handle("GET /api/linked-pages", handlers.HandleGetLinks(s.checks.LinkedPages))