Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,16 @@ func (b *bucket) Flush(reinsert func(*Timer)) {

b.SetExpiration(-1)
}

// Clear timer list
func (b *bucket) Clear() {
b.mu.Lock()
defer b.mu.Unlock()
b.timers = b.timers.Init()
}

func (b *bucket) Len() int {
b.mu.Lock()
defer b.mu.Unlock()
return b.timers.Len()
}
17 changes: 17 additions & 0 deletions delayqueue/delayqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ func (pq *priorityQueue) PeekAndShift(max int64) (*item, int64) {
return item, 0
}

// Clear the priorityQueue
func (pq *priorityQueue) Clear() {
if pq.Len() != 0 {
*pq = nil
}
}

// The end of PriorityQueue implementation.

// DelayQueue is an unbounded blocking queue of *Delayed* elements, in which
Expand Down Expand Up @@ -184,3 +191,13 @@ exit:
// Reset the states
atomic.StoreInt32(&dq.sleeping, 0)
}

func (dq *DelayQueue) Len() int {
return dq.pq.Len()
}

func (dq *DelayQueue) Clear() {
dq.mu.Lock()
dq.pq.Clear()
dq.mu.Unlock()
}
53 changes: 53 additions & 0 deletions timingwheel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package timingwheel

import (
"errors"
"sync"
"sync/atomic"
"time"
"unsafe"
Expand All @@ -16,6 +17,7 @@ type TimingWheel struct {

interval int64 // in milliseconds
currentTime int64 // in milliseconds
mu sync.RWMutex
buckets []*bucket
queue *delayqueue.DelayQueue

Expand Down Expand Up @@ -71,7 +73,13 @@ func (tw *TimingWheel) add(t *Timer) bool {
} else if t.expiration < currentTime+tw.interval {
// Put it into its own bucket
virtualID := t.expiration / tw.tick
tw.mu.RLock()
if tw.buckets == nil {
tw.mu.RUnlock()
return false
}
b := tw.buckets[virtualID%tw.wheelSize]
tw.mu.RUnlock()
b.Add(t)

// Set the bucket expiration time
Expand Down Expand Up @@ -109,6 +117,9 @@ func (tw *TimingWheel) add(t *Timer) bool {
// addOrRun inserts the timer t into the current timing wheel, or run the
// timer's task if it has already expired.
func (tw *TimingWheel) addOrRun(t *Timer) {
if tw.IsStopped() {
return
}
if !tw.add(t) {
// Already expired

Expand Down Expand Up @@ -160,8 +171,41 @@ func (tw *TimingWheel) Start() {
// not wait for the task to complete before returning. If the caller needs to
// know whether the task is completed, it must coordinate with the task explicitly.
func (tw *TimingWheel) Stop() {
if tw.IsStopped() {
return
}
close(tw.exitC)
tw.waitGroup.Wait()
tw.clear()
}

func (tw *TimingWheel) clear() {
tw.queue.Clear()
tw.mu.Lock()
for _, b := range tw.buckets {
b.Clear()
}
tw.buckets = nil
Comment thread
JasonSongHoho marked this conversation as resolved.
tw.mu.Unlock()
// Try to clear the overflow wheel if present
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel != nil {
(*TimingWheel)(overflowWheel).clear()
}
}

func (tw *TimingWheel) Len() int {
l := 0
tw.mu.Lock()
Comment thread
JasonSongHoho marked this conversation as resolved.
for i := 0; i < len(tw.buckets); i++ {
l += tw.buckets[i].Len()
}
tw.mu.Unlock()
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel != nil {
l += (*TimingWheel)(overflowWheel).Len()
}
return l
}

// AfterFunc waits for the duration to elapse and then calls f in its own goroutine.
Expand Down Expand Up @@ -224,3 +268,12 @@ func (tw *TimingWheel) ScheduleFunc(s Scheduler, f func()) (t *Timer) {

return
}

func (tw *TimingWheel) IsStopped() bool {
select {
case <-tw.exitC:
return true
default:
}
return false
}
16 changes: 16 additions & 0 deletions timingwheel_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,19 @@ func BenchmarkStandardTimer_StartStop(b *testing.B) {
})
}
}

func BenchmarkTimingWheel_KeepStartStop(b *testing.B) {
var tw *timingwheel.TimingWheel
for j := 0; j < 10; j++ {
b.ResetTimer()
tw = timingwheel.NewTimingWheel(1*time.Minute, 20)
tw.Start()
l := 100
for i := 0; i < l; i++ {
tw.AfterFunc(time.Duration(i+1)*time.Minute, func() {
})
}
tw.Stop()
b.StopTimer()
}
}
98 changes: 98 additions & 0 deletions timingwheel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,101 @@ func TestTimingWheel_ScheduleFunc(t *testing.T) {
}
}
}

func TestTimingWheel_IsStopped(t *testing.T) {
tw := timingwheel.NewTimingWheel(time.Millisecond, 20)
tw.Start()
if tw.IsStopped() {
t.Errorf("IsStopped() = true before stop")
}
tw.Stop()
if !tw.IsStopped() {
t.Errorf("IsStopped() = false after stop")
}
// test stop 2 times
tw.Stop()
}

func TestTimingWheel_Len(t *testing.T) {
type fields struct {
tw *timingwheel.TimingWheel
len int
}
tests := []struct {
name string
fields fields
want int
}{
{
name: "",
fields: fields{
tw: timingwheel.NewTimingWheel(1*time.Millisecond, 20),
len: 0,
},
want: 0,
},
{
name: "",
fields: fields{
tw: timingwheel.NewTimingWheel(1*time.Second, 20),
len: 100,
},
want: 100,
},
{
name: "",
fields: fields{
tw: timingwheel.NewTimingWheel(1*time.Minute, 20),
len: 100,
},
want: 100,
},
{
name: "",
fields: fields{
tw: timingwheel.NewTimingWheel(1*time.Minute, 20),
len: 10000,
},
want: 10000,
},
{
name: "",
fields: fields{
tw: timingwheel.NewTimingWheel(1*time.Minute, 200),
len: 100,
},
want: 100,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tw := tt.fields.tw
tw.Start()
defer tw.Stop()
for i := 0; i < tt.fields.len; i++ {
tw.AfterFunc(time.Duration(i+1)*time.Minute, func() {
})
}
if got := tw.Len(); got != tt.want {
t.Errorf("Len() = %v, want %v", got, tt.want)
}
})
}
}

func TestTimingWheel_clear(t *testing.T) {
Comment thread
JasonSongHoho marked this conversation as resolved.
tw := timingwheel.NewTimingWheel(1*time.Minute, 20)
tw.Start()
l := 10000
for i := 0; i < l; i++ {
tw.AfterFunc(time.Duration(i+1)*time.Minute, func() {
})
}
if tw.Len() != l {
t.Errorf("add events fail")
}
tw.Stop()
if tw.Len() != 0 {
t.Errorf("clear events fail. tw.Len(): %d", tw.Len())
}
}