@@ -7,6 +7,7 @@ package tunnel
77import (
88 "context"
99 "crypto/tls"
10+ "crypto/x509"
1011 "encoding/json"
1112 "errors"
1213 "fmt"
@@ -24,6 +25,20 @@ import (
2425 "github.com/mmatczuk/go-http-tunnel/proto"
2526)
2627
28+ // A set of listeners to manage subscribers
29+ type SubscriptionListener interface {
30+ // Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
31+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
32+ // verified chain, else the presented peer certificate.
33+ CanSubscribe (id id.ID , chain []* x509.Certificate ) bool
34+ // Invoked when the client has been subscribed.
35+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
36+ // verified chain, else the presented peer certificate.
37+ Subscribed (id id.ID , tlsConn * tls.Conn , chain []* x509.Certificate )
38+ // Invoked before the client is unsubscribed.
39+ Unsubscribed (id id.ID )
40+ }
41+
2742// ServerConfig defines configuration for the Server.
2843type ServerConfig struct {
2944 // Addr is TCP address to listen for client connections. If empty ":0"
@@ -41,6 +56,8 @@ type ServerConfig struct {
4156 Logger log.Logger
4257 // Addr is TCP address to listen for TLS SNI connections
4358 SNIAddr string
59+ // Optional listener to manage subscribers
60+ SubscriptionListener SubscriptionListener
4461}
4562
4663// Server is responsible for proxying public connections to the client over a
@@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
238255 ok bool
239256
240257 inConnPool bool
258+ certs []* x509.Certificate
241259 )
242260
243261 tlsConn , ok := conn .(* tls.Conn )
@@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
262280
263281 logger = logger .With ("identifier" , identifier )
264282
283+ certs = tlsConn .ConnectionState ().PeerCertificates
284+ if tlsConn .ConnectionState ().VerifiedChains != nil && len (tlsConn .ConnectionState ().VerifiedChains ) > 0 {
285+ certs = tlsConn .ConnectionState ().VerifiedChains [0 ]
286+ }
265287 if s .config .AutoSubscribe {
266288 s .Subscribe (identifier )
289+ if s .config .SubscriptionListener != nil {
290+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
291+ }
267292 } else if ! s .IsSubscribed (identifier ) {
268- logger .Log (
269- "level" , 2 ,
270- "msg" , "unknown client" ,
271- )
272- goto reject
293+ if s .config .SubscriptionListener != nil && s .config .SubscriptionListener .CanSubscribe (identifier , certs ) {
294+ s .Subscribe (identifier )
295+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
296+ } else {
297+ logger .Log (
298+ "level" , 2 ,
299+ "msg" , "unknown client" ,
300+ )
301+ goto reject
302+ }
273303 }
274304
275305 if err = conn .SetDeadline (time.Time {}); err != nil {
@@ -486,6 +516,9 @@ rollback:
486516// Unsubscribe removes client from registry, disconnects client if already
487517// connected and returns it's RegistryItem.
488518func (s * Server ) Unsubscribe (identifier id.ID ) * RegistryItem {
519+ if s .config .SubscriptionListener != nil {
520+ s .config .SubscriptionListener .Unsubscribed (identifier )
521+ }
489522 s .connPool .DeleteConn (identifier )
490523 return s .registry .Unsubscribe (identifier )
491524}
@@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
561594 }
562595}
563596
597+ func (s * Server ) Upgrade (identifier id.ID , conn net.Conn , requestBytes []byte ) error {
598+
599+ var err error
600+
601+ msg := & proto.ControlMessage {
602+ Action : proto .ActionProxy ,
603+ ForwardedProto : "https" ,
604+ }
605+
606+ tlsConn , ok := conn .(* tls.Conn )
607+ if ok {
608+ msg .ForwardedHost = tlsConn .ConnectionState ().ServerName
609+ err = keepAlive (tlsConn .NetConn ())
610+
611+ } else {
612+ msg .ForwardedHost = conn .RemoteAddr ().String ()
613+ err = keepAlive (conn )
614+ }
615+
616+ if err != nil {
617+ s .logger .Log (
618+ "level" , 1 ,
619+ "msg" , "TCP keepalive for tunneled connection failed" ,
620+ "identifier" , identifier ,
621+ "ctrlMsg" , msg ,
622+ "err" , err ,
623+ )
624+ }
625+
626+ go func () {
627+ if err := s .proxyConnUpgraded (identifier , conn , msg , requestBytes ); err != nil {
628+ s .logger .Log (
629+ "level" , 0 ,
630+ "msg" , "proxy error" ,
631+ "identifier" , identifier ,
632+ "ctrlMsg" , msg ,
633+ "err" , err ,
634+ )
635+ }
636+ }()
637+
638+ return nil
639+ }
640+
564641// ServeHTTP proxies http connection to the client.
565642func (s * Server ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
566643 resp , err := s .RoundTrip (r )
@@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
639716 return s .proxyHTTP (identifier , outr , msg )
640717}
641718
719+ func (s * Server ) proxyConnUpgraded (identifier id.ID , conn net.Conn , msg * proto.ControlMessage , requestBytes []byte ) error {
720+ s .logger .Log (
721+ "level" , 2 ,
722+ "action" , "proxy conn" ,
723+ "identifier" , identifier ,
724+ "ctrlMsg" , msg ,
725+ )
726+
727+ defer conn .Close ()
728+
729+ pr , pw := io .Pipe ()
730+ defer pr .Close ()
731+ defer pw .Close ()
732+
733+ continueChan := make (chan int )
734+
735+ go func () {
736+ pw .Write (requestBytes )
737+ continueChan <- 1
738+ }()
739+
740+ req , err := s .connectRequest (identifier , msg , pr )
741+ if err != nil {
742+ return err
743+ }
744+
745+ ctx , cancel := context .WithCancel (context .Background ())
746+ req = req .WithContext (ctx )
747+
748+ done := make (chan struct {})
749+ go func () {
750+ <- continueChan
751+ transfer (pw , conn , log .NewContext (s .logger ).With (
752+ "dir" , "user to client" ,
753+ "dst" , identifier ,
754+ "src" , conn .RemoteAddr (),
755+ ))
756+ cancel ()
757+ close (done )
758+ }()
759+
760+ resp , err := s .httpClient .Do (req )
761+ if err != nil {
762+ return fmt .Errorf ("io error: %s" , err )
763+ }
764+ defer resp .Body .Close ()
765+
766+ transfer (conn , resp .Body , log .NewContext (s .logger ).With (
767+ "dir" , "client to user" ,
768+ "dst" , conn .RemoteAddr (),
769+ "src" , identifier ,
770+ ))
771+
772+ select {
773+ case <- done :
774+ case <- time .After (DefaultTimeout ):
775+ }
776+
777+ s .logger .Log (
778+ "level" , 2 ,
779+ "action" , "proxy conn done" ,
780+ "identifier" , identifier ,
781+ "ctrlMsg" , msg ,
782+ )
783+
784+ return nil
785+ }
786+
642787func (s * Server ) proxyConn (identifier id.ID , conn net.Conn , msg * proto.ControlMessage ) error {
643788 s .logger .Log (
644789 "level" , 2 ,
0 commit comments