-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_firewall_test.go
More file actions
119 lines (102 loc) · 3.25 KB
/
basic_firewall_test.go
File metadata and controls
119 lines (102 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package example_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/philiphil/restman/errors"
"github.com/philiphil/restman/orm"
"github.com/philiphil/restman/orm/entity"
"github.com/philiphil/restman/orm/gormrepository"
"github.com/philiphil/restman/route"
"github.com/philiphil/restman/router"
"github.com/philiphil/restman/security"
)
// protected entity is an example of entity that requires authentication
type ProtectedEntity struct {
entity.BaseEntity
}
func (e ProtectedEntity) GetId() entity.ID {
return e.Id
}
func (e ProtectedEntity) SetId(id any) entity.Entity {
e.Id = entity.CastId(id)
return e
}
func (t ProtectedEntity) ToEntity() ProtectedEntity {
return t
}
func (t ProtectedEntity) FromEntity(entity ProtectedEntity) any {
return entity
}
// Reading and Writing rights are the interfaces used
// they require GetReadingRights and GetWritingRights functions
// here we use a default implementation that requires authentication and nothing else
// in order to demonstrate the use of the firewall
func (e ProtectedEntity) GetReadingRights() security.AuthorizationFunction {
return security.AuthenticationRequired
}
// So we need to define a firewall, which is a struct that implements the security.Firewall interface
// the firewall will be used to get the user from the request using Headers, Cookies, etc
type TestFirewall struct {
}
// The GetUser function is used to get the user from the request
// it returns the user and an error
// Here we use the Authorization header to get the user id, not to be used in production obviously
func (f TestFirewall) GetUser(c *gin.Context) (security.User, error) {
token := c.GetHeader("Authorization")
if token == "" {
return nil, errors.ErrUnauthorized
}
id := entity.CastId(token)
return TestUser{}.SetId(id), nil
}
// The user must implement the security.User interface which is essentialy the same as the entity.Entity interface for convenience
type TestUser struct {
entity.BaseEntity
}
func (e TestUser) GetId() entity.ID {
return e.Id
}
func (e TestUser) SetId(id any) entity.Entity {
e.Id = entity.CastId(id)
return e
}
func SetupFireWallRouter() *gin.Engine {
r := gin.New()
r.Use(gin.Recovery())
test_ := router.NewApiRouter(
*orm.NewORM(gormrepository.NewRepository[ProtectedEntity](getDB())),
route.DefaultApiRoutes(),
)
test_.AddFirewall(TestFirewall{})
getDB().AutoMigrate(&ProtectedEntity{})
test_.AllowRoutes(r)
return r
}
func TestMain_FireWall(t *testing.T) {
r := SetupFireWallRouter()
w := httptest.NewRecorder()
repo := orm.NewORM(gormrepository.NewRepository[ProtectedEntity](getDB()))
repo.Create(&ProtectedEntity{})
w = httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/protected_entity/1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Error("Should return 401")
}
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/api/protected_entity/1", nil)
req.Header.Add("Authorization", "1")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Error("Should return 200")
}
w = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/api/protected_entity/1", nil)
req.Header.Add("Authorization", "0")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Error("Should return 401")
}
}