diff --git a/.gitignore b/.gitignore index 2394557..a84776a 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,10 @@ cover.out examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js + +*.pdf +*.png +*.stderr +*.jsonl +*.json +__pycache__/ \ No newline at end of file diff --git a/gcc/arrival_group_accumulator.go b/gcc/arrival_group_accumulator.go new file mode 100644 index 0000000..9cc8719 --- /dev/null +++ b/gcc/arrival_group_accumulator.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" +) + +type arrivalGroupItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + Size int +} + +type arrivalGroup []arrivalGroupItem + +type arrivalGroupAccumulator struct { + next arrivalGroup + burstInterval time.Duration + maxBurstDuration time.Duration +} + +func newArrivalGroupAccumulator() *arrivalGroupAccumulator { + return &arrivalGroupAccumulator{ + next: make([]arrivalGroupItem, 0), + burstInterval: 5 * time.Millisecond, + maxBurstDuration: 5 * time.Millisecond, + } +} + +func (a *arrivalGroupAccumulator) onPacketAcked( + sequenceNumber uint64, + size int, + departure, arrival time.Time, +) arrivalGroup { + if len(a.next) == 0 { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + sendTimeDelta := departure.Sub(a.next[0].Departure) + if sendTimeDelta < a.burstInterval { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + arrivalTimeDeltaFirst := arrival.Sub(a.next[0].Arrival) + propagationDelta := arrivalTimeDeltaFirst - sendTimeDelta + + if propagationDelta < 0 && arrivalTimeDeltaFirst < a.maxBurstDuration { + a.next = append(a.next, arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }) + + return nil + } + + group := make(arrivalGroup, len(a.next)) + copy(group, a.next) + a.next = arrivalGroup{arrivalGroupItem{ + SequenceNumber: sequenceNumber, + Size: size, + Departure: departure, + Arrival: arrival, + }} + + return group +} diff --git a/gcc/arrival_group_accumulator_test.go b/gcc/arrival_group_accumulator_test.go new file mode 100644 index 0000000..213856b --- /dev/null +++ b/gcc/arrival_group_accumulator_test.go @@ -0,0 +1,244 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestArrivalGroupAccumulator(t *testing.T) { + type logItem struct { + SequenceNumber uint64 + Departure time.Time + Arrival time.Time + } + triggerNewGroupElement := logItem{ + Departure: time.Time{}.Add(time.Second), + Arrival: time.Time{}.Add(time.Second), + } + cases := []struct { + name string + log []logItem + exp []arrivalGroup + }{ + { + name: "emptyCreatesNoGroups", + log: []logItem{}, + exp: []arrivalGroup{}, + }, + { + name: "createsSingleElementGroup", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(time.Millisecond), + }, + }, + }, + }, + { + name: "createsTwoElementGroup", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }}, + }, + { + name: "createsTwoArrivalGroups1", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(20 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(24 * time.Millisecond), + }, + }, + }, + }, + { + name: "ignoresOutOfOrderPackets", + log: []logItem{ + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + Departure: time.Time{}, + Arrival: time.Time{}.Add(15 * time.Millisecond), + }, + }, + { + { + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(34 * time.Millisecond), + }, + { + Departure: time.Time{}.Add(8 * time.Millisecond), + Arrival: time.Time{}.Add(30 * time.Millisecond), + }, + }, + }, + }, + { + name: "newGroupBecauseOfInterDepartureTime", + log: []logItem{ + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(3 * time.Millisecond), + Arrival: time.Time{}.Add(4 * time.Millisecond), + }, + }, + { + { + SequenceNumber: 2, + Departure: time.Time{}.Add(6 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 3, + Departure: time.Time{}.Add(9 * time.Millisecond), + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + }, + }, + }, + { + name: "createsSingleGroupArrivalBurst", + log: []logItem{ + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(10 * time.Millisecond), + Arrival: time.Time{}.Add(12 * time.Millisecond), + }, + triggerNewGroupElement, + }, + exp: []arrivalGroup{ + { + { + SequenceNumber: 0, + Departure: time.Time{}, + Arrival: time.Time{}.Add(10 * time.Millisecond), + }, + { + SequenceNumber: 1, + Departure: time.Time{}.Add(10 * time.Millisecond), + Arrival: time.Time{}.Add(12 * time.Millisecond), + }, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + aga := newArrivalGroupAccumulator() + received := []arrivalGroup{} + for _, ack := range tc.log { + next := aga.onPacketAcked(ack.SequenceNumber, 0, ack.Departure, ack.Arrival) + if next != nil { + received = append(received, next) + } + } + assert.Equal(t, tc.exp, received) + }) + } +} diff --git a/gcc/delay_rate_controller.go b/gcc/delay_rate_controller.go new file mode 100644 index 0000000..ebcba6b --- /dev/null +++ b/gcc/delay_rate_controller.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" + + "github.com/pion/logging" +) + +const ( + defaultDecreaseFactor = 0.85 +) + +type delayRateController struct { + log logging.LeveledLogger + decreaseFactor float64 + arrivalGroups *arrivalGroupAccumulator + lastArrivalGroup arrivalGroup + trend *trendlineEstimator + overuse *overuseDetector + samples int + usage usage + state state + lastDecreaseRate *ewma + lastUpdate time.Time + targetRate int + minTarget int + maxTarget int +} + +func newDelayRateController(initialRate, minRate, maxRate int, logger logging.LeveledLogger) *delayRateController { + return &delayRateController{ + log: logger, + decreaseFactor: defaultDecreaseFactor, + arrivalGroups: newArrivalGroupAccumulator(), + lastArrivalGroup: []arrivalGroupItem{}, + trend: newTrendlineEstimator(), + overuse: newOveruseDetector(false), + usage: 0, + samples: 0, + state: 0, + lastDecreaseRate: newEWMA(0.95), + targetRate: initialRate, + minTarget: minRate, + maxTarget: maxRate, + } +} + +func (c *delayRateController) onPacketAcked(sequenceNumber uint64, size int, departure, arrival time.Time) { + next := c.arrivalGroups.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) + if next == nil { + return + } + if len(next) == 0 { + // ignore empty groups, should never occur + return + } + if len(c.lastArrivalGroup) == 0 { + c.lastArrivalGroup = next + + return + } + + interArrivalTime := next[len(next)-1].Arrival.Sub(c.lastArrivalGroup[len(c.lastArrivalGroup)-1].Arrival) + interDepartureTime := next[len(next)-1].Departure.Sub(c.lastArrivalGroup[len(c.lastArrivalGroup)-1].Departure) + interGroupDelay := interArrivalTime - interDepartureTime + + trend := c.trend.update(arrival, interGroupDelay) + c.samples++ + c.usage = c.overuse.update(arrival, trend, c.samples) + c.lastArrivalGroup = next + + c.log.Tracef( + "ts=%v.%06d, seq=%v, interArrivalTime=%v, interDepartureTime=%v, interGroupDelay=%v, estimate=%f, threshold=%f, usage=%v, state=%v", // nolint + c.lastArrivalGroup[0].Departure.UTC().Format("2006/01/02 15:04:05"), + c.lastArrivalGroup[0].Departure.UTC().Nanosecond()/1e3, + next[0].SequenceNumber, + interArrivalTime.Microseconds(), + interDepartureTime.Microseconds(), + interGroupDelay.Microseconds(), + trend, + c.overuse.delayThreshold, + int(c.usage), + int(c.state), + ) +} + +func (c *delayRateController) update(ts time.Time, deliveryRate int, rtt time.Duration) int { + deliveredRate := float64(deliveryRate) + c.state = c.state.transition(c.usage) + if c.state == stateIncrease { + window := ts.Sub(c.lastUpdate) + if c.canIncreaseMultiplicatively(deliveredRate) { + c.targetRate = max(c.targetRate, multiplicativeIncrease(c.targetRate, window)) + } else { + c.targetRate = additiveIncrease(c.targetRate, rtt, window) + } + c.targetRate = min(c.targetRate, int(1.5*deliveredRate)) + } + if c.state == stateDecrease { + c.lastDecreaseRate.update(float64(deliveryRate)) + c.targetRate = int(c.decreaseFactor * float64(deliveryRate)) + } + c.lastUpdate = ts + + c.targetRate = max(c.targetRate, c.minTarget) + c.targetRate = min(c.targetRate, c.maxTarget) + + return c.targetRate +} + +func (c *delayRateController) canIncreaseMultiplicatively(deliveredRate float64) bool { + avg := c.lastDecreaseRate.avg() + if avg == 0 { + return true + } + stdDev := math.Sqrt(c.lastDecreaseRate.varr()) + lower := avg - 3*stdDev + upper := avg + 3*stdDev + + return deliveredRate < lower || deliveredRate > upper +} + +func multiplicativeIncrease(rate int, window time.Duration) int { + exponent := min(window.Seconds(), 1.0) + eta := math.Pow(1.08, exponent) + + return int(eta * float64(rate)) +} + +func additiveIncrease(rate int, rtt, window time.Duration) int { + responseTime := 100 + rtt.Milliseconds() + alpha := 0.5 * min(float64(window.Milliseconds())/float64(responseTime), 1.0) + bitsPerFrame := float64(rate) / 30.0 + packetsPerFrame := math.Ceil(bitsPerFrame / (1200 * 8)) + expectedPacketSizeBits := bitsPerFrame / packetsPerFrame + + return rate + max(1000, int(alpha*float64(expectedPacketSizeBits))) +} diff --git a/gcc/delay_rate_controller_test.go b/gcc/delay_rate_controller_test.go new file mode 100644 index 0000000..74ad75e --- /dev/null +++ b/gcc/delay_rate_controller_test.go @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDelayRateController(t *testing.T) { + t.Run("init", func(t *testing.T) { + controller := newDelayRateController(1_000_000, 500_000, 2_000_000, nil) + assert.Nil(t, controller.log) + assert.Equal(t, controller.decreaseFactor, defaultDecreaseFactor) + assert.NotNil(t, controller.arrivalGroups) + assert.NotNil(t, controller.lastArrivalGroup) + assert.NotNil(t, controller.trend) + assert.NotNil(t, controller.overuse) + assert.Equal(t, controller.samples, 0) + assert.Equal(t, controller.usage, usage(0)) + assert.Equal(t, controller.state, state(0)) + assert.NotNil(t, controller.lastDecreaseRate) + assert.Zero(t, controller.lastUpdate) + assert.Equal(t, controller.minTarget, 500_000) + assert.Equal(t, controller.maxTarget, 2_000_000) + assert.Equal(t, controller.targetRate, 1_000_000) + }) + + t.Run("canIncreaseMultiplicatively", func(t *testing.T) { + cases := []struct { + deliveredRate float64 + decreaseRate ewma + expected bool + }{ + {deliveredRate: 1000, decreaseRate: ewma{average: 0, variance: 0}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 1500, variance: 100}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 1020, variance: 100}, expected: false}, + {deliveredRate: 1000, decreaseRate: ewma{average: 800, variance: 50}, expected: true}, + {deliveredRate: 1000, decreaseRate: ewma{average: 995, variance: 100}, expected: false}, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + controller := newDelayRateController(1000, 500, 2000, nil) + controller.lastDecreaseRate = &c.decreaseRate + assert.Equal(t, c.expected, controller.canIncreaseMultiplicatively(c.deliveredRate)) + }) + } + }) + + t.Run("multiplicativeIncrease", func(t *testing.T) { + cases := []struct { + initialRate int + rate int + window time.Duration + expected float64 + }{ + {initialRate: 1000, rate: 1000, window: 100 * time.Millisecond, expected: 1007}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := multiplicativeIncrease(c.rate, c.window) + assert.InDelta(t, res, c.expected, 1) + }) + } + }) + + t.Run("additiveIncrease", func(t *testing.T) { + cases := []struct { + initialRate int + rate int + window time.Duration + expected int + }{ + {initialRate: 1000, rate: 1000, window: 100 * time.Millisecond, expected: 2000}, + {initialRate: 1_000_000, rate: 1_500_000, window: 100 * time.Millisecond, expected: 1_500_000 + 2083}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + res := additiveIncrease(c.rate, 100*time.Millisecond, c.window) + assert.InDelta(t, res, c.expected, 1) + }) + } + }) +} diff --git a/gcc/overuse_detector.go b/gcc/overuse_detector.go new file mode 100644 index 0000000..53b60a9 --- /dev/null +++ b/gcc/overuse_detector.go @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "math" + "time" +) + +const ( + kUp = 0.0087 + kDown = 0.039 + + minNumDeltas = 60 +) + +const ( + defaultThresholdGain = 4.0 + defaultOveruseTimeThreshold = 5 * time.Millisecond +) + +type overuseDetector struct { + adaptiveThreshold bool + thresholdGain float64 + overUseTimeThreshold time.Duration + delayThreshold float64 + lastUpdate time.Time + firstOverUse time.Time + overUseCounter int + previousTrend float64 +} + +func newOveruseDetector(adaptive bool) *overuseDetector { + return &overuseDetector{ + adaptiveThreshold: adaptive, + thresholdGain: defaultThresholdGain, + overUseTimeThreshold: defaultOveruseTimeThreshold, + delayThreshold: 6, + lastUpdate: time.Time{}, + firstOverUse: time.Time{}, + overUseCounter: 0, + previousTrend: 0, + } +} + +func (d *overuseDetector) update(ts time.Time, trend float64, numDeltas int) usage { + if d.lastUpdate.IsZero() { + d.lastUpdate = ts + } + if numDeltas < 2 { + return usageNormal + } + modifiedTrend := float64(min(numDeltas, minNumDeltas)) * trend * d.thresholdGain + + var currentUsage usage + switch { + case modifiedTrend > d.delayThreshold: + if d.firstOverUse.IsZero() { + delta := ts.Sub(d.lastUpdate) + d.firstOverUse = ts.Add(-delta / 2) + } + d.overUseCounter++ + if ts.Sub(d.firstOverUse) > d.overUseTimeThreshold && d.overUseCounter > 1 && trend >= d.previousTrend { + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageOver + } + case modifiedTrend < -d.delayThreshold: + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageUnder + default: + d.firstOverUse = time.Time{} + d.overUseCounter = 0 + currentUsage = usageNormal + } + d.adaptThreshold(ts, modifiedTrend) + d.previousTrend = trend + d.lastUpdate = ts + + return currentUsage +} + +func (d *overuseDetector) adaptThreshold(ts time.Time, modifiedTrend float64) { + if !d.adaptiveThreshold { + return + } + if math.Abs(modifiedTrend) > d.delayThreshold+15 { + return + } + k := kUp + if math.Abs(modifiedTrend) < d.delayThreshold { + k = kDown + } + delta := min(ts.Sub(d.lastUpdate), 100*time.Millisecond) + d.delayThreshold += k * (math.Abs(modifiedTrend) - d.delayThreshold) * float64(delta.Milliseconds()) + d.delayThreshold = min(d.delayThreshold, 600.0) + d.delayThreshold = max(d.delayThreshold, 6.0) +} diff --git a/gcc/overuse_detector_test.go b/gcc/overuse_detector_test.go new file mode 100644 index 0000000..58d4555 --- /dev/null +++ b/gcc/overuse_detector_test.go @@ -0,0 +1,194 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestOveruseDetectorUpdate(t *testing.T) { + type estimate struct { + ts time.Time + estimate float64 + numDeltas int + } + cases := []struct { + name string + adaptive bool + values []estimate + expected []usage + }{ + { + name: "noEstimateNoUsageStatic", + adaptive: false, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseStatic", + adaptive: false, + values: []estimate{ + {time.Time{}, 1.0, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseStatic", + adaptive: false, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseStatic", + adaptive: false, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedStatic", + adaptive: false, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + { + name: "noEstimateNoUsageAdaptive", + adaptive: true, + values: []estimate{}, + expected: []usage{}, + }, + { + name: "overuseAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}, 1, 1}, + {time.Time{}.Add(5 * time.Millisecond), 20, 2}, + {time.Time{}.Add(20 * time.Millisecond), 30, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "normaluseAdaptive", + adaptive: true, + values: []estimate{{estimate: 0}}, + expected: []usage{usageNormal}, + }, + { + name: "underuseAdaptive", + adaptive: true, + values: []estimate{{time.Time{}, -20, 2}}, + expected: []usage{usageUnder}, + }, + { + name: "noOverUseBeforeDelayAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(2 * time.Millisecond), 30, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + }, + expected: []usage{usageNormal, usageNormal, usageOver}, + }, + { + name: "noOverUseIfEstimateDecreasedAdaptive", + adaptive: true, + values: []estimate{ + {time.Time{}.Add(time.Millisecond), 20, 1}, + {time.Time{}.Add(10 * time.Millisecond), 40, 2}, + {time.Time{}.Add(30 * time.Millisecond), 50, 3}, + {time.Time{}.Add(35 * time.Millisecond), 3, 4}, + }, + expected: []usage{usageNormal, usageNormal, usageOver, usageNormal}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + od := newOveruseDetector(tc.adaptive) + received := []usage{} + for _, e := range tc.values { + u := od.update(e.ts, e.estimate, e.numDeltas) + received = append(received, u) + } + assert.Equal(t, tc.expected, received) + }) + } +} + +func TestOveruseDetectorAdaptThreshold(t *testing.T) { + cases := []struct { + name string + od *overuseDetector + ts time.Time + estimate float64 + expectedThreshold float64 + }{ + { + name: "minThreshold", + od: &overuseDetector{ + adaptiveThreshold: true, + }, + ts: time.Time{}, + estimate: 0, + expectedThreshold: 6, + }, + { + name: "increase", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 12.5, + lastUpdate: time.Time{}.Add(time.Second), + }, + ts: time.Time{}.Add(2 * time.Second), + estimate: 25, + expectedThreshold: 23.375, + }, + { + name: "maxThreshold", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 600, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(time.Second), + estimate: 610, + expectedThreshold: 600, + }, + { + name: "decrease", + od: &overuseDetector{ + adaptiveThreshold: true, + delayThreshold: 12.5, + lastUpdate: time.Time{}, + }, + ts: time.Time{}.Add(10 * time.Millisecond), + estimate: 1, + expectedThreshold: 8.015, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tc.od.adaptThreshold(tc.ts, tc.estimate) + assert.Equal(t, tc.expectedThreshold, tc.od.delayThreshold) + }) + } +} diff --git a/gcc/send_side_bwe.go b/gcc/send_side_bwe.go new file mode 100644 index 0000000..06b0bef --- /dev/null +++ b/gcc/send_side_bwe.go @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +package gcc + +import ( + "time" + + "github.com/pion/logging" +) + +// Option is a functional option for a SendSideController. +type Option func(*SendSideController) error + +// WithLoggerFactory configures a custom logger factory for a +// SendSideController. +func WithLoggerFactory(lf logging.LoggerFactory) Option { + return func(ssc *SendSideController) error { + ssc.logFactory = lf + + return nil + } +} + +// SendSideController is a sender side congestion controller. +type SendSideController struct { + logFactory logging.LoggerFactory + log logging.LeveledLogger + dre *deliveryRateEstimator + lrc *lossRateController + drc *delayRateController + targetRate int +} + +// NewSendSideController creates a new SendSideController with initial, min and +// max rates. +func NewSendSideController(initialRate, minRate, maxRate int, opts ...Option) (*SendSideController, error) { + ssc := &SendSideController{ + logFactory: logging.NewDefaultLoggerFactory(), + dre: newDeliveryRateEstimator(time.Second), + lrc: newLossRateController(initialRate, minRate, maxRate), + targetRate: initialRate, + } + for _, opt := range opts { + if err := opt(ssc); err != nil { + return nil, err + } + } + ssc.log = ssc.logFactory.NewLogger("bwe_send_side_controller") + ssc.drc = newDelayRateController(initialRate, minRate, maxRate, ssc.logFactory.NewLogger("bwe_delay_rate_controller")) + + return ssc, nil +} + +func (c *SendSideController) OnLoss() { + c.lrc.onPacketLost() +} + +// OnAck must be called when new acknowledgments arrive. Packets MUST not be +// acknowledged more than once. +func (c *SendSideController) OnAck(sequenceNumber uint64, size int, departure, arrival time.Time) { + c.lrc.onPacketAcked() + if !arrival.IsZero() { + c.dre.onPacketAcked(arrival, size) + c.drc.onPacketAcked( + sequenceNumber, + size, + departure, + arrival, + ) + } +} + +// OnFeedback must be called when a new feedback report arrives. ts is the +// arrival timestamp of the feedback report. rtt is the latest RTT sample. It +// returns the new target rate. +func (c *SendSideController) OnFeedback(ts time.Time, rtt time.Duration) int { + delivered := c.dre.getRate() + lossTarget := c.lrc.update(delivered) + delayTarget := c.drc.update(ts, delivered, rtt) + c.targetRate = min(lossTarget, delayTarget) + c.log.Tracef( + "rtt=%v, delivered=%v, lossTarget=%v, delayTarget=%v, target=%v", + rtt.Nanoseconds(), + delivered, + lossTarget, + delayTarget, + c.targetRate, + ) + + return c.targetRate +} diff --git a/go.mod b/go.mod index 2cc1f58..082d05b 100644 --- a/go.mod +++ b/go.mod @@ -7,28 +7,29 @@ require ( github.com/pion/logging v0.2.4 github.com/pion/rtcp v1.2.16 github.com/pion/rtp v1.10.2 - github.com/pion/transport/v3 v3.1.1 - github.com/pion/webrtc/v4 v4.1.4 + github.com/pion/sdp/v3 v3.0.18 + github.com/pion/transport/v4 v4.0.1 + github.com/pion/webrtc/v4 v4.2.11 github.com/stretchr/testify v1.11.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/pion/datachannel v1.5.10 // indirect - github.com/pion/dtls/v3 v3.0.7 // indirect - github.com/pion/ice/v4 v4.0.10 // indirect - github.com/pion/mdns/v2 v2.0.7 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/ice/v4 v4.2.2 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/sctp v1.8.39 // indirect - github.com/pion/sdp/v3 v3.0.15 // indirect - github.com/pion/srtp/v3 v3.0.7 // indirect - github.com/pion/stun/v3 v3.0.0 // indirect - github.com/pion/turn/v4 v4.1.1 // indirect + github.com/pion/sctp v1.9.4 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pion/turn/v4 v4.1.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/wlynxg/anet v0.0.5 // indirect - golang.org/x/crypto v0.33.0 // indirect - golang.org/x/net v0.35.0 // indirect - golang.org/x/sys v0.30.0 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/net v0.50.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/time v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 062021e..cf5dc44 100644 --- a/go.sum +++ b/go.sum @@ -6,50 +6,54 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o= -github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M= -github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= -github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= -github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4= -github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/ice/v4 v4.2.2 h1:dQJzzcgTFHDYyV3BoCfjPeX+JEtr58BWPi4PGyo6Vjg= +github.com/pion/ice/v4 v4.2.2/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c= github.com/pion/interceptor v0.1.45 h1:6PUo/5829bIfRFIPPJQzuDn8EjxRTSB/CSD7QVCOaqo= github.com/pion/interceptor v0.1.45/go.mod h1:gNDYM/uFKcLe/B3gS2/7+aw6z+RDiMy2qKTnF1LO31w= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= -github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= -github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= github.com/pion/rtp v1.10.2 h1:l+f6tTDcAH6xwepaAoW791ddhuYsJlqRATOzirO04Mo= github.com/pion/rtp v1.10.2/go.mod h1:Au8fc6cEByy8RLTwKTQTEeQqDB/SJDxwL4mZuxYA5Pk= -github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE= -github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= -github.com/pion/sdp/v3 v3.0.15 h1:F0I1zds+K/+37ZrzdADmx2Q44OFDOPRLhPnNTaUX9hk= -github.com/pion/sdp/v3 v3.0.15/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= -github.com/pion/srtp/v3 v3.0.7 h1:QUElw0A/FUg3MP8/KNMZB3i0m8F9XeMnTum86F7S4bs= -github.com/pion/srtp/v3 v3.0.7/go.mod h1:qvnHeqbhT7kDdB+OGB05KA/P067G3mm7XBfLaLiaNF0= -github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= -github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= +github.com/pion/sctp v1.9.4 h1:cMxEu0F5tbP4qH07bKf1Zjf4rUih9LIo0qQt424e258= +github.com/pion/sctp v1.9.4/go.mod h1:N20Dq6LY+JvJDAh9VVh1JELngb2rQ8dPgds5yBWiPgw= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= -github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= -github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= -github.com/pion/webrtc/v4 v4.1.4 h1:/gK1ACGHXQmtyVVbJFQDxNoODg4eSRiFLB7t9r9pg8M= -github.com/pion/webrtc/v4 v4.1.4/go.mod h1:Oab9npu1iZtQRMic3K3toYq5zFPvToe/QBw7dMI2ok4= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.11 h1:QUX1QZKlNIn4O7U5JxLPGP0sV5RTncZkzu9SPR3jVNU= +github.com/pion/webrtc/v4 v4.2.11/go.mod h1:s/rAiyy77GyRFrZMx+Ls6aua26dIBPudH8/ZHYbIRWY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/simulation/bwe_test.go b/simulation/bwe_test.go new file mode 100644 index 0000000..3b2af4c --- /dev/null +++ b/simulation/bwe_test.go @@ -0,0 +1,385 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js && go1.25 && simulation + +package simulation + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log" + "log/slog" + "math" + "os" + "path/filepath" + "strings" + "testing" + "testing/synctest" + "time" + + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/assert" +) + +var logDir string + +type vnetFactory func(*testing.T) *virtualNetwork + +func TestMain(m *testing.M) { + logDir = os.Getenv("BWE_LOG_DIR") + if logDir == "" { + logDir = "logs/" + } + if err := os.MkdirAll(logDir, 0o750); err != nil { + log.Printf("failed to create log dir %q: %v", logDir, err) + os.Exit(1) + } + if err := os.Setenv("PION_LOG_TRACE", "bwe_send_side_controller,bwe_delay_rate_controller,bwe_test_peer,perfect_codec,ccfb_interceptor,bwe_arrival_group_accumulator"); err != nil { + log.Printf("failed to set pion logger environment variable") + os.Exit(1) + } + + ec := m.Run() + + files, err := filepath.Glob(filepath.Join(logDir, "*.jsonl")) + if err != nil { + log.Printf("Failed to list JSONL files: %v", err) + } + + var names []string + for _, f := range files { + names = append(names, filepath.Base(f)) + } + + b, err := json.Marshal(names) + if err != nil { + log.Printf("Failed to marshal index.json: %v", err) + os.Exit(ec) + } + + indexPath := filepath.Join(logDir, "index.json") + if err := os.WriteFile(indexPath, b, 0600); err != nil { + log.Printf("Failed to write index.json: %v", err) + } else { + log.Printf("Generated index.json with %d files", len(names)) + } + + os.Exit(ec) +} + +func TestBWE(t *testing.T) { + networks := map[string]vnetFactory{ + "1mbps-1ms": createVirtualNetwork(1_000_000, 80_000, 1*time.Millisecond), + "5mbps-1ms": createVirtualNetwork(5_000_000, 80_000, 1*time.Millisecond), + "20mbps-1ms": createVirtualNetwork(20_000_000, 80_000, 1*time.Millisecond), + + "1mbps-10ms": createVirtualNetwork(1_000_000, 80_000, 10*time.Millisecond), + "5mbps-10ms": createVirtualNetwork(5_000_000, 80_000, 10*time.Millisecond), + "20mbps-10ms": createVirtualNetwork(20_000_000, 80_000, 10*time.Millisecond), + + "1mbps-50ms": createVirtualNetwork(1_000_000, 80_000, 50*time.Millisecond), + "5mbps-50ms": createVirtualNetwork(5_000_000, 80_000, 50*time.Millisecond), + "20mbps-50ms": createVirtualNetwork(20_000_000, 80_000, 50*time.Millisecond), + + "1mbps-150ms": createVirtualNetwork(1_000_000, 80_000, 150*time.Millisecond), + "5mbps-150ms": createVirtualNetwork(5_000_000, 80_000, 150*time.Millisecond), + "20mbps-150ms": createVirtualNetwork(20_000_000, 80_000, 150*time.Millisecond), + + "1mbps-300ms": createVirtualNetwork(1_000_000, 80_000, 300*time.Millisecond), + "5mbps-300ms": createVirtualNetwork(5_000_000, 80_000, 300*time.Millisecond), + "20mbps-300ms": createVirtualNetwork(20_000_000, 80_000, 300*time.Millisecond), + } + peerOptions := map[string]struct { + receiver []option + sender []option + codecMinRate int + codecMaxRate int + }{ + "gcc-ccfb": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: math.MaxInt, + }, + "gcc-twcc": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: math.MaxInt, + }, + "gcc-ccfb-applimited500": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 500_000, + }, + "gcc-twcc-applimited500": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 500_000, + }, + "gcc-ccfb-applimited1500": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 1_500_000, + }, + "gcc-twcc-applimited1500": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 1_500_000, + }, + "gcc-ccfb-paced": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + registerPacer(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: math.MaxInt, + }, + "gcc-twcc-paced": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerPacer(), + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: math.MaxInt, + }, + "gcc-ccfb-applimited500-paced": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + registerPacer(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 500_000, + }, + "gcc-twcc-applimited500-paced": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerPacer(), + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 500_000, + }, + "gcc-ccfb-applimited1500-paced": { + receiver: []option{ + registerCCFB(), + }, + sender: []option{ + registerPacer(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 1_500_000, + }, + "gcc-twcc-applimited1500-paced": { + receiver: []option{ + registerTWCC(), + }, + sender: []option{ + registerPacer(), + registerTWCCHeaderExtension(), + initGCC(), + }, + codecMinRate: 0, + codecMaxRate: 1_500_000, + }, + } + for netName, vnf := range networks { + for peerName, pos := range peerOptions { + t.Run(fmt.Sprintf("%v-%v", netName, peerName), func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + t.Helper() + + logger, cleanup := testLogger(t) + defer cleanup() + + onTrack := make(chan struct{}) + connected := make(chan struct{}) + done := make(chan struct{}) + + network := vnf(t) + + receiverOptions := []option{ + registerDefaultCodecs(), + setVNet(network.left, []string{"10.0.1.1"}), + onRemoteTrack(func(track *webrtc.TrackRemote) { + close(onTrack) + go func() { + buf := make([]byte, 1500) + for { + select { + case <-done: + return + default: + _, _, err := track.Read(buf) + if errors.Is(err, io.EOF) { + return + } + assert.NoError(t, err) + } + } + }() + }), + registerPacketLogger(logger.With("vantage-point", "receiver")), + } + receiverOptions = append(receiverOptions, pos.receiver...) + receiver, err := newPeer(receiverOptions...) + assert.NoError(t, err) + + err = receiver.addRemoteTrack() + assert.NoError(t, err) + + var codec *perfectCodec + senderOptions := []option{ + registerDefaultCodecs(), + onConnected(func() { close(connected) }), + setVNet(network.right, []string{"10.0.2.1"}), + registerPacketLogger(logger.With("vantage-point", "sender")), + registerRTPFB(), + setOnRateCallback(func(rate int) { + logger.Info("setting codec target bitrate", "rate", rate) + codec.setTargetBitrate(int(0.9 * float64(rate))) + }), + } + senderOptions = append(senderOptions, pos.sender...) + sender, err := newPeer(senderOptions...) + assert.NoError(t, err) + + track, err := sender.addLocalTrack() + assert.NoError(t, err) + + codec = newPerfectCodec( + track, + pos.codecMinRate, + pos.codecMaxRate, + 1_000_000, + ) + go func() { + <-connected + codec.start() + }() + + offer, err := sender.createOffer() + assert.NoError(t, err) + + err = receiver.setRemoteDescription(offer) + assert.NoError(t, err) + + answer, err := receiver.createAnswer() + assert.NoError(t, err) + + err = sender.setRemoteDescription(answer) + assert.NoError(t, err) + + synctest.Wait() + + select { + case <-onTrack: + case <-time.After(5 * time.Second): + assert.Fail(t, "on track not called") + } + + time.Sleep(100 * time.Second) + close(done) + + err = codec.Close() + assert.NoError(t, err) + + err = sender.pc.Close() + assert.NoError(t, err) + + err = receiver.pc.Close() + assert.NoError(t, err) + + err = network.Close() + assert.NoError(t, err) + + synctest.Wait() + }) + }) + } + } +} + +func testLogger(t *testing.T) (*slog.Logger, func()) { + t.Helper() + name, ok := strings.CutPrefix(t.Name(), "TestBWE/") + if !ok { + assert.FailNow(t, "test case with invalid name tried to create logfile") + } + name = strings.ReplaceAll(name, "/", "-") + filename := filepath.Join(logDir, fmt.Sprintf("%s.jsonl", name)) + file, err := os.Create(filename) + if err != nil { + assert.Failf(t, "failed to create log file %q: %v", filename, err) + } + + handler := slog.NewJSONHandler(file, &slog.HandlerOptions{Level: slog.LevelInfo}) + logger := slog.New(handler) + + // Also create a log file for stdout redirects to capture Pions builtin logs + stderrFileName := filepath.Join(logDir, fmt.Sprintf("%s.stderr", name)) + stderrFile, err := os.Create(stderrFileName) + if err != nil { + assert.Failf(t, "failed to create stdout file %q: %v", filename, err) + } + old := os.Stderr + os.Stderr = stderrFile + + cleanup := func() { + os.Stderr = old + assert.NoError(t, file.Sync()) + assert.NoError(t, file.Close()) + assert.NoError(t, stderrFile.Sync()) + assert.NoError(t, stderrFile.Close()) + } + + return logger, cleanup +} diff --git a/simulation/log_format_test.go b/simulation/log_format_test.go index d0dea10..e3bc376 100644 --- a/simulation/log_format_test.go +++ b/simulation/log_format_test.go @@ -1,35 +1,47 @@ // SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT -//go:build !js +//go:build !js && go1.25 && simulation package simulation import ( "fmt" "log/slog" - "time" "github.com/pion/interceptor" "github.com/pion/rtcp" "github.com/pion/rtp" ) +const ( + maxSequenceNumberPlusOne = int64(65536) + breakpoint = 32768 // half of max uint16 +) + type packetLogger struct { - vantagePoint string - direction string + logger *slog.Logger + direction string + seq *unwrapper +} + +func newPacketLogger(logger *slog.Logger, direction string) *packetLogger { + return &packetLogger{ + logger: logger, + direction: direction, + seq: &unwrapper{}, + } } func (l *packetLogger) LogRTPPacket(header *rtp.Header, payload []byte, attributes interceptor.Attributes) { - ts := time.Now() - slog.Info( + u := l.seq.Unwrap(header.SequenceNumber) + l.logger.Info( "rtp", - "vantage-point", l.vantagePoint, "direction", l.direction, - "ts", ts, "pt", header.PayloadType, "ssrc", header.SSRC, "sequence-number", header.SequenceNumber, + "unwrapped-sequence-number", u, "rtp-timestamp", header.Timestamp, "marker", header.Marker, "payload-size", len(payload), @@ -38,6 +50,48 @@ func (l *packetLogger) LogRTPPacket(header *rtp.Header, payload []byte, attribut func (l *packetLogger) LogRTCPPackets(pkts []rtcp.Packet, attributes interceptor.Attributes) { for _, pkt := range pkts { - slog.Info("rtcp", "vantage-point", l.vantagePoint, "direction", l.direction, "type", fmt.Sprintf("%T", pkt)) + l.logger.Info( + "rtcp", + "direction", l.direction, + "type", fmt.Sprintf("%T", pkt), + ) + } +} + +// Unwrapper stores an unwrapped sequence number. +type unwrapper struct { + init bool + lastUnwrapped int64 +} + +func isNewer(value, previous uint16) bool { + if value-previous == breakpoint { + return value > previous } + + return value != previous && (value-previous) < breakpoint +} + +// Unwrap unwraps the next sequencenumber. +func (u *unwrapper) Unwrap(i uint16) int64 { + if !u.init { + u.init = true + u.lastUnwrapped = int64(i) + + return u.lastUnwrapped + } + + lastWrapped := uint16(u.lastUnwrapped) //nolint:gosec // G115 + delta := int64(i - lastWrapped) + if isNewer(i, lastWrapped) { + if delta < 0 { + delta += maxSequenceNumberPlusOne + } + } else if delta > 0 && u.lastUnwrapped+delta-maxSequenceNumberPlusOne >= 0 { + delta -= maxSequenceNumberPlusOne + } + + u.lastUnwrapped += delta + + return u.lastUnwrapped } diff --git a/simulation/main.py b/simulation/main.py new file mode 100755 index 0000000..e02e44c --- /dev/null +++ b/simulation/main.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + + +import argparse +import plots + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('-i', '--input', + help='input directory containing json files', + default='logs') + parser.add_argument('-o', '--output', + help='output directory for generated plot files (png)', + default='logs') + args = parser.parse_args() + plots.plot_all(args.input, args.output) + + +if __name__ == "__main__": + main() diff --git a/simulation/peer_test.go b/simulation/peer_test.go index d979b83..e1604f8 100644 --- a/simulation/peer_test.go +++ b/simulation/peer_test.go @@ -1,20 +1,31 @@ // SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT -//go:build !js +//go:build !js && go1.25 && simulation package simulation import ( + "log/slog" + "time" + + "github.com/pion/bwe/gcc" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/pacing" "github.com/pion/interceptor/pkg/packetdump" "github.com/pion/interceptor/pkg/rfc8888" "github.com/pion/interceptor/pkg/rtpfb" + "github.com/pion/interceptor/pkg/twcc" "github.com/pion/logging" - "github.com/pion/transport/v3/vnet" + "github.com/pion/sdp/v3" + "github.com/pion/transport/v4/vnet" "github.com/pion/webrtc/v4" ) +const ( + feedbackInterval = 20 * time.Millisecond +) + type option func(*peer) error func setVNet(vnet *vnet.Net, publicIPs []string) option { @@ -48,14 +59,14 @@ func registerDefaultCodecs() option { } } -func registerPacketLogger(vantagePoint string) option { +func registerPacketLogger(logger *slog.Logger) option { return func(p *peer) error { - ipl := &packetLogger{vantagePoint: vantagePoint, direction: "in"} + ipl := newPacketLogger(logger, "in") rd, err := packetdump.NewReceiverInterceptor(packetdump.PacketLog(ipl)) if err != nil { return err } - opl := &packetLogger{vantagePoint: vantagePoint, direction: "out"} + opl := newPacketLogger(logger, "out") sd, err := packetdump.NewSenderInterceptor(packetdump.PacketLog(opl)) if err != nil { return err @@ -79,33 +90,41 @@ func registerRTPFB() option { } } -// func registerTWCC() option { -// return func(p *peer) error { -// twcc, err := twcc.NewSenderInterceptor() -// if err != nil { -// return err -// } -// p.interceptorRegistry.Add(twcc) -// -// return nil -// } -// } -// -// func registerTWCCHeaderExtension() option { -// return func(p *peer) error { -// twccHdrExt, err := twcc.NewHeaderExtensionInterceptor() -// if err != nil { -// return err -// } -// p.interceptorRegistry.Add(twccHdrExt) -// -// return nil -// } -// } +func registerTWCC() option { + return func(p *peer) error { + p.mediaEngine.RegisterFeedback(webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}, webrtc.RTPCodecTypeVideo) + if err := p.mediaEngine.RegisterHeaderExtension( + webrtc.RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, webrtc.RTPCodecTypeVideo, + ); err != nil { + return err + } + + p.mediaEngine.RegisterFeedback(webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}, webrtc.RTPCodecTypeAudio) + if err := p.mediaEngine.RegisterHeaderExtension( + webrtc.RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}, webrtc.RTPCodecTypeAudio, + ); err != nil { + return err + } + + generator, err := twcc.NewSenderInterceptor(twcc.SendInterval(feedbackInterval)) + if err != nil { + return err + } + + p.interceptorRegistry.Add(generator) + return nil + } +} + +func registerTWCCHeaderExtension() option { + return func(p *peer) error { + return webrtc.ConfigureTWCCHeaderExtensionSender(p.mediaEngine, p.interceptorRegistry) + } +} func registerCCFB() option { return func(p *peer) error { - ccfb, err := rfc8888.NewSenderInterceptor() + ccfb, err := rfc8888.NewSenderInterceptor(rfc8888.SendInterval(feedbackInterval)) if err != nil { return err } @@ -115,6 +134,34 @@ func registerCCFB() option { } } +func initGCC() option { + return func(p *peer) (err error) { + p.estimator, err = gcc.NewSendSideController(500_000, 128_000, 50_000_000) + if err != nil { + return err + } + + return nil + } +} + +func setOnRateCallback(onRateUpdate func(int)) option { + return func(p *peer) error { + p.onRateUpdate = onRateUpdate + + return nil + } +} + +func registerPacer() option { + return func(p *peer) error { + p.pacer = pacing.NewInterceptor() + p.interceptorRegistry.Add(p.pacer) + + return nil + } +} + type peer struct { logger logging.LeveledLogger pc *webrtc.PeerConnection @@ -125,6 +172,10 @@ type peer struct { onRemoteTrack func(*webrtc.TrackRemote) onConnected func() + + pacer *pacing.InterceptorFactory + estimator *gcc.SendSideController + onRateUpdate func(int) } func newPeer(opts ...option) (*peer, error) { @@ -274,9 +325,38 @@ func (p *peer) addRemoteTrack() error { func (p *peer) readRTCP(r *webrtc.RTPSender) { for { - _, _, err := r.ReadRTCP() + _, attr, err := r.ReadRTCP() if err != nil { return } + report, ok := attr.Get(rtpfb.CCFBAttributesKey).(rtpfb.Report) + if ok { + p.updateTargetRate(report) + } + } +} + +func (p *peer) updateTargetRate(report rtpfb.Report) { + if p.estimator != nil { + for _, pr := range report.PacketReports { + if pr.Arrived { + p.estimator.OnAck( + pr.SequenceNumber, + pr.Size, + pr.Departure, + pr.Arrival, + ) + } else { + p.estimator.OnLoss() + } + } + rate := p.estimator.OnFeedback(report.Arrival, report.RTT) + p.logger.Infof("new target rate: %v", rate) + if p.onRateUpdate != nil { + p.onRateUpdate(rate) + } + if p.pacer != nil { + p.pacer.SetRate(p.pc.ID(), int(1.5*float64(rate))) + } } } diff --git a/simulation/perfect_codec_test.go b/simulation/perfect_codec_test.go index 5953601..d75fe9a 100644 --- a/simulation/perfect_codec_test.go +++ b/simulation/perfect_codec_test.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT -//go:build !js +//go:build !js && go1.25 && simulation package simulation @@ -25,6 +25,8 @@ type perfectCodec struct { writer sampleWriter + minTargetRateBps int + maxTargetRateBps int targetBitrateBps int fps int bitrateUpdateCh chan int @@ -34,26 +36,31 @@ type perfectCodec struct { } // newPerfectCodec creates a new PerfectCodec with the specified frame writer and target bitrate. -func newPerfectCodec(writer sampleWriter, targetBitrateBps int) *perfectCodec { +func newPerfectCodec(writer sampleWriter, minTargetRateBps, maxTargetRateBps, initTargetBitrateBps int) *perfectCodec { return &perfectCodec{ logger: logging.NewDefaultLoggerFactory().NewLogger("perfect_codec"), writer: writer, - targetBitrateBps: targetBitrateBps, + minTargetRateBps: minTargetRateBps, + maxTargetRateBps: maxTargetRateBps, + targetBitrateBps: initTargetBitrateBps, fps: 30, bitrateUpdateCh: make(chan int), done: make(chan struct{}), + wg: sync.WaitGroup{}, } } // setTargetBitrate sets the target bitrate to r bits per second. -// func (c *perfectCodec) setTargetBitrate(r int) { -// c.wg.Go(func() { -// select { -// case c.bitrateUpdateCh <- r: -// case <-c.done: -// } -// }) -// } +func (c *perfectCodec) setTargetBitrate(r int) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + select { + case c.bitrateUpdateCh <- r: + case <-c.done: + } + }() +} // start begins the codec operation, generating frames at the configured frame rate. func (c *perfectCodec) start() { @@ -81,6 +88,8 @@ func (c *perfectCodec) start() { continue } case nextRate := <-c.bitrateUpdateCh: + nextRate = max(nextRate, c.minTargetRateBps) + nextRate = min(nextRate, c.maxTargetRateBps) c.targetBitrateBps = nextRate case <-c.done: return diff --git a/simulation/plots.py b/simulation/plots.py new file mode 100644 index 0000000..9243ede --- /dev/null +++ b/simulation/plots.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: 2026 The Pion community +# SPDX-License-Identifier: MIT + +import pandas as pd +import datetime +import glob +import re +from concurrent.futures import ProcessPoolExecutor + +from pathlib import Path + +import matplotlib.pyplot as plt +import matplotlib.ticker as mticker + + +usage_and_state = { + -1: 'over / decrease', + 0: 'hold / normal', + 1: 'under / increase', +} + + +def read_json_file(file): + df = pd.read_json(file, lines=True) + df['time'] = pd.to_datetime(df['time'], format='mixed') + df['time'] = df['time'].dt.tz_localize(tz=None) + df['bits'] = df['payload-size'] * 8 + + rtp_tx = df[(df['msg'] == 'rtp') & (df['vantage-point'] == + 'sender')].dropna(axis=1, how='all') + rtp_rx = df[(df['msg'] == 'rtp') & (df['vantage-point'] == + 'receiver')].dropna(axis=1, how='all') + + latency = rtp_tx.merge(rtp_rx, on='unwrapped-sequence-number')[['time_x', + 'time_y']] + latency['latency'] = (latency['time_y'] - latency['time_x']) / \ + datetime.timedelta(milliseconds=1) / 1000.0 + + loss = rtp_tx.merge(rtp_rx, on='unwrapped-sequence-number', how='left', + indicator=True) + loss['lost'] = loss['_merge'] == 'left_only' + loss = loss[['time_x', 'unwrapped-sequence-number', 'lost']] + + p = Path(file) + return p.stem, df, rtp_tx, rtp_rx, latency, loss + + +def read_pion_log(file): + drc_data = [] + drc_pattern = re.compile(r'.*TRACE: (\d{2}:\d{2}:\d{2}\.\d{6}).* ts=(\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}.\d{6}), seq=(\d+), interArrivalTime=(\d+), interDepartureTime=(\d+), interGroupDelay=(-?\d+), estimate=(-?\d+.\d+), threshold=(\d+.\d+), usage=(-?\d+), state=(-?\d+)') + + sbwe_data = [] + sbwe_pattern = re.compile(r'.*TRACE: (\d{2}:\d{2}:\d{2}\.\d{6}).* rtt=(\d+), delivered=(\d+), lossTarget=(\d+), delayTarget=(\d+), target=(\d+)') + with open(file, 'r') as f: + for line in f: + match = drc_pattern.match(line) + if match: + drc_data.append({ + 'time': pd.to_datetime(f'2000-01-01 {match.group(1)}'), + 'ts': match.group(2), + 'seq': int(match.group(3)), + 'inter_arrival_time': int(match.group(4)), + 'inter_departure_time': int(match.group(5)), + 'inter_group_delay': int(match.group(6)), + 'estimate': float(match.group(7)), + 'threshold': float(match.group(8)), + 'usage': int(match.group(9)), + 'state': int(match.group(10)), + }) + match = sbwe_pattern.match(line) + if match: + sbwe_data.append({ + 'time': pd.to_datetime(f'2000-01-01 {match.group(1)}'), + 'rtt': int(match.group(2)), + 'delivered': int(match.group(3)), + 'loss-target': int(match.group(4)), + 'delay-target': int(match.group(5)), + 'target': int(match.group(6)), + }) + return pd.DataFrame(drc_data), pd.DataFrame(sbwe_data) + + +def plot_gcc_usage_and_state(ax, df): + df = df.dropna(subset=['usage', 'state']) + df['usage'] = -df['usage'] + ax.step(df.index, df['usage'], where='post', label='usage', linewidth=0.5) + ax.step(df.index, df['state'], where='post', label='state', linewidth=0.5) + ax.set_xlabel('Time') + ax.yaxis.set_major_formatter( + mticker.FuncFormatter(lambda x, pos: usage_and_state.get(x, ''))) + ax.legend(loc='upper right') + + +def plot_gcc_rtt(ax, df): + df['rtt'] = df['rtt']*1e-9 + ax.plot(df.index, df['rtt'], label='RTT', linewidth=0.5) + ax.yaxis.set_major_formatter(mticker.EngFormatter(unit='s')) + ax.legend(loc='upper right') + + +def plot_gcc_target_rates(ax, df): + ax.plot(df.index, df['loss-target'], label='loss-target', linewidth=0.5) + ax.plot(df.index, df['delay-target'], label='delay-target', linewidth=0.5) + ax.plot(df.index, df['target'], label='target', linewidth=0.5) + ax.yaxis.set_major_formatter(mticker.EngFormatter(unit='b/s')) + ax.legend(loc='upper right') + + +def plot_gcc_estimates(ax, df): + df['inter_group_delay'] = df['inter_group_delay'] * 1e-3 + df['estimate'] = df['estimate'] + df['scaled_estimate'] = df['estimate'] * 60 + ax.plot(df.index, df['inter_group_delay'], + label='inter_group_delay', linewidth=0.5) + ax.plot(df.index, df['estimate'], label='estimate', linewidth=0.5) + ax.plot(df.index, df['scaled_estimate'], label='scaled_estimate', linewidth=0.5) + ax.plot(df.index, df['threshold'], label='threshold', linewidth=0.5) + ax.plot(df.index, -df['threshold'], label='-threshold', linewidth=0.5) + ax.yaxis.set_major_formatter(mticker.EngFormatter(unit='s')) + ax.legend(loc='upper right') + + +def plot_target_rate(ax, df): + df = df[df['msg'] == 'setting codec target bitrate'] + ax.plot(df['time'], df['rate'], label='Target Rate', linewidth=0.5) + + +def plot_rate(ax, label, df): + df.set_index('time', inplace=True) + df['bits'] = df['bits'] * 5 + df = df.resample('200ms').sum(numeric_only=True) + ax.plot(df.index, df['bits'], label=label, linewidth=0.5) + + +def plot_latency(ax, df): + ax.plot(df['time_x'], df['latency'], linewidth=0.5) + + +def plot_loss(ax, df): + df.set_index('time_x', inplace=True) + df = df.resample('1s').agg({'lost': 'sum', 'unwrapped-sequence-number': + 'count'}) + df['ratio'] = df['lost'] / df['unwrapped-sequence-number'] + ax.plot(df.index, df['ratio'], linewidth=0.5) + + +def plot(output, json, stderr): + name, df, rtp_tx, rtp_rx, latency, loss = read_json_file(json) + gcc_drc, gcc_sbwe = read_pion_log(stderr) + gcc_drc.set_index('time', inplace=True) + gcc_sbwe.set_index('time', inplace=True) + + fig, ax = plt.subplots(7, 1, sharex=True, figsize=(10, 10), + constrained_layout=True) + + plot_rate(ax[0], 'Send Rate', rtp_tx) + plot_rate(ax[0], 'Receive Rate', rtp_rx) + plot_target_rate(ax[0], df) + ax[0].set_title('RTP Rates') + ax[0].yaxis.set_major_formatter(mticker.EngFormatter(unit='b/s')) + ax[0].legend(loc='upper right') + + plot_gcc_target_rates(ax[1], gcc_sbwe) + ax[1].set_title('GCC Target Rates') + + plot_latency(ax[2], latency) + ax[2].set_title('E2E Delay') + ax[2].yaxis.set_major_formatter(mticker.EngFormatter(unit='s')) + + plot_gcc_estimates(ax[3], gcc_drc) + ax[3].set_title('GCC Estimates') + + plot_gcc_usage_and_state(ax[4], gcc_drc) + ax[4].set_title('GCC Usage and State') + + plot_gcc_rtt(ax[5], gcc_sbwe) + ax[5].set_title('GCC RTT') + + plot_loss(ax[6], loss) + ax[6].set_title('Packet Loss') + ax[6].yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0)) + + fig.suptitle(name) + plt.savefig(f'{output}/{name}.png', dpi=450) + plt.close(fig) + + +def plot_all(input, output): + json_logs = sorted(glob.glob(f'{input}/*.jsonl')) + stderr_logs = sorted(glob.glob(f'{input}/*.stderr')) + with ProcessPoolExecutor() as executor: + results = list(executor.map(plot, [output] * len(json_logs), json_logs, stderr_logs)) diff --git a/simulation/simulation.go b/simulation/simulation.go deleted file mode 100644 index cb336be..0000000 --- a/simulation/simulation.go +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-FileCopyrightText: 2026 The Pion community -// SPDX-License-Identifier: MIT - -// Package simulation implements bandwidth estimation tests using the synctest -// package. -package simulation diff --git a/simulation/virtual_network_test.go b/simulation/virtual_network_test.go new file mode 100644 index 0000000..9391008 --- /dev/null +++ b/simulation/virtual_network_test.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: 2025 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js && go1.25 && simulation + +package simulation + +import ( + "errors" + "testing" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/v4/vnet" + "github.com/stretchr/testify/assert" +) + +type virtualNetwork struct { + wan *vnet.Router + left *vnet.Net + leftTBF *vnet.Queue + leftDelay *vnet.DelayFilter + + right *vnet.Net + rightTBF *vnet.Queue + rightDelay *vnet.DelayFilter +} + +func (n *virtualNetwork) Close() error { + return errors.Join( + n.leftTBF.Close(), + n.leftDelay.Close(), + n.rightTBF.Close(), + n.rightDelay.Close(), + n.wan.Stop(), + ) +} + +func createVirtualNetwork(rate, burst int, delay time.Duration) func(*testing.T) *virtualNetwork { + return func(t *testing.T) *virtualNetwork { + t.Helper() + + bdp := float64(rate) * delay.Seconds() + bottleneckQueueSize := int(max(bdp, 3000)) // allow at least two packets of MTU size 1500 in queue + + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + leftRouter, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.1.0/24", + StaticIPs: []string{ + "10.0.1.1/10.0.1.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + }) + assert.NoError(t, err) + + leftTBF, err := vnet.NewQueue( + leftRouter, + vnet.NewTBFQueue(rate, burst, int64(bottleneckQueueSize)), + ) + assert.NoError(t, err) + + leftDelay, err := vnet.NewDelayFilter(leftTBF, vnet.WithDelay(delay)) + assert.NoError(t, err) + + err = wan.AddNet(leftDelay) + assert.NoError(t, err) + + err = wan.AddChildRouter(leftRouter) + assert.NoError(t, err) + + rightRouter, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.2.0/24", + StaticIPs: []string{ + "10.0.2.1/10.0.2.101", + }, + LoggerFactory: logging.NewDefaultLoggerFactory(), + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + }) + assert.NoError(t, err) + + rightTBF, err := vnet.NewQueue( + rightRouter, + vnet.NewTBFQueue(rate, burst, int64(bottleneckQueueSize)), + ) + assert.NoError(t, err) + + rightDelay, err := vnet.NewDelayFilter(rightTBF, vnet.WithDelay(delay)) + assert.NoError(t, err) + + err = wan.AddNet(rightDelay) + assert.NoError(t, err) + + err = wan.AddChildRouter(rightRouter) + assert.NoError(t, err) + + err = wan.Start() + assert.NoError(t, err) + + leftNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.1.101"}, + }) + assert.NoError(t, err) + err = leftRouter.AddNet(leftNet) + assert.NoError(t, err) + + rightNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{"10.0.2.101"}, + }) + assert.NoError(t, err) + err = rightRouter.AddNet(rightNet) + assert.NoError(t, err) + + return &virtualNetwork{ + wan: wan, + left: leftNet, + leftTBF: leftTBF, + leftDelay: leftDelay, + right: rightNet, + rightTBF: rightTBF, + rightDelay: rightDelay, + } + } +} diff --git a/simulation/vnet_test.go b/simulation/vnet_test.go deleted file mode 100644 index 5dd2595..0000000 --- a/simulation/vnet_test.go +++ /dev/null @@ -1,185 +0,0 @@ -// SPDX-FileCopyrightText: 2026 The Pion community -// SPDX-License-Identifier: MIT - -//go:build !js && go1.25 - -package simulation - -import ( - "errors" - "io" - "testing" - "testing/synctest" - "time" - - "github.com/pion/logging" - "github.com/pion/transport/v3/vnet" - "github.com/pion/webrtc/v4" - "github.com/stretchr/testify/assert" -) - -type network struct { - wan *vnet.Router - left *vnet.Net - right *vnet.Net -} - -func (n *network) Close() error { - return n.wan.Stop() -} - -func createVirtualNetwork(t *testing.T) *network { - t.Helper() - - wan, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "0.0.0.0/0", - LoggerFactory: logging.NewDefaultLoggerFactory(), - }) - assert.NoError(t, err) - - leftRouter, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "10.0.1.0/24", - StaticIPs: []string{ - "10.0.1.1/10.0.1.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - }) - assert.NoError(t, err) - err = wan.AddRouter(leftRouter) - assert.NoError(t, err) - - rightRouter, err := vnet.NewRouter(&vnet.RouterConfig{ - CIDR: "10.0.2.0/24", - StaticIPs: []string{ - "10.0.2.1/10.0.2.101", - }, - LoggerFactory: logging.NewDefaultLoggerFactory(), - NATType: &vnet.NATType{ - Mode: vnet.NATModeNAT1To1, - }, - }) - assert.NoError(t, err) - err = wan.AddRouter(rightRouter) - assert.NoError(t, err) - - err = wan.Start() - assert.NoError(t, err) - - leftNet, err := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"10.0.1.101"}, - StaticIP: "", - }) - assert.NoError(t, err) - err = leftRouter.AddNet(leftNet) - assert.NoError(t, err) - - rightNet, err := vnet.NewNet(&vnet.NetConfig{ - StaticIPs: []string{"10.0.2.101"}, - StaticIP: "", - }) - assert.NoError(t, err) - err = rightRouter.AddNet(rightNet) - assert.NoError(t, err) - - return &network{ - wan: wan, - left: leftNet, - right: rightNet, - } -} - -func TestVnet(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - t.Helper() - - onTrack := make(chan struct{}) - connected := make(chan struct{}) - done := make(chan struct{}) - - network := createVirtualNetwork(t) - receiver, err := newPeer( - registerDefaultCodecs(), - setVNet(network.left, []string{"10.0.1.1"}), - onRemoteTrack(func(track *webrtc.TrackRemote) { - close(onTrack) - go func() { - buf := make([]byte, 1500) - for { - select { - case <-done: - return - default: - _, _, err := track.Read(buf) - if errors.Is(err, io.EOF) { - return - } - assert.NoError(t, err) - } - } - }() - }), - registerPacketLogger("receiver"), - registerCCFB(), - ) - assert.NoError(t, err) - - err = receiver.addRemoteTrack() - assert.NoError(t, err) - - sender, err := newPeer( - registerDefaultCodecs(), - onConnected(func() { close(connected) }), - setVNet(network.right, []string{"10.0.2.1"}), - registerPacketLogger("sender"), - registerRTPFB(), - ) - assert.NoError(t, err) - - track, err := sender.addLocalTrack() - assert.NoError(t, err) - - codec := newPerfectCodec(track, 1_000_000) - go func() { - <-connected - codec.start() - }() - - offer, err := sender.createOffer() - assert.NoError(t, err) - - err = receiver.setRemoteDescription(offer) - assert.NoError(t, err) - - answer, err := receiver.createAnswer() - assert.NoError(t, err) - - err = sender.setRemoteDescription(answer) - assert.NoError(t, err) - - synctest.Wait() - select { - case <-onTrack: - case <-time.After(time.Second): - assert.Fail(t, "on track not called") - } - time.Sleep(10 * time.Second) - close(done) - - err = codec.Close() - assert.NoError(t, err) - - err = sender.pc.Close() - assert.NoError(t, err) - - err = receiver.pc.Close() - assert.NoError(t, err) - - err = network.Close() - assert.NoError(t, err) - - synctest.Wait() - }) -}