diff --git a/components/net/sal/impl/af_inet_at.c b/components/net/sal/impl/af_inet_at.c index 540e240c62c..636b60e7d7c 100644 --- a/components/net/sal/impl/af_inet_at.c +++ b/components/net/sal/impl/af_inet_at.c @@ -38,6 +38,7 @@ static int at_poll(struct dfs_file *file, struct rt_pollreq *req) } sock = at_get_socket((int)sal_sock->user_data); + sal_socket_put(sal_sock); if (sock != NULL) { rt_base_t level; diff --git a/components/net/sal/impl/af_inet_lwip.c b/components/net/sal/impl/af_inet_lwip.c index 876dfba535f..30c488ff5f0 100644 --- a/components/net/sal/impl/af_inet_lwip.c +++ b/components/net/sal/impl/af_inet_lwip.c @@ -258,6 +258,7 @@ static int inet_poll(struct dfs_file *file, struct rt_pollreq *req) } sock = lwip_tryget_socket((int)(size_t)sal_sock->user_data); + sal_socket_put(sal_sock); if (sock != NULL) { rt_base_t level; diff --git a/components/net/sal/impl/proto_mbedtls.c b/components/net/sal/impl/proto_mbedtls.c index 4ae9f263281..5e61726d89a 100644 --- a/components/net/sal/impl/proto_mbedtls.c +++ b/components/net/sal/impl/proto_mbedtls.c @@ -96,6 +96,7 @@ int mbedtls_net_send_cb(void *ctx, const unsigned char *buf, size_t len) /* Register scoket sendto option to TLS send data callback */ ret = pf->skt_ops->sendto((int) sock->user_data, (void *)buf, len, 0, RT_NULL, RT_NULL); + sal_socket_put(sock); if (ret < 0) { #ifdef RT_USING_DFS @@ -133,6 +134,7 @@ int mbedtls_net_recv_cb( void *ctx, unsigned char *buf, size_t len) /* Register scoket recvfrom option to TLS recv data callback */ ret = pf->skt_ops->recvfrom((int) sock->user_data, (void *)buf, len, 0, RT_NULL, RT_NULL); + sal_socket_put(sock); if (ret < 0) { #ifdef RT_USING_DFS @@ -199,24 +201,13 @@ static int mbedtls_connect(void *sock) static int mbedtls_closesocket(void *sock) { - struct sal_socket *ssock; - int socket; - if (sock == RT_NULL) { return 0; } - socket = ((MbedTLSSession *) sock)->server_fd.fd; - ssock = sal_get_socket(socket); - if (ssock == RT_NULL) - { - return -1; - } - /* Close TLS client session, and clean user-data in SAL socket */ mbedtls_client_close((MbedTLSSession *) sock); - ssock->user_data_tls = RT_NULL; return 0; } diff --git a/components/net/sal/include/sal_low_lvl.h b/components/net/sal/include/sal_low_lvl.h index c19c10ffc14..bf5d6e3a0ea 100644 --- a/components/net/sal/include/sal_low_lvl.h +++ b/components/net/sal/include/sal_low_lvl.h @@ -14,6 +14,7 @@ #define SAL_LOW_LEVEL_H__ #include +#include #ifdef SAL_USING_POSIX #include @@ -30,6 +31,14 @@ typedef uint32_t socklen_t; /* SAL socket magic word */ #define SAL_SOCKET_MAGIC 0x5A10 +enum sal_socket_state +{ + SAL_SOCKET_STATE_INIT = 0, + SAL_SOCKET_STATE_OPEN, /* visible in socket table */ + SAL_SOCKET_STATE_CLOSING, /* detached, waiting refcnt drop */ + SAL_SOCKET_STATE_CLOSED, /* ready for cache reuse */ +}; + /* The maximum number of sockets structure */ #ifndef SAL_SOCKETS_NUM #define SAL_SOCKETS_NUM DFS_FD_MAX @@ -63,6 +72,10 @@ struct sal_socket #ifdef SAL_USING_TLS void *user_data_tls; /* user-specific TLS data */ #endif + rt_atomic_t refcnt; /* in-flight SAL references */ + rt_uint8_t state; /* socket lifecycle state */ + struct rt_completion close_completion; /* wake close waiter */ + struct sal_socket *next_free; /* internal cache link */ }; /* network interface socket opreations */ @@ -109,8 +122,10 @@ struct sal_proto_family /* SAL(Socket Abstraction Layer) initialize */ int sal_init(void); -/* Get SAL socket object by socket descriptor */ +/* Get socket object and hold a temporary reference. */ struct sal_socket *sal_get_socket(int sock); +/* Release reference returned by sal_get_socket(). */ +void sal_socket_put(struct sal_socket *sock); /* check SAL socket netweork interface device internet status */ int sal_check_netdev_internet_up(struct netdev *netdev); diff --git a/components/net/sal/src/sal_socket.c b/components/net/sal/src/sal_socket.c index fa9eb9cc746..e0465676674 100644 --- a/components/net/sal/src/sal_socket.c +++ b/components/net/sal/src/sal_socket.c @@ -87,6 +87,7 @@ static struct sal_proto_tls *proto_tls; /* The global socket table */ static struct sal_socket_table socket_table; static struct rt_mutex sal_core_lock; +static struct sal_socket *sal_socket_cache; static rt_bool_t init_ok = RT_FALSE; static struct sal_netdev_res_table sal_dev_res_tbl[SAL_SOCKETS_NUM]; @@ -113,6 +114,12 @@ static struct sal_netdev_res_table sal_dev_res_tbl[SAL_SOCKETS_NUM]; } \ } while (0) +#define SAL_SOCKET_OBJ_PUT(sock) \ + do \ + { \ + sal_socket_put(sock); \ + } while (0) + #define SAL_NETDEV_IS_UP(netdev) \ do \ { \ @@ -142,6 +149,21 @@ static struct sal_netdev_res_table sal_dev_res_tbl[SAL_SOCKETS_NUM]; ((pf) = (struct sal_proto_family *)(netdev)->sal_user_data) != RT_NULL && \ (pf)->netdb_ops->ops) +/* + * Lifetime helper declarations. + * + * The SAL socket table is the publication point for descriptors, while the + * reference count protects already-running operations after a descriptor has + * been removed from that table. These helpers keep the state transitions small + * and explicit so the close path can be audited without searching every caller. + */ +static struct sal_socket *__sal_socket_lookup_locked(int socket); +static void __sal_socket_cache_free(struct sal_socket *sock); +static void __sal_socket_wait_refs(struct sal_socket *sock); +static int __sal_netdev_is_up(struct netdev *netdev); +static void sal_lock(void); +static void sal_unlock(void); + /** * SAL (Socket Abstraction Layer) initialize. * @@ -357,15 +379,9 @@ int sal_proto_tls_register(const struct sal_proto_tls *pt) } #endif -/** - * This function will get sal socket object by sal socket descriptor. - * - * @param socket sal socket index - * - * @return sal socket object of the current sal socket index - */ -struct sal_socket *sal_get_socket(int socket) +static struct sal_socket *__sal_socket_lookup_locked(int socket) { + struct sal_socket *sock; struct sal_socket_table *st = &socket_table; socket = socket - SAL_SOCKET_OFFSET; @@ -375,10 +391,82 @@ struct sal_socket *sal_get_socket(int socket) return RT_NULL; } - /* check socket structure valid or not */ - RT_ASSERT(st->sockets[socket]->magic == SAL_SOCKET_MAGIC); + sock = st->sockets[socket]; + if (sock == RT_NULL || sock->magic != SAL_SOCKET_MAGIC) + { + return RT_NULL; + } + + return sock; +} + +struct sal_socket *sal_get_socket(int socket) +{ + struct sal_socket *sock; + + sal_lock(); + sock = __sal_socket_lookup_locked(socket); + /* Only published sockets can accept new users. */ + if (sock != RT_NULL && sock->state == SAL_SOCKET_STATE_OPEN) + { + rt_atomic_add(&sock->refcnt, 1); + } + else + { + sock = RT_NULL; + } + sal_unlock(); + + return sock; +} + +void sal_socket_put(struct sal_socket *sock) +{ + rt_atomic_t old; + + if (sock == RT_NULL) + { + return; + } + + old = rt_atomic_sub(&sock->refcnt, 1); + /* Wake waiters when the last in-flight user leaves. */ + if (old == 1) + { + rt_completion_done(&sock->close_completion); + } +} - return st->sockets[socket]; +static void __sal_socket_wait_refs(struct sal_socket *sock) +{ + /* New lookups are blocked before this wait starts. */ + while (rt_atomic_load(&sock->refcnt) != 0) + { + rt_completion_wait(&sock->close_completion, RT_WAITING_FOREVER); + } +} + +static void __sal_socket_cache_free(struct sal_socket *sock) +{ + if (sock == RT_NULL) + { + return; + } + + sock->magic = 0; + sock->state = SAL_SOCKET_STATE_CLOSED; + sock->user_data = RT_NULL; +#ifdef SAL_USING_TLS + sock->user_data_tls = RT_NULL; +#endif + sock->netdev = RT_NULL; + sock->next_free = sal_socket_cache; + sal_socket_cache = sock; +} + +static int __sal_netdev_is_up(struct netdev *netdev) +{ + return netdev_is_up(netdev) ? 0 : -1; } /** @@ -527,6 +615,7 @@ static int socket_init(int family, int type, int protocol, struct sal_socket **r static int socket_alloc(struct sal_socket_table *st, int f_socket) { + struct sal_socket *sock; int idx; /* find an empty socket entry */ @@ -565,7 +654,29 @@ static int socket_alloc(struct sal_socket_table *st, int f_socket) /* allocate 'struct sal_socket' */ if (idx < (int)st->max_socket && st->sockets[idx] == RT_NULL) { - st->sockets[idx] = rt_calloc(1, sizeof(struct sal_socket)); + if (sal_socket_cache != RT_NULL) + { + /* Reuse a detached object from local cache. */ + sock = sal_socket_cache; + sal_socket_cache = sal_socket_cache->next_free; + rt_memset(sock, 0x00, sizeof(struct sal_socket)); + rt_completion_init(&sock->close_completion); + rt_atomic_store(&sock->refcnt, 0); + sock->state = SAL_SOCKET_STATE_INIT; + st->sockets[idx] = sock; + } + else + { + sock = rt_calloc(1, sizeof(struct sal_socket)); + if (sock != RT_NULL) + { + rt_completion_init(&sock->close_completion); + rt_atomic_store(&sock->refcnt, 0); + sock->state = SAL_SOCKET_STATE_INIT; + } + st->sockets[idx] = sock; + } + if (st->sockets[idx] == RT_NULL) { idx = st->max_socket; @@ -576,15 +687,6 @@ static int socket_alloc(struct sal_socket_table *st, int f_socket) return idx; } -static void socket_free(struct sal_socket_table *st, int idx) -{ - struct sal_socket *sock; - - sock = st->sockets[idx]; - st->sockets[idx] = RT_NULL; - rt_free(sock); -} - static int socket_new(void) { struct sal_socket *sock; @@ -604,6 +706,7 @@ static int socket_new(void) } sock = st->sockets[idx]; + /* Publish the slot after runtime fields are reset. */ sock->socket = idx + SAL_SOCKET_OFFSET; sock->magic = SAL_SOCKET_MAGIC; sock->netdev = RT_NULL; @@ -611,6 +714,9 @@ static int socket_new(void) #ifdef SAL_USING_TLS sock->user_data_tls = RT_NULL; #endif + rt_atomic_store(&sock->refcnt, 0); + rt_completion_init(&sock->close_completion); + sock->state = SAL_SOCKET_STATE_OPEN; __result: sal_unlock(); @@ -629,28 +735,44 @@ static void socket_delete(int socket) return; } sal_lock(); - sock = sal_get_socket(socket); - RT_ASSERT(sock != RT_NULL); - sock->magic = 0; - sock->netdev = RT_NULL; - socket_free(st, idx); + sock = __sal_socket_lookup_locked(socket); + if (sock == RT_NULL) + { + sal_unlock(); + return; + } + /* Stop new lookups before waiting current users to exit. */ + sock->state = SAL_SOCKET_STATE_CLOSING; + st->sockets[idx] = RT_NULL; + sal_unlock(); + + rt_completion_init(&sock->close_completion); + __sal_socket_wait_refs(sock); + sal_lock(); + __sal_socket_cache_free(sock); sal_unlock(); } int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen) { int new_socket; + int ret = -1; struct sal_socket *sock; struct sal_proto_family *pf; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); - /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } - /* check the network interface socket operations */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->accept == RT_NULL) + { + goto __exit; + } new_socket = pf->skt_ops->accept((int)(size_t)sock->user_data, addr, addrlen); if (new_socket != -1) @@ -665,29 +787,30 @@ int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen) if (new_sock == RT_NULL) { pf->skt_ops->closesocket(new_socket); - return -1; + goto __exit; } retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock); if (retval < 0) { pf->skt_ops->closesocket(new_socket); - rt_memset(new_sock, 0x00, sizeof(struct sal_socket)); + sal_socket_put(new_sock); /* socket init failed, delete socket */ socket_delete(new_sal_socket); LOG_E("New socket registered failed, return error %d.", retval); - return -1; + goto __exit; } /* new socket create by accept should have the same netdev with server*/ new_sock->netdev = sock->netdev; /* socket structure user_data used to store the acquired new socket */ new_sock->user_data = (void *)(size_t)new_socket; - - return new_sal_socket; + sal_socket_put(new_sock); + ret = new_sal_socket; } - - return -1; +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local_ipaddr) @@ -708,6 +831,7 @@ int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; ip_addr_t input_ipaddr; RT_ASSERT(name); @@ -730,34 +854,54 @@ int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen) new_netdev = netdev_get_by_ipaddr(&input_ipaddr); if (new_netdev == RT_NULL) { - return -1; + goto __exit; } /* get input and local ip address proto_family */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, local_pf, bind); - SAL_NETDEV_SOCKETOPS_VALID(new_netdev, input_pf, bind); + local_pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + input_pf = (struct sal_proto_family *)new_netdev->sal_user_data; + if (local_pf == RT_NULL || local_pf->skt_ops == RT_NULL || + local_pf->skt_ops->bind == RT_NULL || local_pf->skt_ops->closesocket == RT_NULL || + input_pf == RT_NULL || input_pf->skt_ops == RT_NULL || + input_pf->skt_ops->bind == RT_NULL || input_pf->skt_ops->socket == RT_NULL) + { + goto __exit; + } /* check the network interface protocol family type */ if (input_pf->family != local_pf->family) { + int old_socket; int new_socket = -1; - /* protocol family is different, close old socket and create new socket by input ip address */ - local_pf->skt_ops->closesocket(socket); - + old_socket = (int)(size_t)sock->user_data; new_socket = input_pf->skt_ops->socket(input_pf->family, sock->type, sock->protocol); if (new_socket < 0) { - return -1; + goto __exit; } + + if (local_pf->skt_ops->closesocket(old_socket) < 0) + { + input_pf->skt_ops->closesocket(new_socket); + goto __exit; + } + sock->netdev = new_netdev; sock->user_data = (void *)(size_t)new_socket; } } } /* check and get protocol families by the network interface device */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind); - return pf->skt_ops->bind((int)(size_t)sock->user_data, name, namelen); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->bind == RT_NULL) + { + goto __exit; + } + ret = pf->skt_ops->bind((int)(size_t)sock->user_data, name, namelen); +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_shutdown(int socket, int how) @@ -771,7 +915,12 @@ int sal_shutdown(int socket, int how) /* shutdown operation not need to check network interface status */ /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, shutdown); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->shutdown == RT_NULL) + { + error = -1; + goto __exit; + } if (pf->skt_ops->shutdown((int)(size_t)sock->user_data, how) == 0) { @@ -780,8 +929,10 @@ int sal_shutdown(int socket, int how) { if (proto_tls->ops->closesocket(sock->user_data_tls) < 0) { - return -1; + error = -1; + goto __exit; } + sock->user_data_tls = RT_NULL; } #endif error = 0; @@ -790,8 +941,8 @@ int sal_shutdown(int socket, int how) { error = -1; } - - +__exit: + SAL_SOCKET_OBJ_PUT(sock); return error; } @@ -799,54 +950,74 @@ int sal_getpeername(int socket, struct sockaddr *name, socklen_t *namelen) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getpeername); - - return pf->skt_ops->getpeername((int)(size_t)sock->user_data, name, namelen); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf != RT_NULL && pf->skt_ops != RT_NULL && pf->skt_ops->getpeername != RT_NULL) + { + ret = pf->skt_ops->getpeername((int)(size_t)sock->user_data, name, namelen); + } + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_getsockname(int socket, struct sockaddr *name, socklen_t *namelen) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockname); - - return pf->skt_ops->getsockname((int)(size_t)sock->user_data, name, namelen); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf != RT_NULL && pf->skt_ops != RT_NULL && pf->skt_ops->getsockname != RT_NULL) + { + ret = pf->skt_ops->getsockname((int)(size_t)sock->user_data, name, namelen); + } + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_getsockopt(int socket, int level, int optname, void *optval, socklen_t *optlen) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, getsockopt); - - return pf->skt_ops->getsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf != RT_NULL && pf->skt_ops != RT_NULL && pf->skt_ops->getsockopt != RT_NULL) + { + ret = pf->skt_ops->getsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen); + } + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_setsockopt(int socket, int level, int optname, const void *optval, socklen_t optlen) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, setsockopt); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->setsockopt == RT_NULL) + { + goto __exit; + } #ifdef SAL_USING_TLS if (level == SOL_TLS) @@ -854,34 +1025,40 @@ int sal_setsockopt(int socket, int level, int optname, const void *optval, sockl switch (optname) { case TLS_CRET_LIST: - SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_cret_list, optval, optlen); + ret = (SAL_SOCKOPS_PROTO_TLS_VALID(sock, set_cret_list)) ? + proto_tls->ops->set_cret_list(sock->user_data_tls, optval, optlen) : -1; break; case TLS_CIPHERSUITE_LIST: - SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_ciphersurite, optval, optlen); + ret = (SAL_SOCKOPS_PROTO_TLS_VALID(sock, set_ciphersurite)) ? + proto_tls->ops->set_ciphersurite(sock->user_data_tls, optval, optlen) : -1; break; case TLS_PEER_VERIFY: - SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_peer_verify, optval, optlen); + ret = (SAL_SOCKOPS_PROTO_TLS_VALID(sock, set_peer_verify)) ? + proto_tls->ops->set_peer_verify(sock->user_data_tls, optval, optlen) : -1; break; case TLS_DTLS_ROLE: - SAL_SOCKOPT_PROTO_TLS_EXEC(sock, set_dtls_role, optval, optlen); + ret = (SAL_SOCKOPS_PROTO_TLS_VALID(sock, set_dtls_role)) ? + proto_tls->ops->set_dtls_role(sock->user_data_tls, optval, optlen) : -1; break; default: - return -1; + goto __exit; } - - return 0; + ret = (ret == 0) ? 0 : -1; } else { - return pf->skt_ops->setsockopt((int)sock->user_data, level, optname, optval, optlen); + ret = pf->skt_ops->setsockopt((int)sock->user_data, level, optname, optval, optlen); } #else - return pf->skt_ops->setsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen); + ret = pf->skt_ops->setsockopt((int)(size_t)sock->user_data, level, optname, optval, optlen); #endif /* SAL_USING_TLS */ +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) @@ -894,9 +1071,18 @@ int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + ret = -1; + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, connect); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->connect == RT_NULL) + { + ret = -1; + goto __exit; + } ret = pf->skt_ops->connect((int)(size_t)sock->user_data, name, namelen); #ifdef SAL_USING_TLS @@ -904,13 +1090,13 @@ int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) { if (proto_tls->ops->connect(sock->user_data_tls) < 0) { - return -1; + ret = -1; + goto __exit; } - - return ret; } #endif - +__exit: + SAL_SOCKET_OBJ_PUT(sock); return ret; } @@ -918,80 +1104,103 @@ int sal_listen(int socket, int backlog) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, listen); - - return pf->skt_ops->listen((int)(size_t)sock->user_data, backlog); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf != RT_NULL && pf->skt_ops != RT_NULL && pf->skt_ops->listen != RT_NULL) + { + ret = pf->skt_ops->listen((int)(size_t)sock->user_data, backlog); + } + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_sendmsg(int socket, const struct msghdr *message, int flags) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendmsg); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->sendmsg == RT_NULL) + { + goto __exit; + } #ifdef SAL_USING_TLS if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send)) { - int ret; - if ((ret = proto_tls->ops->send(sock->user_data_tls, message, flags)) < 0) { - return -1; + ret = -1; + goto __exit; } - return ret; } else { - return pf->skt_ops->sendmsg((int)(size_t)sock->user_data, message, flags); + ret = pf->skt_ops->sendmsg((int)(size_t)sock->user_data, message, flags); } #else - return pf->skt_ops->sendmsg((int)(size_t)sock->user_data, message, flags); + ret = pf->skt_ops->sendmsg((int)(size_t)sock->user_data, message, flags); #endif +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_recvmsg(int socket, struct msghdr *message, int flags) { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvmsg); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->recvmsg == RT_NULL) + { + goto __exit; + } #ifdef SAL_USING_TLS if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv)) { - int ret; - if ((ret = proto_tls->ops->recv(sock->user_data_tls, message, flags)) < 0) { - return -1; + ret = -1; + goto __exit; } - return ret; } else { - return pf->skt_ops->recvmsg((int)(size_t)sock->user_data, message, flags); + ret = pf->skt_ops->recvmsg((int)(size_t)sock->user_data, message, flags); } #else - return pf->skt_ops->recvmsg((int)(size_t)sock->user_data, message, flags); + ret = pf->skt_ops->recvmsg((int)(size_t)sock->user_data, message, flags); #endif +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_recvfrom(int socket, void *mem, size_t len, int flags, @@ -999,33 +1208,42 @@ int sal_recvfrom(int socket, void *mem, size_t len, int flags, { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, recvfrom); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->recvfrom == RT_NULL) + { + goto __exit; + } #ifdef SAL_USING_TLS if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv)) { - int ret; - if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0) { - return -1; + ret = -1; + goto __exit; } - return ret; } else { - return pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen); + ret = pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen); } #else - return pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen); + ret = pf->skt_ops->recvfrom((int)(size_t)sock->user_data, mem, len, flags, from, fromlen); #endif +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_sendto(int socket, const void *dataptr, size_t size, int flags, @@ -1033,33 +1251,42 @@ int sal_sendto(int socket, const void *dataptr, size_t size, int flags, { struct sal_socket *sock; struct sal_proto_family *pf; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, sendto); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->sendto == RT_NULL) + { + goto __exit; + } #ifdef SAL_USING_TLS if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send)) { - int ret; - if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0) { - return -1; + ret = -1; + goto __exit; } - return ret; } else { - return pf->skt_ops->sendto((int)sock->user_data, dataptr, size, flags, to, tolen); + ret = pf->skt_ops->sendto((int)sock->user_data, dataptr, size, flags, to, tolen); } #else - return pf->skt_ops->sendto((int)(size_t)sock->user_data, dataptr, size, flags, to, tolen); + ret = pf->skt_ops->sendto((int)(size_t)sock->user_data, dataptr, size, flags, to, tolen); #endif +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } int sal_socket(int domain, int type, int protocol) @@ -1089,12 +1316,19 @@ int sal_socket(int domain, int type, int protocol) if (retval < 0) { LOG_E("SAL socket protocol family input failed, return error %d.", retval); + sal_socket_put(sock); socket_delete(socket); return retval; } /* valid the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, socket); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->socket == RT_NULL) + { + sal_socket_put(sock); + socket_delete(socket); + return -1; + } proto_socket = pf->skt_ops->socket(domain, type, protocol); if (proto_socket >= 0) @@ -1105,14 +1339,18 @@ int sal_socket(int domain, int type, int protocol) sock->user_data_tls = proto_tls->ops->socket(socket); if (sock->user_data_tls == RT_NULL) { + sal_socket_put(sock); socket_delete(socket); return -1; } } #endif sock->user_data = (void *)(size_t)proto_socket; - return sock->socket; + retval = sock->socket; + sal_socket_put(sock); + return retval; } + sal_socket_put(sock); socket_delete(socket); return -1; } @@ -1120,31 +1358,45 @@ int sal_socket(int domain, int type, int protocol) int sal_socketpair(int domain, int type, int protocol, int *fds) { int unix_fd[2]; - struct sal_socket *socka; - struct sal_socket *sockb; + struct sal_socket *socka = RT_NULL; + struct sal_socket *sockb = RT_NULL; struct sal_proto_family *pf; + int ret = -1; if (domain == AF_UNIX) { - /* get the socket object by socket descriptor */ - SAL_SOCKET_OBJ_GET(socka, fds[0]); - SAL_SOCKET_OBJ_GET(sockb, fds[1]); + socka = sal_get_socket(fds[0]); + if (socka == RT_NULL) + { + return -1; + } + + sockb = sal_get_socket(fds[1]); + if (sockb == RT_NULL) + { + goto __exit; + } /* valid the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(socka->netdev, pf, socket); + pf = (struct sal_proto_family *)socka->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->socketpair == RT_NULL) + { + goto __exit; + } unix_fd[0] = (int)(size_t)socka->user_data; unix_fd[1] = (int)(size_t)sockb->user_data; - - if (pf->skt_ops->socketpair) - { - return pf->skt_ops->socketpair(domain, type, protocol, unix_fd); - } + ret = pf->skt_ops->socketpair(domain, type, protocol, unix_fd); + goto __exit; } - rt_set_errno(EINVAL); - - return -1; +__exit: + if (domain == AF_UNIX) + { + SAL_SOCKET_OBJ_PUT(sockb); + SAL_SOCKET_OBJ_PUT(socka); + } + return ret; } int sal_closesocket(int socket) @@ -1152,13 +1404,30 @@ int sal_closesocket(int socket) struct sal_socket *sock; struct sal_proto_family *pf; int error = 0; + int idx; - /* get the socket object by socket descriptor */ - SAL_SOCKET_OBJ_GET(sock, socket); + idx = socket - SAL_SOCKET_OFFSET; + sal_lock(); + sock = __sal_socket_lookup_locked(socket); + if (sock == RT_NULL || sock->state != SAL_SOCKET_STATE_OPEN) + { + sal_unlock(); + return -1; + } + /* Remove from table first, backend close later. */ + sock->state = SAL_SOCKET_STATE_CLOSING; + socket_table.sockets[idx] = RT_NULL; + sal_unlock(); - /* clsoesocket operation not need to vaild network interface status */ - /* valid the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, closesocket); + rt_completion_init(&sock->close_completion); + __sal_socket_wait_refs(sock); + + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->closesocket == RT_NULL) + { + error = -1; + goto __exit; + } if (pf->skt_ops->closesocket((int)(size_t)sock->user_data) == 0) { @@ -1167,7 +1436,8 @@ int sal_closesocket(int socket) { if (proto_tls->ops->closesocket(sock->user_data_tls) < 0) { - return -1; + error = -1; + goto __exit; } } #endif @@ -1177,10 +1447,10 @@ int sal_closesocket(int socket) { error = -1; } - - /* delete socket */ - socket_delete(socket); - +__exit: + sal_lock(); + __sal_socket_cache_free(sock); + sal_unlock(); return error; } @@ -1190,19 +1460,15 @@ int sal_closesocket(int socket) #define IFF_RUNNING 0x40 #define IFF_NOARP 0x80 -int sal_ioctlsocket(int socket, long cmd, void *arg) +static int __sal_ioctlsocket(struct sal_socket *sock, long cmd, void *arg) { rt_slist_t *node = RT_NULL; struct netdev *netdev = RT_NULL; struct netdev *cur_netdev_list = netdev_list; - struct sal_socket *sock; struct sal_proto_family *pf; struct sockaddr_in *addr_in = RT_NULL; struct sockaddr *addr = RT_NULL; ip_addr_t input_ipaddr; - /* get the socket object by socket descriptor */ - SAL_SOCKET_OBJ_GET(sock, socket); - struct sal_ifreq *ifr = (struct sal_ifreq *)arg; if (ifr != RT_NULL) @@ -1519,27 +1785,54 @@ int sal_ioctlsocket(int socket, long cmd, void *arg) } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, ioctlsocket); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->ioctlsocket == RT_NULL) + { + return -1; + } return pf->skt_ops->ioctlsocket((int)(size_t)sock->user_data, cmd, arg); } +int sal_ioctlsocket(int socket, long cmd, void *arg) +{ + int ret; + struct sal_socket *sock; + + SAL_SOCKET_OBJ_GET(sock, socket); + ret = __sal_ioctlsocket(sock, cmd, arg); + SAL_SOCKET_OBJ_PUT(sock); + + return ret; +} + #ifdef SAL_USING_POSIX int sal_poll(struct dfs_file *file, struct rt_pollreq *req) { struct sal_socket *sock; struct sal_proto_family *pf; int socket = (int)(size_t)file->vnode->data; + int ret = -1; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface is up status */ - SAL_NETDEV_IS_UP(sock->netdev); + if (__sal_netdev_is_up(sock->netdev) != 0) + { + goto __exit; + } /* check the network interface socket opreation */ - SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, poll); + pf = (struct sal_proto_family *)sock->netdev->sal_user_data; + if (pf == RT_NULL || pf->skt_ops == RT_NULL || pf->skt_ops->poll == RT_NULL) + { + goto __exit; + } - return pf->skt_ops->poll(file, req); + ret = pf->skt_ops->poll(file, req); +__exit: + SAL_SOCKET_OBJ_PUT(sock); + return ret; } #endif diff --git a/components/net/utest/tc_sal_socket.c b/components/net/utest/tc_sal_socket.c index db6acd7b726..7798e12463e 100644 --- a/components/net/utest/tc_sal_socket.c +++ b/components/net/utest/tc_sal_socket.c @@ -67,6 +67,42 @@ static char local_ip[16] = "127.0.0.1"; /* Thread synchronization structures */ static struct rt_event test_event; static volatile int server_ready = 0; +static volatile int close_race_stop = 0; +static volatile int close_race_done = 0; +static volatile int close_call_done = 0; +static volatile int close_call_ret = -1; + +static void init_loopback_addr(struct sockaddr_in *addr, int port) +{ + memset(addr, 0, sizeof(*addr)); + addr->sin_family = AF_INET; + addr->sin_addr.s_addr = inet_addr(local_ip); + addr->sin_port = htons(port); +} + +static void close_race_thread(void *parameter) +{ + int sock = (int)(rt_size_t)parameter; + struct sockaddr_in addr; + socklen_t addr_len; + + while (!close_race_stop) + { + addr_len = sizeof(addr); + sal_getsockname(sock, (struct sockaddr *)&addr, &addr_len); + rt_thread_mdelay(1); + } + + close_race_done = 1; +} + +static void close_call_thread(void *parameter) +{ + int sock = (int)(rt_size_t)parameter; + + close_call_ret = sal_closesocket(sock); + close_call_done = 1; +} /* Test helper functions */ static int get_available_port(int base_port) @@ -164,6 +200,79 @@ static void close_test_socket(int sock) } } +static void wait_test_flag(volatile int *flag, int timeout_ms) +{ + rt_tick_t start = rt_tick_get(); + rt_tick_t timeout_tick = rt_tick_from_millisecond(timeout_ms); + + while (!*flag) + { + if ((rt_tick_get() - start) > timeout_tick) + { + break; + } + rt_thread_mdelay(1); + } +} + +static void verify_closed_stream_socket_ops(int sock) +{ + int ret; + int opt = 0; + int mode = 0; + struct sockaddr_in addr; + socklen_t addr_len = sizeof(addr); + + init_loopback_addr(&addr, TEST_SERVER_BASE_PORT); + + ret = sal_getsockname(sock, (struct sockaddr *)&addr, &addr_len); + uassert_int_equal(ret, -1); + + addr_len = sizeof(addr); + ret = sal_getpeername(sock, (struct sockaddr *)&addr, &addr_len); + uassert_int_equal(ret, -1); + + ret = sal_shutdown(sock, SHUT_RDWR); + uassert_int_equal(ret, -1); + + ret = sal_setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + uassert_int_equal(ret, -1); + + ret = sal_ioctlsocket(sock, FIONBIO, &mode); + uassert_int_equal(ret, -1); +} + +static void verify_closed_datagram_socket_ops(int sock) +{ + int ret; + struct sockaddr_in addr; + char data[] = "sal"; + + init_loopback_addr(&addr, TEST_SERVER_BASE_PORT + TEST_CLIENT_PORT_OFFSET); + ret = sal_sendto(sock, data, sizeof(data), 0, (struct sockaddr *)&addr, sizeof(addr)); + uassert_int_equal(ret, -1); +} + +static void exercise_socket_cache_reuse(int rounds) +{ + int i; + + for (i = 0; i < rounds; i++) + { + int tcp_sock = create_test_socket(AF_INET, SOCK_STREAM, 0); + int udp_sock = create_test_socket(AF_INET, SOCK_DGRAM, 0); + + uassert_true(tcp_sock >= 0); + uassert_true(udp_sock >= 0); + + uassert_int_equal(sal_closesocket(tcp_sock), 0); + uassert_int_equal(sal_closesocket(udp_sock), 0); + + verify_closed_stream_socket_ops(tcp_sock); + verify_closed_datagram_socket_ops(udp_sock); + } +} + /* Server thread function */ static void server_thread_entry(void *parameter) { @@ -922,6 +1031,10 @@ static void TC_sal_socket_udp_communication(void) static void TC_sal_socket_close(void) { int sock = -1; + int udp_sock = -1; + int ret; + struct sockaddr_in addr; + socklen_t addr_len = sizeof(addr); LOG_I("Starting TC_sal_socket_close tests..."); @@ -941,10 +1054,92 @@ static void TC_sal_socket_close(void) sock = create_test_socket(AF_INET, SOCK_STREAM, 0); sal_closesocket(sock); LOG_I("Double closing socket %d (should be safe)", sock); + uassert_int_equal(sal_closesocket(sock), -1); + + /* Closed socket must fail safely on follow-up use */ + ret = sal_getsockname(sock, (struct sockaddr *)&addr, &addr_len); + uassert_int_equal(ret, -1); + verify_closed_stream_socket_ops(sock); + + udp_sock = create_test_socket(AF_INET, SOCK_DGRAM, 0); + uassert_true(udp_sock >= 0); + uassert_int_equal(sal_closesocket(udp_sock), 0); + verify_closed_datagram_socket_ops(udp_sock); + + /* Race close against lookup/use on another thread */ + sock = create_test_socket(AF_INET, SOCK_STREAM, 0); + if (sock >= 0) + { + rt_thread_t worker; + + close_race_stop = 0; + close_race_done = 0; + worker = rt_thread_create("salrace", close_race_thread, (void *)(rt_size_t)sock, + 2048, RT_THREAD_PRIORITY_MAX / 2, 10); + uassert_true(worker != RT_NULL); + if (worker != RT_NULL) + { + rt_thread_startup(worker); + rt_thread_mdelay(20); + uassert_int_equal(sal_closesocket(sock), 0); + close_race_stop = 1; + wait_test_flag(&close_race_done, THREAD_WAIT_TIMEOUT); + uassert_true(close_race_done != 0); + } + } LOG_I("TC_sal_socket_close tests completed"); } +static void TC_sal_socketpair_invalid_fd(void) +{ + int sock = -1; + int fds[2]; + rt_thread_t worker = RT_NULL; + + LOG_I("Starting TC_sal_socketpair_invalid_fd tests..."); + + sock = create_test_socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) + { + LOG_W("Skip socketpair invalid fd test because socket create failed"); + return; + } + + fds[0] = sock; + fds[1] = -1; + uassert_int_equal(sal_socketpair(AF_UNIX, SOCK_STREAM, 0, fds), -1); + + close_call_done = 0; + close_call_ret = -1; + worker = rt_thread_create("salclose", close_call_thread, (void *)(rt_size_t)sock, + 2048, RT_THREAD_PRIORITY_MAX / 2, 10); + uassert_true(worker != RT_NULL); + if (worker != RT_NULL) + { + rt_thread_startup(worker); + wait_test_flag(&close_call_done, THREAD_WAIT_TIMEOUT); + uassert_true(close_call_done != 0); + uassert_int_equal(close_call_ret, 0); + } + + LOG_I("TC_sal_socketpair_invalid_fd tests completed"); +} + +static void TC_sal_socket_reuse_stress(void) +{ + int i; + + LOG_I("Starting TC_sal_socket_reuse_stress tests..."); + + for (i = 0; i < TEST_MAX_RETRY_ATTEMPTS; i++) + { + exercise_socket_cache_reuse(4); + } + + LOG_I("TC_sal_socket_reuse_stress tests completed"); +} + static void TC_sal_socket_getpeername_getsockname(void) { int server_sock = -1; @@ -1022,6 +1217,8 @@ static void utest_do_tc(void) UTEST_UNIT_RUN(TC_sal_socket_udp_communication); UTEST_UNIT_RUN(TC_sal_socket_getpeername_getsockname); UTEST_UNIT_RUN(TC_sal_socket_close); + UTEST_UNIT_RUN(TC_sal_socketpair_invalid_fd); + UTEST_UNIT_RUN(TC_sal_socket_reuse_stress); LOG_I("==========================================="); LOG_I("SAL Socket Basic API Tests Completed");