Skip to content

Commit e2f9ca6

Browse files
authored
Relax required host capabilities in wasip3 tcp/udp bindings (#12085)
Don't require `async | store` when it's not necessary as this'll soon have ramifications on the type of the function being bound in the component model.
1 parent cf48211 commit e2f9ca6

File tree

4 files changed

+98
-104
lines changed

4 files changed

+98
-104
lines changed

crates/wasi/src/p3/bindings.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ mod generated {
8383
"wasi:cli/stdout": store | tracing | trappable,
8484
"wasi:cli/stderr": store | tracing | trappable,
8585
"wasi:filesystem/types.[method]descriptor.read-via-stream": store | tracing | trappable,
86-
"wasi:sockets/types.[method]tcp-socket.bind": async | store | tracing | trappable,
87-
"wasi:sockets/types.[method]tcp-socket.listen": async | store | tracing | trappable,
88-
"wasi:sockets/types.[method]tcp-socket.receive": async | store | tracing | trappable,
89-
"wasi:sockets/types.[method]udp-socket.bind": async | store | tracing | trappable,
90-
"wasi:sockets/types.[method]udp-socket.connect": async | store | tracing | trappable,
86+
"wasi:sockets/types.[method]tcp-socket.bind": async | tracing | trappable,
87+
"wasi:sockets/types.[method]tcp-socket.listen": store | tracing | trappable,
88+
"wasi:sockets/types.[method]tcp-socket.receive": store | tracing | trappable,
89+
"wasi:sockets/types.[method]udp-socket.bind": async | tracing | trappable,
90+
"wasi:sockets/types.[method]udp-socket.connect": async | tracing | trappable,
9191
default: tracing | trappable,
9292
},
9393
exports: { default: async | store },

crates/wasi/src/p3/sockets/host/types/tcp.rs

Lines changed: 57 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::sync::Arc;
1818
use tokio::net::{TcpListener, TcpStream};
1919
use tokio::sync::oneshot;
2020
use wasmtime::component::{
21-
Accessor, Destination, FutureReader, Resource, ResourceTable, Source, StreamConsumer,
21+
Access, Accessor, Destination, FutureReader, Resource, ResourceTable, Source, StreamConsumer,
2222
StreamProducer, StreamReader, StreamResult,
2323
};
2424
use wasmtime::{AsContextMut as _, StoreContextMut};
@@ -236,23 +236,6 @@ impl<D> StreamConsumer<D> for SendStreamConsumer {
236236
}
237237

238238
impl HostTcpSocketWithStore for WasiSockets {
239-
async fn bind<T>(
240-
store: &Accessor<T, Self>,
241-
socket: Resource<TcpSocket>,
242-
local_address: IpSocketAddress,
243-
) -> SocketResult<()> {
244-
let local_address = SocketAddr::from(local_address);
245-
if !is_addr_allowed(store, local_address, SocketAddrUse::TcpBind).await {
246-
return Err(ErrorCode::AccessDenied.into());
247-
}
248-
store.with(|mut store| {
249-
let socket = get_socket_mut(store.get().table, &socket)?;
250-
socket.start_bind(local_address)?;
251-
socket.finish_bind()?;
252-
Ok(())
253-
})
254-
}
255-
256239
async fn connect<T>(
257240
store: &Accessor<T, Self>,
258241
socket: Resource<TcpSocket>,
@@ -278,28 +261,26 @@ impl HostTcpSocketWithStore for WasiSockets {
278261
})
279262
}
280263

281-
async fn listen<T: 'static>(
282-
store: &Accessor<T, Self>,
264+
fn listen<T: 'static>(
265+
mut store: Access<'_, T, Self>,
283266
socket: Resource<TcpSocket>,
284267
) -> SocketResult<StreamReader<Resource<TcpSocket>>> {
285268
let getter = store.getter();
286-
store.with(|mut store| {
287-
let socket = get_socket_mut(store.get().table, &socket)?;
288-
socket.start_listen()?;
289-
socket.finish_listen()?;
290-
let listener = socket.tcp_listener_arc().unwrap().clone();
291-
let family = socket.address_family();
292-
let options = socket.non_inherited_options().clone();
293-
Ok(StreamReader::new(
294-
&mut store,
295-
ListenStreamProducer {
296-
listener,
297-
family,
298-
options,
299-
getter,
300-
},
301-
))
302-
})
269+
let socket = get_socket_mut(store.get().table, &socket)?;
270+
socket.start_listen()?;
271+
socket.finish_listen()?;
272+
let listener = socket.tcp_listener_arc().unwrap().clone();
273+
let family = socket.address_family();
274+
let options = socket.non_inherited_options().clone();
275+
Ok(StreamReader::new(
276+
&mut store,
277+
ListenStreamProducer {
278+
listener,
279+
family,
280+
options,
281+
getter,
282+
},
283+
))
303284
}
304285

