Skip to content

Commit 497dbe8

Browse files
authored
feat: Allow blocking until primary rate limit is reset (#3117)
Fixes: #3114.
1 parent e21b500 commit 497dbe8

File tree

3 files changed

+212
-0
lines changed

3 files changed

+212
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,12 @@ if _, ok := err.(*github.AbuseRateLimitError); ok {
180180
}
181181
```
182182

183+
Alternatively, you can block until the rate limit is reset by using the `context.WithValue` method:
184+
185+
````go
186+
repos, _, err := client.Repositories.List(context.WithValue(ctx, github.SleepUntilPrimaryRateLimitResetWhenRateLimited, true), "", nil)
187+
```
188+
183189
You can use [go-github-ratelimit](https://github.com/gofri/go-github-ratelimit) to handle
184190
secondary rate limit sleep-and-retry for you.
185191

github/github.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ type requestContext uint8
804804

805805
const (
806806
bypassRateLimitCheck requestContext = iota
807+
SleepUntilPrimaryRateLimitResetWhenRateLimited
807808
)
808809

809810
// BareDo sends an API request and lets you handle the api response. If an error
@@ -889,6 +890,15 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro
889890
err = aerr
890891
}
891892

893+
rateLimitError, ok := err.(*RateLimitError)
894+
if ok && req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
895+
if err := sleepUntilResetWithBuffer(req.Context(), rateLimitError.Rate.Reset.Time); err != nil {
896+
return response, err
897+
}
898+
// retry the request once when the rate limit has reset
899+
return c.BareDo(context.WithValue(req.Context(), SleepUntilPrimaryRateLimitResetWhenRateLimited, nil), req)
900+
}
901+
892902
// Update the secondary rate limit if we hit it.
893903
rerr, ok := err.(*AbuseRateLimitError)
894904
if ok && rerr.RetryAfter != nil {
@@ -950,6 +960,18 @@ func (c *Client) checkRateLimitBeforeDo(req *http.Request, rateLimitCategory Rat
950960
Header: make(http.Header),
951961
Body: io.NopCloser(strings.NewReader("")),
952962
}
963+
964+
if req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
965+
if err := sleepUntilResetWithBuffer(req.Context(), rate.Reset.Time); err == nil {
966+
return nil
967+
}
968+
return &RateLimitError{
969+
Rate: rate,
970+
Response: resp,
971+
Message: fmt.Sprintf("Context cancelled while waiting for rate limit to reset until %v, not making remote request.", rate.Reset.Time),
972+
}
973+
}
974+
953975
return &RateLimitError{
954976
Rate: rate,
955977
Response: resp,
@@ -1514,6 +1536,20 @@ func formatRateReset(d time.Duration) string {
15141536
return fmt.Sprintf("[rate reset in %v]", timeString)
15151537
}
15161538

1539+
func sleepUntilResetWithBuffer(ctx context.Context, reset time.Time) error {
1540+
buffer := time.Second
1541+
timer := time.NewTimer(time.Until(reset) + buffer)
1542+
select {
1543+
case <-ctx.Done():
1544+
if !timer.Stop() {
1545+
<-timer.C
1546+
}
1547+
return ctx.Err()
1548+
case <-timer.C:
1549+
}
1550+
return nil
1551+
}
1552+
15171553
// When using roundTripWithOptionalFollowRedirect, note that it
15181554
// is the responsibility of the caller to close the response body.
15191555
func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u string, maxRedirects int, opts ...RequestOption) (*http.Response, error) {

github/github_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,6 +1381,176 @@ func TestDo_rateLimit_ignoredFromCache(t *testing.T) {
13811381
}
13821382
}
13831383

