Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
// domain, avoid pounding manager or storage thousands of times simultaneously. We use a similar sync
// strategy for obtaining certificate during handshake.
certLoadWaitChansMu.Lock()
wait, ok := certLoadWaitChans[name]
waiter, ok := certLoadWaitChans[name]
if ok {
// another goroutine is already loading the cert; just wait and we'll get it from the in-memory cache
// another goroutine is already loading the cert; just wait
certLoadWaitChansMu.Unlock()

timeout := time.NewTimer(2 * time.Minute)
Expand All @@ -310,33 +310,44 @@ func (cfg *Config) getCertDuringHandshake(ctx context.Context, hello *tls.Client
case <-ctx.Done():
timeout.Stop()
return Certificate{}, ctx.Err()
case <-wait:
case <-waiter.done:
timeout.Stop()
}

return cfg.getCertDuringHandshake(ctx, hello, false)
} else {
// no other goroutine is currently trying to load this cert
wait = make(chan struct{})
certLoadWaitChans[name] = wait
certLoadWaitChansMu.Unlock()
// If the leader got a result from an external cert manager, use it
// directly — these certs are not added to the cache, so a recursive
// cache lookup would miss. For cached certs (on-demand, managed),
// the waiter result will be empty and we fall through to the
// original recursive lookup.
if !waiter.cert.Empty() || waiter.err != nil {
return waiter.cert, waiter.err
}

// unblock others and clean up when we're done
defer func() {
certLoadWaitChansMu.Lock()
close(wait)
delete(certLoadWaitChans, name)
certLoadWaitChansMu.Unlock()
}()
return cfg.getCertDuringHandshake(ctx, hello, false)
}

// no other goroutine is currently trying to load this cert
waiter = &certLoadWaiter{done: make(chan struct{})}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pool these maybe to avoid some allocations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, the channel get's closed to wake all waiters at once and can't really be reused and needs to be freshly allocated everytime, so pooling would only save the small struct around it.
Putting it back in the pool is also not straightforward I think, because the waiters still hold references after the leader is done, so we'd probably need some kind of reference counting to know when it's safe.

It feels like a lot of complexity for one allocation per contended name, but happy to reconsider if you see it differently.

certLoadWaitChans[name] = waiter
certLoadWaitChansMu.Unlock()

// unblock others and clean up when we're done
defer func() {
certLoadWaitChansMu.Lock()
close(waiter.done)
delete(certLoadWaitChans, name)
certLoadWaitChansMu.Unlock()
}()

// If an external Manager is configured, try to get it from them.
// Only continue to use our own logic if it returns empty+nil.
externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, logger)
if err != nil {
waiter.err = err
return Certificate{}, err
}
if !externalCert.Empty() {
waiter.cert = externalCert
return externalCert, nil
}

Expand Down Expand Up @@ -946,9 +957,19 @@ var (
obtainCertWaitChansMu sync.Mutex
)

// certLoadWaiter coordinates concurrent certificate loading for the same name.
// The leader populates the result and closes the channel; waiters read the result
// after the channel is closed. This allows externally-managed certificates (which
// are not cached) to be shared directly with waiting goroutines.
type certLoadWaiter struct {
done chan struct{}
cert Certificate
err error
}

// TODO: this lockset should probably be per-cache
var (
certLoadWaitChans = make(map[string]chan struct{})
certLoadWaitChans = make(map[string]*certLoadWaiter)
certLoadWaitChansMu sync.Mutex
)

Expand Down