305286
async fn send<T: 'static>(
@@ -328,39 +309,52 @@ impl HostTcpSocketWithStore for WasiSockets {
328309
Ok(())
329310
}
330311

331-
async fn receive<T: 'static>(
332-
store: &Accessor<T, Self>,
312+
fn receive<T: 'static>(
313+
mut store: Access<T, Self>,
333314
socket: Resource<TcpSocket>,
334315
) -> wasmtime::Result<(StreamReader<u8>, FutureReader<Result<(), ErrorCode>>)> {
335-
store.with(|mut store| {
336-
let socket = get_socket_mut(store.get().table, &socket)?;
337-
match socket.start_receive() {
338-
Some(stream) => {
339-
let stream = Arc::clone(stream);
340-
let (result_tx, result_rx) = oneshot::channel();
341-
Ok((
342-
StreamReader::new(
343-
&mut store,
344-
ReceiveStreamProducer {
345-
stream,
346-
result: Some(result_tx),
347-
},
348-
),
349-
FutureReader::new(&mut store, result_rx),
350-
))
351-
}
352-
None => Ok((
353-
StreamReader::new(&mut store, iter::empty()),
354-
FutureReader::new(&mut store, async {
355-
anyhow::Ok(Err(ErrorCode::InvalidState))
356-
}),
357-
)),
316+
let socket = get_socket_mut(store.get().table, &socket)?;
317+
match socket.start_receive() {
318+
Some(stream) => {
319+
let stream = Arc::clone(stream);
320+
let (result_tx, result_rx) = oneshot::channel();
321+
Ok((
322+
StreamReader::new(
323+
&mut store,
324+
ReceiveStreamProducer {
325+
stream,
326+
result: Some(result_tx),
327+
},
328+
),
329+
FutureReader::new(&mut store, result_rx),
330+
))
358331
}
359-
})
332+
None => Ok((
333+
StreamReader::new(&mut store, iter::empty()),
334+
FutureReader::new(&mut store, async {
335+
anyhow::Ok(Err(ErrorCode::InvalidState))
336+
}),
337+
)),
338+
}
360339
}
361340
}
362341

363342
impl HostTcpSocket for WasiSocketsCtxView<'_> {
343+
async fn bind(
344+
&mut self,
345+
socket: Resource<TcpSocket>,
346+
local_address: IpSocketAddress,
347+
) -> SocketResult<()> {
348+
let local_address = SocketAddr::from(local_address);
349+
if !(self.ctx.socket_addr_check)(local_address, SocketAddrUse::TcpBind).await {
350+
return Err(ErrorCode::AccessDenied.into());
351+
}
352+
let socket = get_socket_mut(self.table, &socket)?;
353+
socket.start_bind(local_address)?;
354+
socket.finish_bind()?;
355+
Ok(())
356+
}
357+
364358
fn create(&mut self, address_family: IpAddressFamily) -> SocketResult<Resource<TcpSocket>> {
365359
let family = address_family.into();
366360
let socket = TcpSocket::new(self.ctx, family)?;

crates/wasi/src/p3/sockets/host/types/udp.rs

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,39 +30,6 @@ fn get_socket_mut<'a>(
3030
}
3131

