diff --git a/examples/test_tcp_server.cpp b/examples/test_tcp_server.cpp new file mode 100644 index 0000000..be6879f --- /dev/null +++ b/examples/test_tcp_server.cpp @@ -0,0 +1,52 @@ +#include + +#include +#include +#include +#include +#include + +int main() +{ + using namespace ouc_server::server; + using namespace ouc_server::ouc_socket; + + TCPServer server; + if (!server.start("127.0.0.1", 8080)) + { + std::cerr << "Failed to start server\n"; + return 1; + } + + puts("Start Listening!"); + + server.on_connection( + [](TCPSocket &client) + { std::cout << "New client connected, fd=" << client.get_fd() << "\n"; }); + + server.on_message( + [&](TCPSocket &client, const std::string &msg) + { + if (!memcmp("exit", msg.c_str(), 4)) + { + server.remove_fd(client); + return; + } + std::cout << "Received: " << msg; + + std::string ret_str("Echo: " + msg); + size_t count = client.send(ret_str.c_str()); + while (count < ret_str.size()) + count += client.send(ret_str.c_str() + count); + }); + + server.on_close( + [](TCPSocket &client) + { std::cout << "Client disconnected, fd=" << client.get_fd() << "\n"; }); + + while (true) + { + server.loop(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } +} diff --git a/include/ouc_server/server/tcp_server.cpp b/include/ouc_server/server/tcp_server.cpp new file mode 100644 index 0000000..f8d603c --- /dev/null +++ b/include/ouc_server/server/tcp_server.cpp @@ -0,0 +1,138 @@ +#include + +#include + +namespace ouc_server +{ + namespace server + { + TCPServer::TCPServer() + : server_socket(ouc_server::ouc_socket::TCPSocket::create()), + tasks(8) + { + } + + TCPServer::~TCPServer() + { + for (auto &[k, v] : clients) + v.close(); + server_socket.close(); + } + + bool TCPServer::start(const std::string &ip, uint16_t port) + { + if (server_socket.get_fd() < 0) + return false; + if (!server_socket.bind(ip, port)) + return false; + if (!server_socket.listen()) + return false; + + epoll_loop.add_fd( + server_socket.get_fd(), + EPOLLIN, + [this](int) + { + handle_new_connection(); + }); + return true; + } + + bool TCPServer::add_fd(int fd, ouc_server::ouc_socket::TCPSocket &&tcp_socket) + { + if (clients.count(fd) || tcp_socket.get_fd() < 0) + return false; + + clients.emplace(fd, std::move(tcp_socket)); + + if (!epoll_loop.add_fd( + fd, + EPOLLIN, + [this](int fd) + { + this->handle_client_event(fd); + })) + return false; + + if (on_connection_callback) + on_connection_callback(clients.at(fd)); + + return true; + } + + bool TCPServer::add_fd(int fd) + { + if (clients.count(fd) || fd < 0) + return false; + return add_fd(fd, ouc_server::ouc_socket::TCPSocket(fd)); + } + + bool TCPServer::remove_fd(ouc_server::ouc_socket::TCPSocket &client) + { + int fd = client.get_fd(); + if ((!clients.count(fd)) || fd < 0) + return false; + + if (!epoll_loop.remove_fd(client.get_fd())) + return false; + + if (on_close_callback) + on_close_callback(client); + + clients.erase(client.get_fd()); + return client.close(); + } + + bool TCPServer::remove_fd(int fd) + { + if ((!clients.count(fd)) || fd < 0) + return false; + + auto &client = clients.at(fd); + + return remove_fd(client); + } + + void TCPServer::handle_new_connection() + { + while (true) + { + auto client = server_socket.accept(); + while (client.get_fd() < 0) + client = server_socket.accept(); + + add_fd(client.get_fd(), std::move(client)); + } + } + + void TCPServer::handle_client_event(int fd) + { + auto &client = clients[fd]; + char buf[4096]; + while (true) + { + ssize_t n = client.recv(buf, sizeof(buf)); + if (n > 0) + { + std::string data(buf, n); + if (on_message_callback) + tasks.sumbit(on_message_callback, std::ref(client), data); + } + else if (n == 0) + { + remove_fd(fd); + return; + } + else + { + if (errno == EAGAIN || errno == EWOULDBLOCK) + continue; + if (errno == EINTR) + continue; + + remove_fd(fd); + } + } + } + } +} \ No newline at end of file diff --git a/include/ouc_server/server/tcp_server.hpp b/include/ouc_server/server/tcp_server.hpp new file mode 100644 index 0000000..fef2f06 --- /dev/null +++ b/include/ouc_server/server/tcp_server.hpp @@ -0,0 +1,59 @@ +#ifndef INCLUDE_OUC_SERVER_TCP_SERVER +#define INCLUDE_OUC_SERVER_TCP_SERVER + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace ouc_server +{ + namespace server + { + class TCPServer + { + public: + template + using Callback = std::function; + + private: + ouc_server::ouc_socket::TCPSocket server_socket; + ouc_server::epoll::EpollLoop epoll_loop; + ouc_server::utils::ThreadPool tasks; + std::map clients; + + Callback<> on_connection_callback; + Callback on_message_callback; + Callback<> on_close_callback; + + public: + TCPServer(); + ~TCPServer(); + + public: + void on_connection(Callback<> &&callback) { on_connection_callback = std::move(callback); } + void on_message(Callback &&callback) { on_message_callback = std::move(callback); } + void on_close(Callback<> &&callback) { on_close_callback = std::move(callback); } + + public: + bool start(const std::string &, uint16_t); + void loop() { epoll_loop.poll(); } + + bool add_fd(int, ouc_server::ouc_socket::TCPSocket &&); + bool add_fd(int); + bool remove_fd(ouc_server::ouc_socket::TCPSocket &); + bool remove_fd(int); + + private: + void handle_new_connection(); + void handle_client_event(int); + }; + } +} + +#endif // INCLUDE_OUC_SERVER_TCP_SERVER \ No newline at end of file