Remove trantor NonCopyable inheritances from core.

This commit is contained in:
Relintai 2022-02-10 10:05:10 +01:00
parent b8e0579a2a
commit 3534817831
12 changed files with 683 additions and 710 deletions

View File

@ -18,6 +18,7 @@
#ifdef __linux__ #ifdef __linux__
#include <sys/prctl.h> #include <sys/prctl.h>
#endif #endif
using namespace trantor; using namespace trantor;
ConcurrentTaskQueue::ConcurrentTaskQueue(size_t threadNum, ConcurrentTaskQueue::ConcurrentTaskQueue(size_t threadNum,
const std::string &name) : const std::string &name) :

View File

@ -17,72 +17,71 @@
#include "core/containers/task_queue.h" #include "core/containers/task_queue.h"
#include <list> #include <list>
#include <memory> #include <memory>
#include <vector>
#include <queue> #include <queue>
#include <string> #include <string>
#include <vector>
namespace trantor namespace trantor {
{
/** /**
* @brief This class implements a task queue running in parallel. Basically this * @brief This class implements a task queue running in parallel. Basically this
* can be called a threads pool. * can be called a threads pool.
* *
*/ */
class ConcurrentTaskQueue : public TaskQueue class ConcurrentTaskQueue : public TaskQueue {
{ public:
public: ConcurrentTaskQueue() {}
/**
* @brief Construct a new concurrent task queue instance.
*
* @param threadNum The number of threads in the queue.
* @param name The name of the queue.
*/
ConcurrentTaskQueue(size_t threadNum, const std::string &name);
/** /**
* @brief Run a task in the queue. * @brief Construct a new concurrent task queue instance.
* *
* @param task * @param threadNum The number of threads in the queue.
*/ * @param name The name of the queue.
virtual void runTaskInQueue(const std::function<void()> &task); */
virtual void runTaskInQueue(std::function<void()> &&task); ConcurrentTaskQueue(size_t threadNum, const std::string &name);
/** /**
* @brief Get the name of the queue. * @brief Run a task in the queue.
* *
* @return std::string * @param task
*/ */
virtual std::string getName() const virtual void runTaskInQueue(const std::function<void()> &task);
{ virtual void runTaskInQueue(std::function<void()> &&task);
return queueName_;
};
/** /**
* @brief Get the number of tasks to be executed in the queue. * @brief Get the name of the queue.
* *
* @return size_t * @return std::string
*/ */
size_t getTaskCount(); virtual std::string getName() const {
return queueName_;
};
/** /**
* @brief Stop all threads in the queue. * @brief Get the number of tasks to be executed in the queue.
* *
*/ * @return size_t
void stop(); */
size_t getTaskCount();
~ConcurrentTaskQueue(); /**
* @brief Stop all threads in the queue.
*
*/
void stop();
private: ~ConcurrentTaskQueue();
size_t queueCount_;
std::string queueName_;
std::queue<std::function<void()>> taskQueue_; private:
std::vector<std::thread> threads_; size_t queueCount_;
std::string queueName_;
std::mutex taskMutex_; std::queue<std::function<void()> > taskQueue_;
std::condition_variable taskCond_; std::vector<std::thread> threads_;
std::atomic_bool stop_;
void queueFunc(int queueNum); std::mutex taskMutex_;
std::condition_variable taskCond_;
std::atomic_bool stop_;
void queueFunc(int queueNum);
}; };
} // namespace trantor } // namespace trantor

View File

@ -14,10 +14,10 @@
#pragma once #pragma once
#include <assert.h> #include <assert.h>
#include <trantor/utils/NonCopyable.h>
#include <atomic> #include <atomic>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
namespace trantor { namespace trantor {
/** /**
* @brief This class template represents a lock-free multiple producers single * @brief This class template represents a lock-free multiple producers single

View File

@ -14,45 +14,48 @@
#pragma once #pragma once
#include "NonCopyable.h"
#include <functional> #include <functional>
#include <future> #include <future>
#include <string> #include <string>
namespace trantor
{ namespace trantor {
/** /**
* @brief This class is a pure virtual class that can be implemented as a * @brief This class is a pure virtual class that can be implemented as a
* SerialTaskQueue or a ConcurrentTaskQueue. * SerialTaskQueue or a ConcurrentTaskQueue.
* *
*/ */
class TaskQueue : public NonCopyable class TaskQueue {
{ protected:
public: TaskQueue(const TaskQueue &) = delete;
virtual void runTaskInQueue(const std::function<void()> &task) = 0; TaskQueue &operator=(const TaskQueue &) = delete;
virtual void runTaskInQueue(std::function<void()> &&task) = 0; // some uncopyable classes maybe support move constructor....
virtual std::string getName() const TaskQueue(TaskQueue &&) noexcept(true) = default;
{ TaskQueue &operator=(TaskQueue &&) noexcept(true) = default;
return "";
};
/** public:
* @brief Run a task in the queue sychronously. This means that the task is TaskQueue() {}
* executed before the method returns. virtual void runTaskInQueue(const std::function<void()> &task) = 0;
* virtual void runTaskInQueue(std::function<void()> &&task) = 0;
* @param task virtual std::string getName() const {
*/ return "";
void syncTaskInQueue(const std::function<void()> &task) };
{
std::promise<int> prom; /**
std::future<int> fut = prom.get_future(); * @brief Run a task in the queue sychronously. This means that the task is
runTaskInQueue([&]() { * executed before the method returns.
task(); *
prom.set_value(1); * @param task
}); */
fut.get(); void syncTaskInQueue(const std::function<void()> &task) {
}; std::promise<int> prom;
virtual ~TaskQueue() std::future<int> fut = prom.get_future();
{ runTaskInQueue([&]() {
} task();
prom.set_value(1);
});
fut.get();
};
virtual ~TaskQueue() {
}
}; };
} // namespace trantor } // namespace trantor

View File

@ -55,8 +55,8 @@ class Channel;
class Socket; class Socket;
class TcpServer; class TcpServer;
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
class TcpConnectionImpl : public TcpConnection, class TcpConnectionImpl : public TcpConnection,
public NonCopyable,
public std::enable_shared_from_this<TcpConnectionImpl> public std::enable_shared_from_this<TcpConnectionImpl>
{ {
friend class TcpServer; friend class TcpServer;
@ -64,6 +64,13 @@ class TcpConnectionImpl : public TcpConnection,
friend void trantor::removeConnection(EventLoop *loop, friend void trantor::removeConnection(EventLoop *loop,
const TcpConnectionPtr &conn); const TcpConnectionPtr &conn);
protected:
TcpConnectionImpl(const TcpConnectionImpl &) = delete;
TcpConnectionImpl &operator=(const TcpConnectionImpl &) = delete;
// some uncopyable classes maybe support move constructor....
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
public: public:
class KickoffEntry class KickoffEntry
{ {

View File

@ -19,70 +19,68 @@
#include <atomic> #include <atomic>
#include <memory> #include <memory>
namespace trantor namespace trantor {
{ class Connector : public std::enable_shared_from_this<Connector> {
class Connector : public NonCopyable, protected:
public std::enable_shared_from_this<Connector> Connector(const Connector &) = delete;
{ Connector &operator=(const Connector &) = delete;
public: // some uncopyable classes maybe support move constructor....
using NewConnectionCallback = std::function<void(int sockfd)>; Connector(Connector &&) noexcept(true) = default;
using ConnectionErrorCallback = std::function<void()>; Connector &operator=(Connector &&) noexcept(true) = default;
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
void setNewConnectionCallback(const NewConnectionCallback &cb)
{
newConnectionCallback_ = cb;
}
void setNewConnectionCallback(NewConnectionCallback &&cb)
{
newConnectionCallback_ = std::move(cb);
}
void setErrorCallback(const ConnectionErrorCallback &cb)
{
errorCallback_ = cb;
}
void setErrorCallback(ConnectionErrorCallback &&cb)
{
errorCallback_ = std::move(cb);
}
const InetAddress &serverAddress() const
{
return serverAddr_;
}
void start();
void restart();
void stop();
private: public:
NewConnectionCallback newConnectionCallback_; using NewConnectionCallback = std::function<void(int sockfd)>;
ConnectionErrorCallback errorCallback_; using ConnectionErrorCallback = std::function<void()>;
enum class Status Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
{ Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
Disconnected, void setNewConnectionCallback(const NewConnectionCallback &cb) {
Connecting, newConnectionCallback_ = cb;
Connected }
}; void setNewConnectionCallback(NewConnectionCallback &&cb) {
static constexpr int kMaxRetryDelayMs = 30 * 1000; newConnectionCallback_ = std::move(cb);
static constexpr int kInitRetryDelayMs = 500; }
std::shared_ptr<Channel> channelPtr_; void setErrorCallback(const ConnectionErrorCallback &cb) {
EventLoop *loop_; errorCallback_ = cb;
InetAddress serverAddr_; }
void setErrorCallback(ConnectionErrorCallback &&cb) {
errorCallback_ = std::move(cb);
}
const InetAddress &serverAddress() const {
return serverAddr_;
}
void start();
void restart();
void stop();
std::atomic_bool connect_{false}; private:
std::atomic<Status> status_{Status::Disconnected}; NewConnectionCallback newConnectionCallback_;
ConnectionErrorCallback errorCallback_;
enum class Status {
Disconnected,
Connecting,
Connected
};
static constexpr int kMaxRetryDelayMs = 30 * 1000;
static constexpr int kInitRetryDelayMs = 500;
std::shared_ptr<Channel> channelPtr_;
EventLoop *loop_;
InetAddress serverAddr_;
int retryInterval_{kInitRetryDelayMs}; std::atomic_bool connect_{ false };
int maxRetryInterval_{kMaxRetryDelayMs}; std::atomic<Status> status_{ Status::Disconnected };
bool retry_; int retryInterval_{ kInitRetryDelayMs };
int maxRetryInterval_{ kMaxRetryDelayMs };
void startInLoop(); bool retry_;
void connect();
void connecting(int sockfd); void startInLoop();
int removeAndResetChannel(); void connect();
void handleWrite(); void connecting(int sockfd);
void handleError(); int removeAndResetChannel();
void retry(int sockfd); void handleWrite();
void handleError();
void retry(int sockfd);
}; };
} // namespace trantor } // namespace trantor

View File

@ -6,143 +6,133 @@
// Author: Tao An // Author: Tao An
#pragma once #pragma once
#include "core/net/resolver.h"
#include <trantor/utils/NonCopyable.h>
#include "core/loops/event_loop_thread.h" #include "core/loops/event_loop_thread.h"
#include "core/net/resolver.h"
#include <string.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string.h>
extern "C" extern "C" {
{ struct hostent;
struct hostent; struct ares_channeldata;
struct ares_channeldata; using ares_channel = struct ares_channeldata *;
using ares_channel = struct ares_channeldata*;
} }
namespace trantor namespace trantor {
{
class AresResolver : public Resolver, class AresResolver : public Resolver,
public NonCopyable, public std::enable_shared_from_this<AresResolver> {
public std::enable_shared_from_this<AresResolver> protected:
{ AresResolver(const AresResolver &) = delete;
public: AresResolver &operator=(const AresResolver &) = delete;
AresResolver(trantor::EventLoop* loop, size_t timeout); // some uncopyable classes maybe support move constructor....
~AresResolver(); AresResolver(AresResolver &&) noexcept(true) = default;
AresResolver &operator=(AresResolver &&) noexcept(true) = default;
virtual void resolve(const std::string& hostname, public:
const Callback& cb) override AresResolver(trantor::EventLoop *loop, size_t timeout);
{ ~AresResolver();
bool cached = false;
InetAddress inet;
{
std::lock_guard<std::mutex> lock(globalMutex());
auto iter = globalCache().find(hostname);
if (iter != globalCache().end())
{
auto& cachedAddr = iter->second;
if (timeout_ == 0 ||
cachedAddr.second.after(timeout_) > trantor::Date::date())
{
struct sockaddr_in addr;
memset(&addr, 0, sizeof addr);
addr.sin_family = AF_INET;
addr.sin_port = 0;
addr.sin_addr = cachedAddr.first;
inet = InetAddress(addr);
cached = true;
}
}
}
if (cached)
{
cb(inet);
return;
}
if (loop_->isInLoopThread())
{
resolveInLoop(hostname, cb);
}
else
{
loop_->queueInLoop([thisPtr = shared_from_this(), hostname, cb]() {
thisPtr->resolveInLoop(hostname, cb);
});
}
}
private: virtual void resolve(const std::string &hostname,
struct QueryData const Callback &cb) override {
{ bool cached = false;
AresResolver* owner_; InetAddress inet;
Callback callback_; {
std::string hostname_; std::lock_guard<std::mutex> lock(globalMutex());
QueryData(AresResolver* o, auto iter = globalCache().find(hostname);
const Callback& cb, if (iter != globalCache().end()) {
const std::string& hostname) auto &cachedAddr = iter->second;
: owner_(o), callback_(cb), hostname_(hostname) if (timeout_ == 0 ||
{ cachedAddr.second.after(timeout_) > trantor::Date::date()) {
} struct sockaddr_in addr;
}; memset(&addr, 0, sizeof addr);
void resolveInLoop(const std::string& hostname, const Callback& cb); addr.sin_family = AF_INET;
void init(); addr.sin_port = 0;
trantor::EventLoop* loop_; addr.sin_addr = cachedAddr.first;
ares_channel ctx_{nullptr}; inet = InetAddress(addr);
bool timerActive_{false}; cached = true;
using ChannelList = std::map<int, std::unique_ptr<trantor::Channel>>; }
ChannelList channels_; }
static std::unordered_map<std::string, }
std::pair<struct in_addr, trantor::Date>>& if (cached) {
globalCache() cb(inet);
{ return;
static std::unordered_map<std::string, }
std::pair<struct in_addr, trantor::Date>> if (loop_->isInLoopThread()) {
dnsCache; resolveInLoop(hostname, cb);
return dnsCache; } else {
} loop_->queueInLoop([thisPtr = shared_from_this(), hostname, cb]() {
static std::mutex& globalMutex() thisPtr->resolveInLoop(hostname, cb);
{ });
static std::mutex mutex_; }
return mutex_; }
}
static EventLoop* getLoop()
{
static EventLoopThread loopThread;
loopThread.run();
return loopThread.getLoop();
}
const size_t timeout_{60};
void onRead(int sockfd); private:
void onTimer(); struct QueryData {
void onQueryResult(int status, AresResolver *owner_;
struct hostent* result, Callback callback_;
const std::string& hostname, std::string hostname_;
const Callback& callback); QueryData(AresResolver *o,
void onSockCreate(int sockfd, int type); const Callback &cb,
void onSockStateChange(int sockfd, bool read, bool write); const std::string &hostname) :
owner_(o), callback_(cb), hostname_(hostname) {
}
};
void resolveInLoop(const std::string &hostname, const Callback &cb);
void init();
trantor::EventLoop *loop_;
ares_channel ctx_{ nullptr };
bool timerActive_{ false };
using ChannelList = std::map<int, std::unique_ptr<trantor::Channel> >;
ChannelList channels_;
static std::unordered_map<std::string,
std::pair<struct in_addr, trantor::Date> > &
globalCache() {
static std::unordered_map<std::string,
std::pair<struct in_addr, trantor::Date> >
dnsCache;
return dnsCache;
}
static std::mutex &globalMutex() {
static std::mutex mutex_;
return mutex_;
}
static EventLoop *getLoop() {
static EventLoopThread loopThread;
loopThread.run();
return loopThread.getLoop();
}
const size_t timeout_{ 60 };
static void ares_hostcallback_(void* data, void onRead(int sockfd);
int status, void onTimer();
int timeouts, void onQueryResult(int status,
struct hostent* hostent); struct hostent *result,
const std::string &hostname,
const Callback &callback);
void onSockCreate(int sockfd, int type);
void onSockStateChange(int sockfd, bool read, bool write);
static void ares_hostcallback_(void *data,
int status,
int timeouts,
struct hostent *hostent);
#ifdef _WIN32 #ifdef _WIN32
static int ares_sock_createcallback_(SOCKET sockfd, int type, void* data); static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
#else #else
static int ares_sock_createcallback_(int sockfd, int type, void* data); static int ares_sock_createcallback_(int sockfd, int type, void *data);
#endif #endif
static void ares_sock_statecallback_(void* data, static void ares_sock_statecallback_(void *data,
#ifdef _WIN32 #ifdef _WIN32
SOCKET sockfd, SOCKET sockfd,
#else #else
int sockfd, int sockfd,
#endif #endif
int read, int read,
int write); int write);
struct LibraryInitializer struct LibraryInitializer {
{ LibraryInitializer();
LibraryInitializer(); ~LibraryInitializer();
~LibraryInitializer(); };
}; static LibraryInitializer libraryInitializer_;
static LibraryInitializer libraryInitializer_;
}; };
} // namespace trantor } // namespace trantor

View File

@ -6,57 +6,53 @@
// Author: Tao An // Author: Tao An
#pragma once #pragma once
#include "core/net/resolver.h"
#include <trantor/utils/NonCopyable.h>
#include "core/containers/concurrent_task_queue.h" #include "core/containers/concurrent_task_queue.h"
#include "core/net/resolver.h"
#include <memory> #include <memory>
#include <vector>
#include <thread> #include <thread>
#include <vector>
namespace trantor namespace trantor {
{ constexpr size_t kResolveBufferLength{ 16 * 1024 };
constexpr size_t kResolveBufferLength{16 * 1024};
class NormalResolver : public Resolver, class NormalResolver : public Resolver,
public NonCopyable, public std::enable_shared_from_this<NormalResolver> {
public std::enable_shared_from_this<NormalResolver> protected:
{ NormalResolver(const NormalResolver &) = delete;
public: NormalResolver &operator=(const NormalResolver &) = delete;
virtual void resolve(const std::string& hostname, // some uncopyable classes maybe support move constructor....
const Callback& callback) override; NormalResolver(NormalResolver &&) noexcept(true) = default;
explicit NormalResolver(size_t timeout) NormalResolver &operator=(NormalResolver &&) noexcept(true) = default;
: timeout_(timeout), resolveBuffer_(kResolveBufferLength)
{
}
virtual ~NormalResolver()
{
}
private: public:
static std::unordered_map<std::string, virtual void resolve(const std::string &hostname,
std::pair<trantor::InetAddress, trantor::Date>>& const Callback &callback) override;
globalCache() explicit NormalResolver(size_t timeout) :
{ timeout_(timeout), resolveBuffer_(kResolveBufferLength) {
static std::unordered_map< }
std::string, virtual ~NormalResolver() {
std::pair<trantor::InetAddress, trantor::Date>> }
dnsCache_;
return dnsCache_; private:
} static std::unordered_map<std::string,
static std::mutex& globalMutex() std::pair<trantor::InetAddress, trantor::Date> > &
{ globalCache() {
static std::mutex mutex_; static std::unordered_map<
return mutex_; std::string,
} std::pair<trantor::InetAddress, trantor::Date> >
static trantor::ConcurrentTaskQueue& concurrentTaskQueue() dnsCache_;
{ return dnsCache_;
static trantor::ConcurrentTaskQueue queue( }
std::thread::hardware_concurrency() < 8 static std::mutex &globalMutex() {
? 8 static std::mutex mutex_;
: std::thread::hardware_concurrency(), return mutex_;
"Dns Queue"); }
return queue; static trantor::ConcurrentTaskQueue &concurrentTaskQueue() {
} static trantor::ConcurrentTaskQueue queue(
const size_t timeout_; std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
std::vector<char> resolveBuffer_; "Dns Queue");
return queue;
}
const size_t timeout_;
std::vector<char> resolveBuffer_;
}; };
} // namespace trantor } // namespace trantor

View File

@ -28,7 +28,6 @@ namespace trantor {
class Socket { class Socket {
protected: protected:
// NonCopyable
Socket(const Socket &) = delete; Socket(const Socket &) = delete;
Socket &operator=(const Socket &) = delete; Socket &operator=(const Socket &) = delete;
// some uncopyable classes maybe support move constructor.... // some uncopyable classes maybe support move constructor....

View File

@ -21,13 +21,12 @@
#include "core/loops/event_loop.h" #include "core/loops/event_loop.h"
#include "core/net/inet_address.h" #include "core/net/inet_address.h"
#include "tcp_connection.h" #include "tcp_connection.h"
#include <signal.h>
#include <atomic>
#include <functional> #include <functional>
#include <thread> #include <thread>
#include <atomic>
#include <signal.h>
namespace trantor namespace trantor {
{
class Connector; class Connector;
using ConnectorPtr = std::shared_ptr<Connector>; using ConnectorPtr = std::shared_ptr<Connector>;
class SSLContext; class SSLContext;
@ -35,214 +34,204 @@ class SSLContext;
* @brief This class represents a TCP client. * @brief This class represents a TCP client.
* *
*/ */
class TcpClient : NonCopyable class TcpClient {
{ protected:
public: TcpClient(const TcpClient &) = delete;
/** TcpClient &operator=(const TcpClient &) = delete;
* @brief Construct a new TCP client instance. // some uncopyable classes maybe support move constructor....
* TcpClient(TcpClient &&) noexcept(true) = default;
* @param loop The event loop in which the client runs. TcpClient &operator=(TcpClient &&) noexcept(true) = default;
* @param serverAddr The address of the server.
* @param nameArg The name of the client.
*/
TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg);
~TcpClient();
/** public:
* @brief Connect to the server. /**
* * @brief Construct a new TCP client instance.
*/ *
void connect(); * @param loop The event loop in which the client runs.
* @param serverAddr The address of the server.
* @param nameArg The name of the client.
*/
TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg);
~TcpClient();
/** /**
* @brief Disconnect from the server. * @brief Connect to the server.
* *
*/ */
void disconnect(); void connect();
/** /**
* @brief Stop connecting to the server. * @brief Disconnect from the server.
* *
*/ */
void stop(); void disconnect();
/** /**
* @brief Get the TCP connection to the server. * @brief Stop connecting to the server.
* *
* @return TcpConnectionPtr */
*/ void stop();
TcpConnectionPtr connection() const
{
std::lock_guard<std::mutex> lock(mutex_);
return connection_;
}
/** /**
* @brief Get the event loop. * @brief Get the TCP connection to the server.
* *
* @return EventLoop* * @return TcpConnectionPtr
*/ */
EventLoop *getLoop() const TcpConnectionPtr connection() const {
{ std::lock_guard<std::mutex> lock(mutex_);
return loop_; return connection_;
} }
/** /**
* @brief Check whether the client re-connect to the server. * @brief Get the event loop.
* *
* @return true * @return EventLoop*
* @return false */
*/ EventLoop *getLoop() const {
bool retry() const return loop_;
{ }
return retry_;
}
/** /**
* @brief Enable retrying. * @brief Check whether the client re-connect to the server.
* *
*/ * @return true
void enableRetry() * @return false
{ */
retry_ = true; bool retry() const {
} return retry_;
}
/** /**
* @brief Get the name of the client. * @brief Enable retrying.
* *
* @return const std::string& */
*/ void enableRetry() {
const std::string &name() const retry_ = true;
{ }
return name_;
}
/** /**
* @brief Set the connection callback. * @brief Get the name of the client.
* *
* @param cb The callback is called when the connection to the server is * @return const std::string&
* established or closed. */
*/ const std::string &name() const {
void setConnectionCallback(const ConnectionCallback &cb) return name_;
{ }
connectionCallback_ = cb;
}
void setConnectionCallback(ConnectionCallback &&cb)
{
connectionCallback_ = std::move(cb);
}
/** /**
* @brief Set the connection error callback. * @brief Set the connection callback.
* *
* @param cb The callback is called when an error occurs during connecting * @param cb The callback is called when the connection to the server is
* to the server. * established or closed.
*/ */
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) void setConnectionCallback(const ConnectionCallback &cb) {
{ connectionCallback_ = cb;
connectionErrorCallback_ = cb; }
} void setConnectionCallback(ConnectionCallback &&cb) {
connectionCallback_ = std::move(cb);
}
/** /**
* @brief Set the message callback. * @brief Set the connection error callback.
* *
* @param cb The callback is called when some data is received from the * @param cb The callback is called when an error occurs during connecting
* server. * to the server.
*/ */
void setMessageCallback(const RecvMessageCallback &cb) void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
{ connectionErrorCallback_ = cb;
messageCallback_ = cb; }
}
void setMessageCallback(RecvMessageCallback &&cb)
{
messageCallback_ = std::move(cb);
}
/// Set write complete callback.
/// Not thread safe.
/** /**
* @brief Set the write complete callback. * @brief Set the message callback.
* *
* @param cb The callback is called when data to send is written to the * @param cb The callback is called when some data is received from the
* socket. * server.
*/ */
void setWriteCompleteCallback(const WriteCompleteCallback &cb) void setMessageCallback(const RecvMessageCallback &cb) {
{ messageCallback_ = cb;
writeCompleteCallback_ = cb; }
} void setMessageCallback(RecvMessageCallback &&cb) {
void setWriteCompleteCallback(WriteCompleteCallback &&cb) messageCallback_ = std::move(cb);
{ }
writeCompleteCallback_ = std::move(cb); /// Set write complete callback.
} /// Not thread safe.
/** /**
* @brief Set the callback for errors of SSL * @brief Set the write complete callback.
* @param cb The callback is called when an SSL error occurs. *
*/ * @param cb The callback is called when data to send is written to the
void setSSLErrorCallback(const SSLErrorCallback &cb) * socket.
{ */
sslErrorCallback_ = cb; void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
} writeCompleteCallback_ = cb;
void setSSLErrorCallback(SSLErrorCallback &&cb) }
{ void setWriteCompleteCallback(WriteCompleteCallback &&cb) {
sslErrorCallback_ = std::move(cb); writeCompleteCallback_ = std::move(cb);
} }
/** /**
* @brief Enable SSL encryption. * @brief Set the callback for errors of SSL
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the * @param cb The callback is called when an SSL error occurs.
* client. */
* @param validateCert If true, we try to validate if the peer's SSL cert void setSSLErrorCallback(const SSLErrorCallback &cb) {
* is valid. sslErrorCallback_ = cb;
* @param hostname The server hostname for SNI. If it is empty, the SNI is }
* not used. void setSSLErrorCallback(SSLErrorCallback &&cb) {
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in sslErrorCallback_ = std::move(cb);
* OpenSSL. }
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string>>
&sslConfCmds = {});
private: /**
/// Not thread safe, but in loop * @brief Enable SSL encryption.
void newConnection(int sockfd); * @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
/// Not thread safe, but in loop * client.
void removeConnection(const TcpConnectionPtr &conn); * @param validateCert If true, we try to validate if the peer's SSL cert
* is valid.
* @param hostname The server hostname for SNI. If it is empty, the SNI is
* not used.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
EventLoop *loop_; private:
ConnectorPtr connector_; // avoid revealing Connector /// Not thread safe, but in loop
const std::string name_; void newConnection(int sockfd);
ConnectionCallback connectionCallback_; /// Not thread safe, but in loop
ConnectionErrorCallback connectionErrorCallback_; void removeConnection(const TcpConnectionPtr &conn);
RecvMessageCallback messageCallback_;
WriteCompleteCallback writeCompleteCallback_; EventLoop *loop_;
SSLErrorCallback sslErrorCallback_; ConnectorPtr connector_; // avoid revealing Connector
std::atomic_bool retry_; // atomic const std::string name_;
std::atomic_bool connect_; // atomic ConnectionCallback connectionCallback_;
// always in loop thread ConnectionErrorCallback connectionErrorCallback_;
mutable std::mutex mutex_; RecvMessageCallback messageCallback_;
TcpConnectionPtr connection_; // @GuardedBy mutex_ WriteCompleteCallback writeCompleteCallback_;
std::shared_ptr<SSLContext> sslCtxPtr_; SSLErrorCallback sslErrorCallback_;
bool validateCert_{false}; std::atomic_bool retry_; // atomic
std::string SSLHostName_; std::atomic_bool connect_; // atomic
// always in loop thread
mutable std::mutex mutex_;
TcpConnectionPtr connection_; // @GuardedBy mutex_
std::shared_ptr<SSLContext> sslCtxPtr_;
bool validateCert_{ false };
std::string SSLHostName_;
#ifndef _WIN32 #ifndef _WIN32
class IgnoreSigPipe class IgnoreSigPipe {
{ public:
public: IgnoreSigPipe() {
IgnoreSigPipe() ::signal(SIGPIPE, SIG_IGN);
{ }
::signal(SIGPIPE, SIG_IGN); };
}
};
static IgnoreSigPipe initObj; static IgnoreSigPipe initObj;
#endif #endif
}; };
} // namespace trantor } // namespace trantor

View File

@ -18,7 +18,6 @@
#include "core/loops/callbacks.h" #include "core/loops/callbacks.h"
#include "core/loops/event_loop.h" #include "core/loops/event_loop.h"
#include "core/net/inet_address.h" #include "core/net/inet_address.h"
#include <trantor/utils/NonCopyable.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>

View File

@ -13,237 +13,229 @@
*/ */
#pragma once #pragma once
#include "core/loops/callbacks.h"
#include <trantor/utils/NonCopyable.h>
#include "core/log/logger.h" #include "core/log/logger.h"
#include "core/loops/callbacks.h"
#include "core/loops/event_loop_thread_pool.h" #include "core/loops/event_loop_thread_pool.h"
#include "core/loops/timing_wheel.h"
#include "core/net/inet_address.h" #include "core/net/inet_address.h"
#include "core/net/tcp_connection.h" #include "core/net/tcp_connection.h"
#include "core/loops/timing_wheel.h" #include <signal.h>
#include <string>
#include <memory> #include <memory>
#include <set> #include <set>
#include <signal.h> #include <string>
namespace trantor
{ namespace trantor {
class Acceptor; class Acceptor;
class SSLContext; class SSLContext;
/** /**
* @brief This class represents a TCP server. * @brief This class represents a TCP server.
* *
*/ */
class TcpServer : NonCopyable class TcpServer {
{ protected:
public: TcpServer(const TcpServer &) = delete;
/** TcpServer &operator=(const TcpServer &) = delete;
* @brief Construct a new TCP server instance. // some uncopyable classes maybe support move constructor....
* TcpServer(TcpServer &&) noexcept(true) = default;
* @param loop The event loop in which the acceptor of the server is TcpServer &operator=(TcpServer &&) noexcept(true) = default;
* handled.
* @param address The address of the server.
* @param name The name of the server.
* @param reUseAddr The SO_REUSEADDR option.
* @param reUsePort The SO_REUSEPORT option.
*/
TcpServer(EventLoop *loop,
const InetAddress &address,
const std::string &name,
bool reUseAddr = true,
bool reUsePort = true);
~TcpServer();
/** public:
* @brief Start the server. /**
* * @brief Construct a new TCP server instance.
*/ *
void start(); * @param loop The event loop in which the acceptor of the server is
* handled.
* @param address The address of the server.
* @param name The name of the server.
* @param reUseAddr The SO_REUSEADDR option.
* @param reUsePort The SO_REUSEPORT option.
*/
TcpServer(EventLoop *loop,
const InetAddress &address,
const std::string &name,
bool reUseAddr = true,
bool reUsePort = true);
~TcpServer();
/** /**
* @brief Stop the server. * @brief Start the server.
* *
*/ */
void stop(); void start();
/** /**
* @brief Set the number of event loops in which the I/O of connections to * @brief Stop the server.
* the server is handled. *
* */
* @param num void stop();
*/
void setIoLoopNum(size_t num)
{
assert(!started_);
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
loopPoolPtr_->start();
}
/** /**
* @brief Set the event loops pool in which the I/O of connections to * @brief Set the number of event loops in which the I/O of connections to
* the server is handled. * the server is handled.
* *
* @param pool * @param num
*/ */
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) void setIoLoopNum(size_t num) {
{ assert(!started_);
assert(pool->size() > 0); loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
assert(!started_); loopPoolPtr_->start();
loopPoolPtr_ = pool; }
loopPoolPtr_->start();
}
/** /**
* @brief Set the message callback. * @brief Set the event loops pool in which the I/O of connections to
* * the server is handled.
* @param cb The callback is called when some data is received on a *
* connection to the server. * @param pool
*/ */
void setRecvMessageCallback(const RecvMessageCallback &cb) void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
{ assert(pool->size() > 0);
recvMessageCallback_ = cb; assert(!started_);
} loopPoolPtr_ = pool;
void setRecvMessageCallback(RecvMessageCallback &&cb) loopPoolPtr_->start();
{ }
recvMessageCallback_ = std::move(cb);
}
/** /**
* @brief Set the connection callback. * @brief Set the message callback.
* *
* @param cb The callback is called when a connection is established or * @param cb The callback is called when some data is received on a
* closed. * connection to the server.
*/ */
void setConnectionCallback(const ConnectionCallback &cb) void setRecvMessageCallback(const RecvMessageCallback &cb) {
{ recvMessageCallback_ = cb;
connectionCallback_ = cb; }
} void setRecvMessageCallback(RecvMessageCallback &&cb) {
void setConnectionCallback(ConnectionCallback &&cb) recvMessageCallback_ = std::move(cb);
{ }
connectionCallback_ = std::move(cb);
}
/** /**
* @brief Set the write complete callback. * @brief Set the connection callback.
* *
* @param cb The callback is called when data to send is written to the * @param cb The callback is called when a connection is established or
* socket of a connection. * closed.
*/ */
void setWriteCompleteCallback(const WriteCompleteCallback &cb) void setConnectionCallback(const ConnectionCallback &cb) {
{ connectionCallback_ = cb;
writeCompleteCallback_ = cb; }
} void setConnectionCallback(ConnectionCallback &&cb) {
void setWriteCompleteCallback(WriteCompleteCallback &&cb) connectionCallback_ = std::move(cb);
{ }
writeCompleteCallback_ = std::move(cb);
}
/** /**
* @brief Get the name of the server. * @brief Set the write complete callback.
* *
* @return const std::string& * @param cb The callback is called when data to send is written to the
*/ * socket of a connection.
const std::string &name() const */
{ void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
return serverName_; writeCompleteCallback_ = cb;
} }
void setWriteCompleteCallback(WriteCompleteCallback &&cb) {
writeCompleteCallback_ = std::move(cb);
}
/** /**
* @brief Get the IP and port string of the server. * @brief Get the name of the server.
* *
* @return const std::string * @return const std::string&
*/ */
const std::string ipPort() const; const std::string &name() const {
return serverName_;
}
/** /**
* @brief Get the address of the server. * @brief Get the IP and port string of the server.
* *
* @return const trantor::InetAddress& * @return const std::string
*/ */
const trantor::InetAddress &address() const; const std::string ipPort() const;
/** /**
* @brief Get the event loop of the server. * @brief Get the address of the server.
* *
* @return EventLoop* * @return const trantor::InetAddress&
*/ */
EventLoop *getLoop() const const trantor::InetAddress &address() const;
{
return loop_;
}
/** /**
* @brief Get the I/O event loops of the server. * @brief Get the event loop of the server.
* *
* @return std::vector<EventLoop *> * @return EventLoop*
*/ */
std::vector<EventLoop *> getIoLoops() const EventLoop *getLoop() const {
{ return loop_;
return loopPoolPtr_->getLoops(); }
}
/** /**
* @brief An idle connection is a connection that has no read or write, kick * @brief Get the I/O event loops of the server.
* off it after timeout seconds. *
* * @return std::vector<EventLoop *>
* @param timeout */
*/ std::vector<EventLoop *> getIoLoops() const {
void kickoffIdleConnections(size_t timeout) return loopPoolPtr_->getLoops();
{ }
loop_->runInLoop([this, timeout]() {
assert(!started_);
idleTimeout_ = timeout;
});
}
/** /**
* @brief Enable SSL encryption. * @brief An idle connection is a connection that has no read or write, kick
* * off it after timeout seconds.
* @param certPath The path of the certificate file. *
* @param keyPath The path of the private key file. * @param timeout
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the */
* server. void kickoffIdleConnections(size_t timeout) {
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in loop_->runInLoop([this, timeout]() {
* OpenSSL. assert(!started_);
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in idleTimeout_ = timeout;
* 2020. And it's a good practice to only use TLS 1.2 and above. });
*/ }
void enableSSL(const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string>>
&sslConfCmds = {});
private: /**
EventLoop *loop_; * @brief Enable SSL encryption.
std::unique_ptr<Acceptor> acceptorPtr_; *
void newConnection(int fd, const InetAddress &peer); * @param certPath The path of the certificate file.
std::string serverName_; * @param keyPath The path of the private key file.
std::set<TcpConnectionPtr> connSet_; * @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
* server.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
RecvMessageCallback recvMessageCallback_; private:
ConnectionCallback connectionCallback_; EventLoop *loop_;
WriteCompleteCallback writeCompleteCallback_; std::unique_ptr<Acceptor> acceptorPtr_;
void newConnection(int fd, const InetAddress &peer);
std::string serverName_;
std::set<TcpConnectionPtr> connSet_;
size_t idleTimeout_{0}; RecvMessageCallback recvMessageCallback_;
std::map<EventLoop *, std::shared_ptr<TimingWheel>> timingWheelMap_; ConnectionCallback connectionCallback_;
void connectionClosed(const TcpConnectionPtr &connectionPtr); WriteCompleteCallback writeCompleteCallback_;
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
size_t idleTimeout_{ 0 };
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
void connectionClosed(const TcpConnectionPtr &connectionPtr);
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
#ifndef _WIN32 #ifndef _WIN32
class IgnoreSigPipe class IgnoreSigPipe {
{ public:
public: IgnoreSigPipe() {
IgnoreSigPipe() ::signal(SIGPIPE, SIG_IGN);
{ LOG_TRACE << "Ignore SIGPIPE";
::signal(SIGPIPE, SIG_IGN); }
LOG_TRACE << "Ignore SIGPIPE"; };
}
};
IgnoreSigPipe initObj; IgnoreSigPipe initObj;
#endif #endif
bool started_{false}; bool started_{ false };
// OpenSSL SSL context Object; // OpenSSL SSL context Object;
std::shared_ptr<SSLContext> sslCtxPtr_; std::shared_ptr<SSLContext> sslCtxPtr_;
}; };
} // namespace trantor } // namespace trantor