Fix listening socket epoll_wait not waken by connect

This commit is contained in:
Hui, Chunyang 2022-08-09 08:54:02 +00:00 committed by volcano
parent f87ee7c7a4
commit 71c4937b45
5 changed files with 157 additions and 20 deletions

@ -1,4 +1,5 @@
use super::endpoint::Endpoint;
use super::endpoint::RelayNotifier;
use super::stream::Listener;
use super::*;
use std::collections::btree_map::BTreeMap;
@ -49,13 +50,22 @@ impl AddressSpace {
}
}
pub fn add_listener(&self, addr: &Addr, capacity: usize, nonblocking: bool) -> Result<()> {
pub(super) fn add_listener(
&self,
addr: &Addr,
capacity: usize,
nonblocking: bool,
notifier: Arc<RelayNotifier>,
) -> Result<()> {
let key = Self::get_key(addr).ok_or_else(|| errno!(EINVAL, "the socket is not bound"))?;
let mut space = self.get_space(addr);
if let Some(option) = space.get(&key) {
if option.is_none() {
space.insert(key, Some(Arc::new(Listener::new(capacity, nonblocking)?)));
space.insert(
key,
Some(Arc::new(Listener::new(capacity, nonblocking, notifier)?)),
);
Ok(())
} else {
return_errno!(EINVAL, "the socket is already listened");

@ -1,3 +1,4 @@
use super::address_space::ADDRESS_SPACE;
use super::stream::Status;
use super::*;
use fs::{AccessMode, File, FileRef, IoEvents, IoNotifier, IoctlCmd, StatusFlags};
@ -100,12 +101,15 @@ impl File for Stream {
// linux return value
Status::Idle(info) => IoEvents::OUT | IoEvents::HUP,
Status::Connected(endpoint) => endpoint.poll(),
Status::Listening(_) => {
warn!("poll is not fully implemented for the listener socket");
Status::Listening(addr) => {
if let Some(listener) = ADDRESS_SPACE.get_listener_ref(addr) {
listener.poll_new()
} else {
IoEvents::empty()
}
}
}
}
fn notifier(&self) -> Option<&IoNotifier> {
Some(&self.notifier.notifier())

@ -121,7 +121,12 @@ impl Stream {
Status::Idle(info) => {
if let Some(addr) = info.addr() {
warn!("addr = {:?}", addr);
ADDRESS_SPACE.add_listener(addr, capacity, info.nonblocking())?;
ADDRESS_SPACE.add_listener(
addr,
capacity,
info.nonblocking(),
self.notifier.clone(),
)?;
*inner = Status::Listening(addr.clone());
} else {
return_errno!(EINVAL, "the socket is not bound");
@ -165,6 +170,11 @@ impl Stream {
self.notifier.observe_endpoint(&end_self);
// Notify listener for this event
if let Some(listener) = ADDRESS_SPACE.get_listener_ref(addr) {
listener.notifier.notifier().broadcast(&IoEvents::IN);
}
*inner = Status::Connected(end_self);
Ok(())
}
@ -322,16 +332,22 @@ impl Info {
/// ECONNREFUSED rather than block when the channel is full.
pub struct Listener {
channel: RwLock<Channel<Endpoint>>,
notifier: Arc<RelayNotifier>,
}
impl Listener {
pub fn new(capacity: usize, nonblocking: bool) -> Result<Self> {
pub(super) fn new(
capacity: usize,
nonblocking: bool,
notifier: Arc<RelayNotifier>,
) -> Result<Self> {
let channel = Channel::new(capacity)?;
channel.producer().set_nonblocking(true);
channel.consumer().set_nonblocking(nonblocking);
Ok(Self {
channel: RwLock::new(channel),
notifier,
})
}
@ -391,4 +407,14 @@ impl Listener {
let channel = self.channel.read().unwrap();
channel.shutdown();
}
pub fn poll_new(&self) -> IoEvents {
let mut events = IoEvents::empty();
let channel = self.channel.read().unwrap();
let item_num = channel.items_to_consume();
if item_num > 0 {
events |= IoEvents::IN;
}
events
}
}

@ -1,5 +1,5 @@
include ../test_common.mk
EXTRA_C_FLAGS := -Wno-incompatible-pointer-types-discards-qualifiers
EXTRA_LINK_FLAGS :=
EXTRA_LINK_FLAGS := -lpthread
BIN_ARGS :=

@ -9,6 +9,8 @@
#include <stdio.h>
#include <spawn.h>
#include <string.h>
#include <sys/epoll.h>
#include <pthread.h>
#include "test.h"
@ -279,7 +281,9 @@ int test_poll() {
};
int ret = poll(polls, 2, 5000);
if (ret <= 0) { THROW_ERROR("poll error"); }
if (ret <= 0) {
THROW_ERROR("poll error");
}
if (((polls[0].revents & POLLOUT) && (polls[1].revents & POLLIN)) == 0) {
printf("%d %d\n", polls[0].revents, polls[1].revents);
THROW_ERROR("wrong return events");
@ -369,6 +373,98 @@ int test_ioctl_fionread() {
return 0;
}
void *client_routine(void *arg) {
// Sleep a while before connect
sleep(3);
printf("sleep is done\n");
char *sock_path = "/tmp/test.sock";
struct sockaddr_un addr;
memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, sock_path);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1;
int client_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (client_fd == -1) {
printf("failed to create a unix socket\n");
return NULL;
}
if (connect(client_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(client_fd);
printf("failed to connect\n");
return NULL;
}
printf("connect success\n");
return NULL;
}
int test_epoll_wait() {
char *sock_path = "/tmp/test.sock";
int ret;
struct epoll_event event;
uint32_t interest_events = EPOLLIN | EPOLLOUT;
struct epoll_event polled_events;
pthread_t client_tid;
struct sockaddr_un addr;
int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (listen_fd == -1) {
THROW_ERROR("failed to create a unix socket");
}
memset(&addr, 0, sizeof(struct sockaddr_un)); // Clear structure
addr.sun_family = AF_UNIX;
strcpy(addr.sun_path, sock_path);
socklen_t addr_len = strlen(addr.sun_path) + sizeof(addr.sun_family) + 1;
unlink(addr.sun_path);
if (bind(listen_fd, (struct sockaddr *)&addr, addr_len) == -1) {
close(listen_fd);
THROW_ERROR("failed to bind");
}
if (listen(listen_fd, 5) != 0) {
THROW_ERROR("server_routine, error in listen");
}
int ep_fd = epoll_create1(0);
if (ep_fd < 0) {
THROW_ERROR("failed to create an epoll");
}
event.events = interest_events;
event.data.u32 = listen_fd;
ret = epoll_ctl(ep_fd, EPOLL_CTL_ADD, listen_fd, &event);
if (ret < 0) {
THROW_ERROR("failed to do epoll ctl");
}
if (pthread_create(&client_tid, NULL, client_routine, NULL)) {
THROW_ERROR("Failure creating client thread");
}
// wait infinitely
ret = epoll_wait(ep_fd, &polled_events, 1, -1);
if (ret != 1) {
THROW_ERROR("failed to do epoll wait");
}
if (polled_events.data.u32 != listen_fd) {
THROW_ERROR("epoll wait returned wrong fd");
}
int newsock = accept(listen_fd, (struct sockaddr *)&addr, &addr_len);
if (newsock == -1) {
THROW_ERROR("server_routine, error in accept");
}
printf("accept done, new socket = %d\n", newsock);
pthread_join(client_tid, NULL);
return 0;
}
static test_case_t test_cases[] = {
TEST_CASE(test_unix_socket_inter_process),
TEST_CASE(test_socketpair_inter_process),
@ -377,6 +473,7 @@ static test_case_t test_cases[] = {
TEST_CASE(test_getname),
TEST_CASE(test_ioctl_fionread),
TEST_CASE(test_unix_socket_rename),
TEST_CASE(test_epoll_wait),
};
int main(int argc, const char *argv[]) {