Skip to content

Commit 9285eac

Browse files
committed
use atomic bool for logic and/or aggregators
1 parent 3f44a7f commit 9285eac

File tree

2 files changed

+281
-34
lines changed

2 files changed

+281
-34
lines changed

internal/routing/aggregator.go

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggreg
6565
}
6666
case RespAggLogicalAnd:
6767
andAgg := &AggLogicalAndAggregator{}
68-
andAgg.res.Add(1)
68+
andAgg.res.Store(true)
6969

7070
return andAgg
7171
case RespAggLogicalOr:
@@ -466,7 +466,7 @@ func (a *AggMaxAggregator) Result() (interface{}, error) {
466466
// AggLogicalAndAggregator performs logical AND on boolean values.
467467
type AggLogicalAndAggregator struct {
468468
err atomic.Value
469-
res atomic.Int64
469+
res atomic.Bool
470470
hasResult atomic.Bool
471471
}
472472

@@ -482,21 +482,9 @@ func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error {
482482
return e
483483
}
484484

485-
// Atomic AND operation using CompareAndSwap loop (Go 1.21 compatible)
486-
// TODO: Once minimum Go version is upgraded to 1.23+, replace this with:
487-
// if val { a.res.And(1) } else { a.res.And(0) }
488-
var newVal int64
489-
if val {
490-
newVal = 1
491-
} else {
492-
newVal = 0
493-
}
494-
for {
495-
old := a.res.Load()
496-
desired := old & newVal
497-
if a.res.CompareAndSwap(old, desired) {
498-
break
499-
}
485+
// Atomic AND operation: if val is false, result is always false
486+
if !val {
487+
a.res.Store(false)
500488
}
501489

502490
a.hasResult.Store(true)
@@ -555,13 +543,13 @@ func (a *AggLogicalAndAggregator) Result() (interface{}, error) {
555543
if !a.hasResult.Load() {
556544
return nil, ErrAndAggregation
557545
}
558-
return a.res.Load() != 0, nil
546+
return a.res.Load(), nil
559547
}
560548

561549
// AggLogicalOrAggregator performs logical OR on boolean values.
562550
type AggLogicalOrAggregator struct {
563551
err atomic.Value
564-
res atomic.Int64
552+
res atomic.Bool
565553
hasResult atomic.Bool
566554
}
567555

@@ -577,21 +565,9 @@ func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error {
577565
return e
578566
}
579567

580-
// Atomic OR operation using CompareAndSwap loop (Go 1.21 compatible)
581-
// TODO: Once minimum Go version is upgraded to 1.23+, replace this with:
582-
// if val { a.res.Or(1) } else { a.res.Or(0) }
583-
var newVal int64
568+
// Atomic OR operation: if val is true, result is always true
584569
if val {
585-
newVal = 1
586-
} else {
587-
newVal = 0
588-
}
589-
for {
590-
old := a.res.Load()
591-
desired := old | newVal
592-
if a.res.CompareAndSwap(old, desired) {
593-
break
594-
}
570+
a.res.Store(true)
595571
}
596572

597573
a.hasResult.Store(true)
@@ -650,7 +626,7 @@ func (a *AggLogicalOrAggregator) Result() (interface{}, error) {
650626
if !a.hasResult.Load() {
651627
return nil, ErrOrAggregation
652628
}
653-
return a.res.Load() != 0, nil
629+
return a.res.Load(), nil
654630
}
655631

656632
func toInt64(val interface{}) (int64, error) {
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
package routing
2+
3+
import (
4+
"errors"
5+
"testing"
6+
)
7+
8+
func TestAggLogicalAndAggregator(t *testing.T) {
9+
t.Run("all true values", func(t *testing.T) {
10+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
11+
12+
err := agg.Add(true, nil)
13+
if err != nil {
14+
t.Fatalf("unexpected error: %v", err)
15+
}
16+
17+
err = agg.Add(int64(1), nil)
18+
if err != nil {
19+
t.Fatalf("unexpected error: %v", err)
20+
}
21+
22+
err = agg.Add(1, nil)
23+
if err != nil {
24+
t.Fatalf("unexpected error: %v", err)
25+
}
26+
27+
result, err := agg.Result()
28+
if err != nil {
29+
t.Fatalf("unexpected error: %v", err)
30+
}
31+
32+
if result != true {
33+
t.Errorf("expected true, got %v", result)
34+
}
35+
})
36+
37+
t.Run("one false value", func(t *testing.T) {
38+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
39+
40+
err := agg.Add(true, nil)
41+
if err != nil {
42+
t.Fatalf("unexpected error: %v", err)
43+
}
44+
45+
err = agg.Add(false, nil)
46+
if err != nil {
47+
t.Fatalf("unexpected error: %v", err)
48+
}
49+
50+
err = agg.Add(true, nil)
51+
if err != nil {
52+
t.Fatalf("unexpected error: %v", err)
53+
}
54+
55+
result, err := agg.Result()
56+
if err != nil {
57+
t.Fatalf("unexpected error: %v", err)
58+
}
59+
60+
if result != false {
61+
t.Errorf("expected false, got %v", result)
62+
}
63+
})
64+
65+
t.Run("no results", func(t *testing.T) {
66+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
67+
68+
_, err := agg.Result()
69+
if err != ErrAndAggregation {
70+
t.Errorf("expected ErrAndAggregation, got %v", err)
71+
}
72+
})
73+
74+
t.Run("with error", func(t *testing.T) {
75+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
76+
77+
testErr := errors.New("test error")
78+
err := agg.Add(nil, testErr)
79+
if err != nil {
80+
t.Fatalf("unexpected error: %v", err)
81+
}
82+
83+
_, err = agg.Result()
84+
if err != testErr {
85+
t.Errorf("expected test error, got %v", err)
86+
}
87+
})
88+
}
89+
90+
func TestAggLogicalOrAggregator(t *testing.T) {
91+
t.Run("all false values", func(t *testing.T) {
92+
agg := NewResponseAggregator(RespAggLogicalOr, "")
93+
94+
err := agg.Add(false, nil)
95+
if err != nil {
96+
t.Fatalf("unexpected error: %v", err)
97+
}
98+
99+
err = agg.Add(int64(0), nil)
100+
if err != nil {
101+
t.Fatalf("unexpected error: %v", err)
102+
}
103+
104+
err = agg.Add(0, nil)
105+
if err != nil {
106+
t.Fatalf("unexpected error: %v", err)
107+
}
108+
109+
result, err := agg.Result()
110+
if err != nil {
111+
t.Fatalf("unexpected error: %v", err)
112+
}
113+
114+
if result != false {
115+
t.Errorf("expected false, got %v", result)
116+
}
117+
})
118+
119+
t.Run("one true value", func(t *testing.T) {
120+
agg := NewResponseAggregator(RespAggLogicalOr, "")
121+
122+
err := agg.Add(false, nil)
123+
if err != nil {
124+
t.Fatalf("unexpected error: %v", err)
125+
}
126+
127+
err = agg.Add(true, nil)
128+
if err != nil {
129+
t.Fatalf("unexpected error: %v", err)
130+
}
131+
132+
err = agg.Add(false, nil)
133+
if err != nil {
134+
t.Fatalf("unexpected error: %v", err)
135+
}
136+
137+
result, err := agg.Result()
138+
if err != nil {
139+
t.Fatalf("unexpected error: %v", err)
140+
}
141+
142+
if result != true {
143+
t.Errorf("expected true, got %v", result)
144+
}
145+
})
146+
147+
t.Run("no results", func(t *testing.T) {
148+
agg := NewResponseAggregator(RespAggLogicalOr, "")
149+
150+
_, err := agg.Result()
151+
if err != ErrOrAggregation {
152+
t.Errorf("expected ErrOrAggregation, got %v", err)
153+
}
154+
})
155+
156+
t.Run("with error", func(t *testing.T) {
157+
agg := NewResponseAggregator(RespAggLogicalOr, "")
158+
159+
testErr := errors.New("test error")
160+
err := agg.Add(nil, testErr)
161+
if err != nil {
162+
t.Fatalf("unexpected error: %v", err)
163+
}
164+
165+
_, err = agg.Result()
166+
if err != testErr {
167+
t.Errorf("expected test error, got %v", err)
168+
}
169+
})
170+
}
171+
172+
func TestAggLogicalAndBatchAdd(t *testing.T) {
173+
t.Run("batch add all true", func(t *testing.T) {
174+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
175+
176+
results := map[string]AggregatorResErr{
177+
"key1": {Result: true, Err: nil},
178+
"key2": {Result: int64(1), Err: nil},
179+
"key3": {Result: 1, Err: nil},
180+
}
181+
182+
err := agg.BatchAdd(results)
183+
if err != nil {
184+
t.Fatalf("unexpected error: %v", err)
185+
}
186+
187+
result, err := agg.Result()
188+
if err != nil {
189+
t.Fatalf("unexpected error: %v", err)
190+
}
191+
192+
if result != true {
193+
t.Errorf("expected true, got %v", result)
194+
}
195+
})
196+
197+
t.Run("batch add with false", func(t *testing.T) {
198+
agg := NewResponseAggregator(RespAggLogicalAnd, "")
199+
200+
results := map[string]AggregatorResErr{
201+
"key1": {Result: true, Err: nil},
202+
"key2": {Result: false, Err: nil},
203+
"key3": {Result: true, Err: nil},
204+
}
205+
206+
err := agg.BatchAdd(results)
207+
if err != nil {
208+
t.Fatalf("unexpected error: %v", err)
209+
}
210+
211+
result, err := agg.Result()
212+
if err != nil {
213+
t.Fatalf("unexpected error: %v", err)
214+
}
215+
216+
if result != false {
217+
t.Errorf("expected false, got %v", result)
218+
}
219+
})
220+
}
221+
222+
func TestAggLogicalOrBatchAdd(t *testing.T) {
223+
t.Run("batch add all false", func(t *testing.T) {
224+
agg := NewResponseAggregator(RespAggLogicalOr, "")
225+
226+
results := map[string]AggregatorResErr{
227+
"key1": {Result: false, Err: nil},
228+
"key2": {Result: int64(0), Err: nil},
229+
"key3": {Result: 0, Err: nil},
230+
}
231+
232+
err := agg.BatchAdd(results)
233+
if err != nil {
234+
t.Fatalf("unexpected error: %v", err)
235+
}
236+
237+
result, err := agg.Result()
238+
if err != nil {
239+
t.Fatalf("unexpected error: %v", err)
240+
}
241+
242+
if result != false {
243+
t.Errorf("expected false, got %v", result)
244+
}
245+
})
246+
247+
t.Run("batch add with true", func(t *testing.T) {
248+
agg := NewResponseAggregator(RespAggLogicalOr, "")
249+
250+
results := map[string]AggregatorResErr{
251+
"key1": {Result: false, Err: nil},
252+
"key2": {Result: true, Err: nil},
253+
"key3": {Result: false, Err: nil},
254+
}
255+
256+
err := agg.BatchAdd(results)
257+
if err != nil {
258+
t.Fatalf("unexpected error: %v", err)
259+
}
260+
261+
result, err := agg.Result()
262+
if err != nil {
263+
t.Fatalf("unexpected error: %v", err)
264+
}
265+
266+
if result != true {
267+
t.Errorf("expected true, got %v", result)
268+
}
269+
})
270+
}
271+

0 commit comments

Comments
 (0)