Skip to content

Commit 50e6664

Browse files
authored
pgstmt: union (#27)
* pgstmt: add union * pgstmt: add join union to select statement * add offset * add nested union * update test
1 parent 0c02819 commit 50e6664

File tree

5 files changed

+274
-0
lines changed

5 files changed

+274
-0
lines changed

pgstmt/select.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,31 @@ type SelectStatement interface {
1515
From(table ...string)
1616
FromSelect(f func(b SelectStatement), as string)
1717
FromValues(f func(b Values), as string)
18+
1819
Join(table string) Join
1920
InnerJoin(table string) Join
2021
FullOuterJoin(table string) Join
2122
LeftJoin(table string) Join
2223
RightJoin(table string) Join
24+
2325
JoinSelect(f func(b SelectStatement), as string) Join
2426
InnerJoinSelect(f func(b SelectStatement), as string) Join
2527
FullOuterJoinSelect(f func(b SelectStatement), as string) Join
2628
LeftJoinSelect(f func(b SelectStatement), as string) Join
2729
RightJoinSelect(f func(b SelectStatement), as string) Join
30+
2831
JoinLateralSelect(f func(b SelectStatement), as string) Join
2932
InnerJoinLateralSelect(f func(b SelectStatement), as string) Join
3033
FullOuterJoinLateralSelect(f func(b SelectStatement), as string) Join
3134
LeftJoinLateralSelect(f func(b SelectStatement), as string) Join
3235
RightJoinLateralSelect(f func(b SelectStatement), as string) Join
36+
37+
JoinUnion(f func(b UnionStatement), as string) Join
38+
InnerJoinUnion(f func(b UnionStatement), as string) Join
39+
FullOuterJoinUnion(f func(b UnionStatement), as string) Join
40+
LeftJoinUnion(f func(b UnionStatement), as string) Join
41+
RightJoinUnion(f func(b UnionStatement), as string) Join
42+
3343
Where(f func(b Cond))
3444
GroupBy(col ...string)
3545
Having(f func(b Cond))
@@ -217,6 +227,44 @@ func (st *selectStmt) RightJoinLateralSelect(f func(b SelectStatement), as strin
217227
return st.joinSelect("right join lateral", f, as)
218228
}
219229

230+
func (st *selectStmt) joinUnion(typ string, f func(b UnionStatement), as string) Join {
231+
var x unionStmt
232+
f(&x)
233+
234+
var b buffer
235+
b.push(paren(x.make()))
236+
if as != "" {
237+
b.push(as)
238+
}
239+
240+
j := join{
241+
typ: typ,
242+
table: &b,
243+
}
244+
st.joins.push(&j)
245+
return &j
246+
}
247+
248+
func (st *selectStmt) JoinUnion(f func(b UnionStatement), as string) Join {
249+
return st.joinUnion("join", f, as)
250+
}
251+
252+
func (st *selectStmt) InnerJoinUnion(f func(b UnionStatement), as string) Join {
253+
return st.joinUnion("inner join", f, as)
254+
}
255+
256+
func (st *selectStmt) FullOuterJoinUnion(f func(b UnionStatement), as string) Join {
257+
return st.joinUnion("full outer join", f, as)
258+
}
259+
260+
func (st *selectStmt) LeftJoinUnion(f func(b UnionStatement), as string) Join {
261+
return st.joinUnion("left join", f, as)
262+
}
263+
264+
func (st *selectStmt) RightJoinUnion(f func(b UnionStatement), as string) Join {
265+
return st.joinUnion("right join", f, as)
266+
}
267+
220268
func (st *selectStmt) Where(f func(b Cond)) {
221269
f(&st.where)
222270
}

pgstmt/select_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,37 @@ func TestSelect(t *testing.T) {
452452
`,
453453
nil,
454454
},
455+
{
456+
"inner join union",
457+
pgstmt.Select(func(b pgstmt.SelectStatement) {
458+
b.Columns("id")
459+
b.From("table1")
460+
b.InnerJoinUnion(func(b pgstmt.UnionStatement) {
461+
b.Select(func(b pgstmt.SelectStatement) {
462+
b.Columns("id")
463+
b.From("table2")
464+
})
465+
b.AllSelect(func(b pgstmt.SelectStatement) {
466+
b.Columns("id")
467+
b.From("table3")
468+
})
469+
b.OrderBy("id").Desc()
470+
b.Limit(100)
471+
}, "t").Using("id")
472+
}),
473+
`
474+
select id
475+
from table1
476+
inner join (
477+
(select id from table2)
478+
union all
479+
(select id from table3)
480+
order by id desc
481+
limit 100
482+
) t using (id)
483+
`,
484+
nil,
485+
},
455486
}
456487

457488
for _, tC := range cases {

pgstmt/union.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package pgstmt
2+
3+
func Union(f func(b UnionStatement)) *Result {
4+
var st unionStmt
5+
f(&st)
6+
return newResult(build(st.make()))
7+
}
8+
9+
type UnionStatement interface {
10+
Select(f func(b SelectStatement))
11+
AllSelect(f func(b SelectStatement))
12+
Union(f func(b UnionStatement))
13+
AllUnion(f func(b UnionStatement))
14+
OrderBy(col string) OrderBy
15+
Limit(n int64)
16+
Offset(n int64)
17+
}
18+
19+
type unionStmt struct {
20+
b buffer
21+
orderBy group
22+
limit *int64
23+
offset *int64
24+
}
25+
26+
func (st *unionStmt) Select(f func(b SelectStatement)) {
27+
var x selectStmt
28+
f(&x)
29+
30+
if st.b.empty() {
31+
st.b.push(paren(x.make()))
32+
} else {
33+
st.b.push("union", paren(x.make()))
34+
}
35+
}
36+
37+
func (st *unionStmt) AllSelect(f func(b SelectStatement)) {
38+
var x selectStmt
39+
f(&x)
40+
41+
if st.b.empty() {
42+
st.b.push(paren(x.make()))
43+
} else {
44+
st.b.push("union all", paren(x.make()))
45+
}
46+
}
47+
48+
func (st *unionStmt) Union(f func(b UnionStatement)) {
49+
var x unionStmt
50+
f(&x)
51+
52+
if st.b.empty() {
53+
st.b.push(paren(x.make()))
54+
} else {
55+
st.b.push("union", paren(x.make()))
56+
}
57+
}
58+
59+
func (st *unionStmt) AllUnion(f func(b UnionStatement)) {
60+
var x unionStmt
61+
f(&x)
62+
63+
if st.b.empty() {
64+
st.b.push(paren(x.make()))
65+
} else {
66+
st.b.push("union all", paren(x.make()))
67+
}
68+
}
69+
70+
func (st *unionStmt) OrderBy(col string) OrderBy {
71+
p := orderBy{
72+
col: col,
73+
}
74+
st.orderBy.push(&p)
75+
return &p
76+
}
77+
78+
func (st *unionStmt) Limit(n int64) {
79+
st.limit = &n
80+
}
81+
82+
func (st *unionStmt) Offset(n int64) {
83+
st.offset = &n
84+
}
85+
86+
func (st *unionStmt) make() *buffer {
87+
var b buffer
88+
b.push(&st.b)
89+
if !st.orderBy.empty() {
90+
b.push("order by", &st.orderBy)
91+
}
92+
if st.limit != nil {
93+
b.push("limit", *st.limit)
94+
}
95+
if st.offset != nil {
96+
b.push("offset", *st.offset)
97+
}
98+
return &b
99+
}

pgstmt/union_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package pgstmt_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
8+
"github.com/acoshift/pgsql/pgstmt"
9+
)
10+
11+
func TestUnion(t *testing.T) {
12+
t.Parallel()
13+
14+
cases := []struct {
15+
name string
16+
result *pgstmt.Result
17+
query string
18+
args []interface{}
19+
}{
20+
{
21+
"union select",
22+
pgstmt.Union(func(b pgstmt.UnionStatement) {
23+
b.Select(func(b pgstmt.SelectStatement) {
24+
b.Columns("id")
25+
b.From("table1")
26+
})
27+
b.AllSelect(func(b pgstmt.SelectStatement) {
28+
b.Columns("id")
29+
b.From("table2")
30+
})
31+
b.OrderBy("id")
32+
b.Limit(10)
33+
b.Offset(2)
34+
}),
35+
`
36+
(select id from table1)
37+
union all (select id from table2)
38+
order by id
39+
limit 10 offset 2
40+
`,
41+
nil,
42+
},
43+
{
44+
"union nested",
45+
pgstmt.Union(func(b pgstmt.UnionStatement) {
46+
b.Union(func(b pgstmt.UnionStatement) {
47+
b.Select(func(b pgstmt.SelectStatement) {
48+
b.Columns("id")
49+
b.From("table1")
50+
})
51+
b.Select(func(b pgstmt.SelectStatement) {
52+
b.Columns("id")
53+
b.From("table2")
54+
})
55+
})
56+
b.Select(func(b pgstmt.SelectStatement) {
57+
b.Columns("id")
58+
b.From("table3")
59+
})
60+
b.AllUnion(func(b pgstmt.UnionStatement) {
61+
b.Select(func(b pgstmt.SelectStatement) {
62+
b.Columns("id")
63+
b.From("table4")
64+
})
65+
b.Select(func(b pgstmt.SelectStatement) {
66+
b.Columns("id")
67+
b.From("table5")
68+
})
69+
})
70+
}),
71+
`
72+
(
73+
(select id from table1)
74+
union (select id from table2)
75+
)
76+
union (select id from table3)
77+
union all (
78+
(select id from table4)
79+
union
80+
(select id from table5)
81+
)
82+
`,
83+
nil,
84+
},
85+
}
86+
87+
for _, tC := range cases {
88+
t.Run(tC.name, func(t *testing.T) {
89+
q, args := tC.result.SQL()
90+
assert.Equal(t, stripSpace(tC.query), q)
91+
assert.EqualValues(t, tC.args, args)
92+
})
93+
}
94+
}

pgstmt/util_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ func stripSpace(s string) string {
1515
}
1616
s = p
1717
}
18+
s = strings.ReplaceAll(s, "( ", "(")
19+
s = strings.ReplaceAll(s, " )", ")")
1820
return s
1921
}

0 commit comments

Comments
 (0)