3232
impl HostUdpSocketWithStore for WasiSockets {
33-
async fn bind<T>(
34-
store: &Accessor<T, Self>,
35-
socket: Resource<UdpSocket>,
36-
local_address: IpSocketAddress,
37-
) -> SocketResult<()> {
38-
let local_address = SocketAddr::from(local_address);
39-
if !is_addr_allowed(store, local_address, SocketAddrUse::UdpBind).await {
40-
return Err(ErrorCode::AccessDenied.into());
41-
}
42-
store.with(|mut view| {
43-
let socket = get_socket_mut(view.get().table, &socket)?;
44-
socket.bind(local_address)?;
45-
socket.finish_bind()?;
46-
Ok(())
47-
})
48-
}
49-
50-
async fn connect<T>(
51-
store: &Accessor<T, Self>,
52-
socket: Resource<UdpSocket>,
53-
remote_address: IpSocketAddress,
54-
) -> SocketResult<()> {
55-
let remote_address = SocketAddr::from(remote_address);
56-
if !is_addr_allowed(store, remote_address, SocketAddrUse::UdpConnect).await {
57-
return Err(ErrorCode::AccessDenied.into());
58-
}
59-
store.with(|mut view| {
60-
let socket = get_socket_mut(view.get().table, &socket)?;
61-
socket.connect(remote_address)?;
62-
Ok(())
63-
})
64-
}
65-
6633
async fn send<T>(
6734
store: &Accessor<T, Self>,
6835
socket: Resource<UdpSocket>,
@@ -103,6 +70,35 @@ impl HostUdpSocketWithStore for WasiSockets {
10370
}
10471

10572
impl HostUdpSocket for WasiSocketsCtxView<'_> {
73+
async fn bind(
74+
&mut self,
75+
socket: Resource<UdpSocket>,
76+
local_address: IpSocketAddress,
77+
) -> SocketResult<()> {
78+
let local_address = SocketAddr::from(local_address);
79+
if !(self.ctx.socket_addr_check)(local_address, SocketAddrUse::UdpBind).await {
80+
return Err(ErrorCode::AccessDenied.into());
81+
}
82+
let socket = get_socket_mut(self.table, &socket)?;
83+
socket.bind(local_address)?;
84+
socket.finish_bind()?;
85+
Ok(())
86+
}
87+
88+
async fn connect(
89+
&mut self,
90+
socket: Resource<UdpSocket>,
91+
remote_address: IpSocketAddress,
92+
) -> SocketResult<()> {
93+
let remote_address = SocketAddr::from(remote_address);
94+
if !(self.ctx.socket_addr_check)(remote_address, SocketAddrUse::UdpConnect).await {
95+
return Err(ErrorCode::AccessDenied.into());
96+
}
97+
let socket = get_socket_mut(self.table, &socket)?;
98+
socket.connect(remote_address)?;
99+
Ok(())
100+
}
101+
106102
fn create(&mut self, address_family: IpAddressFamily) -> SocketResult<Resource<UdpSocket>> {
107103
let socket = UdpSocket::new(self.ctx, address_family.into())?;
108104
self.table

crates/wit-bindgen/src/lib.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2508,7 +2508,11 @@ impl<'a> InterfaceGenerator<'a> {
25082508
} else {
25092509
uwriteln!(
25102510
self.src,
2511-
"let host = {wt}::component::Access::new(caller, host_getter);"
2511+
"let access_cx = {wt}::AsContextMut::as_context_mut(&mut caller);"
2512+
);
2513+
uwriteln!(
2514+
self.src,
2515+
"let host = {wt}::component::Access::new(access_cx, host_getter);"
25122516
);
25132517
}
25142518
} else {
@@ -2581,9 +2585,9 @@ impl<'a> InterfaceGenerator<'a> {
25812585
let convert = format!("{}::convert_{}", convert_trait, err_name.to_snake_case());
25822586
let convert = if flags.contains(FunctionFlags::STORE) {
25832587
if flags.contains(FunctionFlags::ASYNC) {
2584-
format!("host.with(|mut host| {convert}(&mut host.get(), e))?")
2588+
format!("caller.with(|mut host| {convert}(&mut host_getter(host.get()), e))?")
25852589
} else {
2586-
format!("{convert}(&mut host.get(), e)?")
2590+
format!("{convert}(&mut host_getter(caller.data_mut()), e)?")
25872591
}
25882592
} else {
25892593
format!("{convert}(host, e)?")

0 commit comments

Comments
 (0)