Skip to content

Commit 691672b

Browse files
tables
1 parent d51f20b commit 691672b

File tree

4 files changed

+96
-12
lines changed

4 files changed

+96
-12
lines changed

table.go

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,94 @@ import (
1010
"github.com/jackc/pgx/v5"
1111
)
1212

13+
// NewTable creates a new Table instance for a specific record type.
14+
//
15+
// T: the concrete struct type representing a single record (e.g., Account).
16+
// PT: the pointer type to T (e.g., *Account), which must implement tableP[T, IDT].
17+
// IDT: the type of the primary key (e.g., int64), which must be comparable.
18+
//
19+
// db: a pointer to the DB instance
20+
//
21+
// The function introspects a zero value of T to extract metadata if it
22+
// implements the Base[IDT], specifically:
23+
//
24+
// - Validate() error -> validation function, which si called everytime on saving
25+
// - DBTableName() string -> returns the database table name.
26+
// - GetID() IDT -> returns primary key ID
27+
// - GetIDColumn() string -> returns the name of the primary key column, which is then used for default sorting
28+
//
29+
// Additionally, records may optionally implement the following interfaces to
30+
// allow automatic timestamp management:
31+
//
32+
// - hasSetCreatedAt -> SetCreatedAt(time.Time), called automatically on record creation
33+
// - hasSetUpdatedAt -> SetUpdatedAt(time.Time), called automatically on record creation and update
34+
// - hasSetDeletedAt -> SetDeletedAt(time.Time), called automatically on soft deletion
35+
//
36+
// Example usage:
37+
//
38+
// type accountsTable struct {
39+
// *pgkit.Table[Account, *Account, int64]
40+
// }
41+
//
42+
// at := accountsTable{
43+
// Table: pgkit.NewTable[Account, *Account, int64](db),
44+
// }
45+
//
46+
// type Account struct {
47+
// ID int64 `db:"id,omitempty"`
48+
// Name string `db:"name"`
49+
// CreatedAt time.Time `db:"created_at,omitempty"` // ,omitempty will rely on Postgres DEFAULT
50+
// UpdatedAt time.Time `db:"updated_at,omitempty"` // ,omitempty will rely on Postgres DEFAULT
51+
// }
52+
//
53+
// func (a *Account) DBTableName() string { return "accounts" }
54+
// func (a *Account) GetIDColumn() string { return "id" }
55+
// func (a *Account) GetID() int64 { return a.ID }
56+
// func (a *Account) SetUpdatedAt(t time.Time) { a.UpdatedAt = t }
57+
//
58+
// func (a *Account) Validate() error {
59+
// if a.Name == "" {
60+
// return fmt.Errorf("name is required")
61+
// }
62+
// return nil
63+
// }
64+
func NewTable[T any, PT TableP[T, IDT], IDT comparable](db *DB, name string) *Table[T, PT, IDT] {
65+
var t T
66+
67+
idColumn := ""
68+
if v, ok := any(&t).(Base[IDT]); ok {
69+
idColumn = v.GetIDColumn()
70+
}
71+
72+
return &Table[T, PT, IDT]{
73+
DB: db,
74+
Name: name,
75+
IDColumn: idColumn,
76+
}
77+
}
78+
1379
// Table provides basic CRUD operations for database records.
1480
// Records must implement GetID() and Validate() methods.
15-
type Table[T any, PT interface {
16-
*T // Enforce T is a pointer.
17-
GetID() IDT
18-
Validate() error
19-
}, IDT comparable] struct {
81+
type Table[T any, PT TableP[T, IDT], IDT comparable] struct {
2082
*DB
2183
Name string
2284
IDColumn string
2385
}
2486

87+
type TableP[T any, IDT comparable] interface {
88+
*T // Enforce that T is a pointer.
89+
Base[IDT]
90+
}
91+
92+
type Base[IDT comparable] interface {
93+
Validate() error
94+
95+
GetID() IDT
96+
GetIDColumn() string
97+
}
98+
99+
func (t *Table[T, PT, IDT]) DBTableName() string { return t.Name }
100+
25101
type hasSetCreatedAt interface {
26102
SetCreatedAt(time.Time)
27103
}
@@ -68,7 +144,7 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
68144
Suffix("RETURNING *")
69145

70146
if err := t.Query.GetOne(ctx, q, record); err != nil {
71-
return fmt.Errorf("save: insert record: %w", err)
147+
return fmt.Errorf("insert record: %w", err)
72148
}
73149

74150
return nil
@@ -77,7 +153,7 @@ func (t *Table[T, PT, IDT]) saveOne(ctx context.Context, record PT) error {
77153
// Update
78154
q := t.SQL.UpdateRecord(record, sq.Eq{t.IDColumn: record.GetID()}, t.Name)
79155
if _, err := t.Query.Exec(ctx, q); err != nil {
80-
return fmt.Errorf("save: update record: %w", err)
156+
return fmt.Errorf("update record: %w", err)
81157
}
82158

83159
return nil

tests/database_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ package pgkit_test
33
import (
44
"context"
55

6-
"github.com/goware/pgkit/v2"
76
"github.com/jackc/pgx/v5"
7+
8+
"github.com/goware/pgkit/v2"
89
)
910

1011
type Database struct {
@@ -18,9 +19,9 @@ type Database struct {
1819
func initDB(db *pgkit.DB) *Database {
1920
return &Database{
2021
DB: db,
21-
Accounts: &accountsTable{Table: &pgkit.Table[Account, *Account, int64]{DB: db, Name: "accounts", IDColumn: "id"}},
22-
Articles: &articlesTable{Table: &pgkit.Table[Article, *Article, uint64]{DB: db, Name: "articles", IDColumn: "id"}},
23-
Reviews: &reviewsTable{Table: &pgkit.Table[Review, *Review, uint64]{DB: db, Name: "reviews", IDColumn: "id"}},
22+
Accounts: &accountsTable{Table: pgkit.NewTable[Account, *Account](db, "accounts")},
23+
Articles: &articlesTable{Table: pgkit.NewTable[Article, *Article](db, "articles")},
24+
Reviews: &reviewsTable{Table: pgkit.NewTable[Review, *Review](db, "reviews")},
2425
}
2526
}
2627

tests/schema_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type Account struct {
1616
}
1717

1818
func (a *Account) DBTableName() string { return "accounts" }
19+
func (a *Account) GetIDColumn() string { return "id" }
1920
func (a *Account) GetID() int64 { return a.ID }
2021
func (a *Account) SetUpdatedAt(t time.Time) { a.UpdatedAt = t }
2122

@@ -38,6 +39,7 @@ type Article struct {
3839
DeletedAt *time.Time `db:"deleted_at"`
3940
}
4041

42+
func (a *Article) GetIDColumn() string { return "id" }
4143
func (a *Article) GetID() uint64 { return a.ID }
4244
func (a *Article) SetUpdatedAt(t time.Time) { a.UpdatedAt = t }
4345
func (a *Article) SetDeletedAt(t time.Time) { a.DeletedAt = &t }
@@ -69,6 +71,7 @@ type Review struct {
6971
DeletedAt *time.Time `db:"deleted_at"`
7072
}
7173

74+
func (r *Review) GetIDColumn() string { return "id" }
7275
func (r *Review) GetID() uint64 { return r.ID }
7376
func (r *Review) SetUpdatedAt(t time.Time) { r.UpdatedAt = t }
7477
func (r *Review) SetDeletedAt(t time.Time) { r.DeletedAt = &t }

tests/table_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func TestTable(t *testing.T) {
6262
account := &Account{Name: "Save Multiple Account"}
6363
err := db.Accounts.Save(ctx, account)
6464
require.NoError(t, err, "Create account failed")
65+
6566
articles := []*Article{
6667
{Author: "FirstNew", AccountID: account.ID},
6768
{Author: "SecondNew", AccountID: account.ID},
@@ -74,6 +75,7 @@ func TestTable(t *testing.T) {
7475
require.NotZero(t, articles[1].ID, "ID should be set")
7576
require.Equal(t, uint64(10001), articles[2].ID, "ID should be same")
7677
require.Equal(t, uint64(10002), articles[3].ID, "ID should be same")
78+
7779
// test update for multiple records
7880
updateArticles := []*Article{
7981
articles[0],
@@ -83,9 +85,11 @@ func TestTable(t *testing.T) {
8385
updateArticles[1].Author = "Updated Author Name 2"
8486
err = db.Articles.Save(ctx, updateArticles...)
8587
require.NoError(t, err, "Save articles")
88+
8689
updateArticle0, err := db.Articles.GetByID(ctx, articles[0].ID)
8790
require.NoError(t, err, "Get By ID")
8891
require.Equal(t, updateArticles[0].Author, updateArticle0.Author, "Author should be same")
92+
8993
updateArticle1, err := db.Articles.GetByID(ctx, articles[1].ID)
9094
require.NoError(t, err, "Get By ID")
9195
require.Equal(t, updateArticles[1].Author, updateArticle1.Author, "Author should be same")
@@ -217,7 +221,7 @@ func TestLockForUpdates(t *testing.T) {
217221
err = db.Reviews.Save(ctx, reviews...)
218222
require.NoError(t, err, "create review")
219223

220-
var ids [][]uint64 = make([][]uint64, 10)
224+
ids := make([][]uint64, 10)
221225
var wg sync.WaitGroup
222226

223227
for range 10 {

0 commit comments

Comments
 (0)