@@ -373,6 +373,194 @@ mod test {
373373 ( blocking_client, async_client)
374374 }
375375
376+ #[ cfg( feature = "async-ohttp" ) ]
377+ fn find_free_port ( ) -> u16 {
378+ let listener = std:: net:: TcpListener :: bind ( "0.0.0.0:0" ) . unwrap ( ) ;
379+ listener. local_addr ( ) . unwrap ( ) . port ( )
380+ }
381+
382+ #[ cfg( feature = "async-ohttp" ) ]
383+ async fn start_ohttp_relay (
384+ gateway_url : ohttp_relay:: GatewayUri ,
385+ ) -> (
386+ u16 ,
387+ tokio:: task:: JoinHandle < Result < ( ) , Box < dyn std:: error:: Error + std:: marker:: Send + Sync > > > ,
388+ ) {
389+ let port = find_free_port ( ) ;
390+ let relay = ohttp_relay:: listen_tcp ( port, gateway_url) . await . unwrap ( ) ;
391+
392+ ( port, relay)
393+ }
394+
395+ #[ cfg( feature = "async-ohttp" ) ]
396+ async fn start_ohttp_gateway ( ) -> ( u16 , tokio:: task:: JoinHandle < ( ) > ) {
397+ use http_body_util:: Full ;
398+ use hyper:: body:: Incoming ;
399+ use hyper:: service:: service_fn;
400+ use hyper:: Response ;
401+ use hyper:: { Method , Request } ;
402+ use hyper_util:: rt:: TokioIo ;
403+ use tokio:: net:: TcpListener ;
404+
405+ let port = find_free_port ( ) ;
406+ let listener = TcpListener :: bind ( format ! ( "0.0.0.0:{}" , port) )
407+ . await
408+ . unwrap ( ) ;
409+
410+ let handle = tokio:: spawn ( async move {
411+ let key_config = bitcoin_ohttp:: KeyConfig :: new (
412+ 0 ,
413+ bitcoin_ohttp:: hpke:: Kem :: K256Sha256 ,
414+ vec ! [ bitcoin_ohttp:: SymmetricSuite :: new(
415+ bitcoin_ohttp:: hpke:: Kdf :: HkdfSha256 ,
416+ bitcoin_ohttp:: hpke:: Aead :: ChaCha20Poly1305 ,
417+ ) ] ,
418+ )
419+ . expect ( "valid key config" ) ;
420+ let server = bitcoin_ohttp:: Server :: new ( key_config) . expect ( "valid server" ) ;
421+ let server = std:: sync:: Arc :: new ( server) ;
422+ loop {
423+ match listener. accept ( ) . await {
424+ Ok ( ( stream, _) ) => {
425+ let io = TokioIo :: new ( stream) ;
426+ let server = server. clone ( ) ;
427+ let service = service_fn ( move |req : Request < Incoming > | {
428+ let server = server. clone ( ) ;
429+ async move {
430+ let path = req. uri ( ) . path ( ) ;
431+ if path == "/.well-known/ohttp-gateway"
432+ && req. method ( ) == Method :: GET
433+ {
434+ let key_config = server. config ( ) . encode ( ) . unwrap ( ) ;
435+ Ok :: < _ , hyper:: Error > (
436+ Response :: builder ( )
437+ . status ( 200 )
438+ . header ( "content-type" , "application/ohttp-keys" )
439+ . body ( Full :: new ( hyper:: body:: Bytes :: from ( key_config) ) )
440+ . unwrap ( ) ,
441+ )
442+ } else if path == "/.well-known/ohttp-gateway"
443+ && req. method ( ) == Method :: POST
444+ {
445+ use http_body_util:: BodyExt ;
446+
447+ // Assert that the content-type header is set to
448+ // "message/ohttp-req".
449+ let content_type_header = req
450+ . headers ( )
451+ . get ( "content-type" )
452+ . expect ( "content-type header should be set by the client" ) ;
453+ assert_eq ! ( content_type_header, "message/ohttp-req" ) ;
454+
455+ let bytes = req. collect ( ) . await ?. to_bytes ( ) ;
456+ let ( bhttp_body, response_ctx) =
457+ server. decapsulate ( bytes. iter ( ) . as_slice ( ) ) . unwrap ( ) ;
458+ // Reconstruct the inner HTTP message from the bhttp message.
459+ let mut r = std:: io:: Cursor :: new ( bhttp_body) ;
460+ let m: bhttp:: Message = bhttp:: Message :: read_bhttp ( & mut r)
461+ . expect ( "Should be valid bhttp message" ) ;
462+ let base_url = format ! (
463+ "http://{}" ,
464+ ELECTRSD . esplora_url. as_ref( ) . unwrap( )
465+ ) ;
466+ let path =
467+ String :: from_utf8 ( m. control ( ) . path ( ) . unwrap ( ) . to_vec ( ) )
468+ . unwrap ( ) ;
469+ let _ =
470+ Method :: from_bytes ( m. control ( ) . method ( ) . unwrap ( ) ) . unwrap ( ) ;
471+ // TODO: Use the actual method from the bhttp message
472+ // This will be refactored out to use bitreq
473+ let req = reqwest:: Request :: new (
474+ Method :: GET ,
475+ url:: Url :: parse ( & ( base_url + & path) ) . unwrap ( ) ,
476+ ) ;
477+ let mut req_builder = reqwest:: RequestBuilder :: from_parts (
478+ reqwest:: Client :: new ( ) ,
479+ req,
480+ ) ;
481+ for field in m. header ( ) . iter ( ) {
482+ req_builder =
483+ req_builder. header ( field. name ( ) , field. value ( ) ) ;
484+ }
485+
486+ let res = req_builder. send ( ) . await . unwrap ( ) ;
487+ // Convert HTTP response to bhttp response
488+ let mut m: bhttp:: Message = bhttp:: Message :: response (
489+ res. status ( ) . as_u16 ( ) . try_into ( ) . unwrap ( ) ,
490+ ) ;
491+ m. write_content ( res. bytes ( ) . await . unwrap ( ) ) ;
492+ let mut bhttp_res = vec ! [ ] ;
493+ m. write_bhttp ( bhttp:: Mode :: IndeterminateLength , & mut bhttp_res)
494+ . unwrap ( ) ;
495+ // Now we need to encapsulate the response
496+ let encapsulated_response =
497+ response_ctx. encapsulate ( & bhttp_res) . unwrap ( ) ;
498+
499+ Ok :: < _ , hyper:: Error > (
500+ Response :: builder ( )
501+ . status ( 200 )
502+ . header ( "content-type" , "message/ohttp-res" )
503+ . body ( Full :: new ( hyper:: body:: Bytes :: copy_from_slice (
504+ & encapsulated_response,
505+ ) ) )
506+ . unwrap ( ) ,
507+ )
508+ } else {
509+ Ok :: < _ , hyper:: Error > (
510+ Response :: builder ( )
511+ . status ( 404 )
512+ . body ( Full :: new ( hyper:: body:: Bytes :: from ( "Not Found" ) ) )
513+ . unwrap ( ) ,
514+ )
515+ }
516+ }
517+ } ) ;
518+
519+ tokio:: spawn ( async move {
520+ if let Err ( err) = hyper:: server:: conn:: http1:: Builder :: new ( )
521+ . serve_connection ( io, service)
522+ . await
523+ {
524+ eprintln ! ( "Error serving connection: {:?}" , err) ;
525+ }
526+ } ) ;
527+ }
528+ Err ( e) => {
529+ eprintln ! ( "Error accepting connection: {:?}" , e) ;
530+ break ;
531+ }
532+ }
533+ }
534+ } ) ;
535+ println ! ( "OHTTP gateway started on port {}" , port) ;
536+
537+ ( port, handle)
538+ }
539+ #[ cfg( feature = "async-ohttp" ) ]
540+ #[ tokio:: test]
541+ async fn test_ohttp_e2e ( ) {
542+ let ( _, async_client) = setup_clients ( ) . await ;
543+ let block_hash = async_client. get_block_hash ( 1 ) . await . unwrap ( ) ;
544+ let esplora_url = ELECTRSD . esplora_url . as_ref ( ) . unwrap ( ) ;
545+ let ( gateway_port, _) = start_ohttp_gateway ( ) . await ;
546+ let gateway_origin = format ! ( "http://localhost:{gateway_port}" ) ;
547+ let ( relay_port, _) =
548+ start_ohttp_relay ( gateway_origin. parse :: < ohttp_relay:: GatewayUri > ( ) . unwrap ( ) ) . await ;
549+ let gateway_url = format ! (
550+ "http://localhost:{}/.well-known/ohttp-gateway" ,
551+ gateway_port
552+ ) ;
553+ let relay_url = format ! ( "http://localhost:{}" , relay_port) ;
554+
555+ let ohttp_client = Builder :: new ( & format ! ( "http://{}" , esplora_url) )
556+ . build_async_with_ohttp ( & relay_url, & gateway_url)
557+ . await
558+ . unwrap ( ) ;
559+
560+ let res = ohttp_client. get_block_hash ( 1 ) . await . unwrap ( ) ;
561+ assert_eq ! ( res, block_hash) ;
562+ }
563+
376564 #[ cfg( all( feature = "blocking" , feature = "async" ) ) ]
377565 fn generate_blocks_and_wait ( num : usize ) {
378566 let cur_height = BITCOIND . client . get_block_count ( ) . unwrap ( ) . 0 ;
0 commit comments