1384+
// Ensure sleeps until the rate limit is reset when the client is rate limited.
1385+
func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) {
1386+
client, mux, _, teardown := setup()
1387+
defer teardown()
1388+
1389+
reset := time.Now().UTC().Add(time.Second)
1390+
1391+
var firstRequest = true
1392+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1393+
if firstRequest {
1394+
firstRequest = false
1395+
w.Header().Set(headerRateLimit, "60")
1396+
w.Header().Set(headerRateRemaining, "0")
1397+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
1398+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1399+
w.WriteHeader(http.StatusForbidden)
1400+
fmt.Fprintln(w, `{
1401+
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
1402+
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
1403+
}`)
1404+
return
1405+
}
1406+
w.Header().Set(headerRateLimit, "5000")
1407+
w.Header().Set(headerRateRemaining, "5000")
1408+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
1409+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1410+
w.WriteHeader(http.StatusOK)
1411+
fmt.Fprintln(w, `{}`)
1412+
})
1413+
1414+
req, _ := client.NewRequest("GET", ".", nil)
1415+
ctx := context.Background()
1416+
resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
1417+
if err != nil {
1418+
t.Errorf("Do returned unexpected error: %v", err)
1419+
}
1420+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1421+
t.Errorf("Response status code = %v, want %v", got, want)
1422+
}
1423+
}
1424+
1425+
// Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once.
1426+
func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) {
1427+
client, mux, _, teardown := setup()
1428+
defer teardown()
1429+
1430+
reset := time.Now().UTC().Add(time.Second)
1431+
1432+
requestCount := 0
1433+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1434+
requestCount++
1435+
w.Header().Set(headerRateLimit, "60")
1436+
w.Header().Set(headerRateRemaining, "0")
1437+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
1438+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1439+
w.WriteHeader(http.StatusForbidden)
1440+
fmt.Fprintln(w, `{
1441+
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
1442+
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
1443+
}`)
1444+
})
1445+
1446+
req, _ := client.NewRequest("GET", ".", nil)
1447+
ctx := context.Background()
1448+
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
1449+
if err == nil {
1450+
t.Error("Expected error to be returned.")
1451+
}
1452+
if got, want := requestCount, 2; got != want {
1453+
t.Errorf("Expected 2 requests, got %d", got)
1454+
}
1455+
}
1456+
1457+
// Ensure a network call is not made when it's known that API rate limit is still exceeded.
1458+
func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) {
1459+
client, mux, _, teardown := setup()
1460+
defer teardown()
1461+
1462+
reset := time.Now().UTC().Add(time.Second)
1463+
client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}}
1464+
requestCount := 0
1465+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1466+
requestCount++
1467+
w.Header().Set(headerRateLimit, "5000")
1468+
w.Header().Set(headerRateRemaining, "5000")
1469+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
1470+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1471+
w.WriteHeader(http.StatusOK)
1472+
fmt.Fprintln(w, `{}`)
1473+
})
1474+
req, _ := client.NewRequest("GET", ".", nil)
1475+
ctx := context.Background()
1476+
resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
1477+
if err != nil {
1478+
t.Errorf("Do returned unexpected error: %v", err)
1479+
}
1480+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1481+
t.Errorf("Response status code = %v, want %v", got, want)
1482+
}
1483+
if got, want := requestCount, 1; got != want {
1484+
t.Errorf("Expected 1 request, got %d", got)
1485+
}
1486+
}
1487+
1488+
// Ensure sleep is aborted when the context is cancelled.
1489+
func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) {
1490+
client, mux, _, teardown := setup()
1491+
defer teardown()
1492+
1493+
// We use a 1 minute reset time to ensure the sleep is not completed.
1494+
reset := time.Now().UTC().Add(time.Minute)
1495+
requestCount := 0
1496+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1497+
requestCount++
1498+
w.Header().Set(headerRateLimit, "60")
1499+
w.Header().Set(headerRateRemaining, "0")
1500+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
1501+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1502+
w.WriteHeader(http.StatusForbidden)
1503+
fmt.Fprintln(w, `{
1504+
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
1505+
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
1506+
}`)
1507+
})
1508+
1509+
req, _ := client.NewRequest("GET", ".", nil)
1510+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
1511+
defer cancel()
1512+
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
1513+
if !errors.Is(err, context.DeadlineExceeded) {
1514+
t.Error("Expected context deadline exceeded error.")
1515+
}
1516+
if got, want := requestCount, 1; got != want {
1517+
t.Errorf("Expected 1 requests, got %d", got)
1518+
}
1519+
}
1520+
1521+
// Ensure sleep is aborted when the context is cancelled on initial request.
1522+
func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) {
1523+
client, mux, _, teardown := setup()
1524+
defer teardown()
1525+
1526+
reset := time.Now().UTC().Add(time.Minute)
1527+
client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}}
1528+
requestCount := 0
1529+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
1530+
requestCount++
1531+
w.Header().Set(headerRateLimit, "5000")
1532+
w.Header().Set(headerRateRemaining, "5000")
1533+
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
1534+
w.Header().Set("Content-Type", "application/json; charset=utf-8")
1535+
w.WriteHeader(http.StatusOK)
1536+
fmt.Fprintln(w, `{}`)
1537+
})
1538+
req, _ := client.NewRequest("GET", ".", nil)
1539+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
1540+
defer cancel()
1541+
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
1542+
rateLimitError, ok := err.(*RateLimitError)
1543+
if !ok {
1544+
t.Fatalf("Expected a *rateLimitError error; got %#v.", err)
1545+
}
1546+
if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) {
1547+
t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got)
1548+
}
1549+
if got, want := requestCount, 0; got != want {
1550+
t.Errorf("Expected 1 requests, got %d", got)
1551+
}
1552+
}
1553+
13841554
// Ensure *AbuseRateLimitError is returned when the response indicates that
13851555
// the client has triggered an abuse detection mechanism.
13861556
func TestDo_rateLimit_abuseRateLimitError(t *testing.T) {

0 commit comments

Comments
 (0)