initial codestyle cleanup in the net folder.

This commit is contained in:
Relintai 2022-02-10 13:41:02 +01:00
parent 8992cf49b3
commit 38a28ee9ce
18 changed files with 662 additions and 1128 deletions

File diff suppressed because it is too large Load Diff

View File

@ -30,54 +30,45 @@
#pragma once #pragma once
#include "core/net/tcp_connection.h"
#include "core/loops/timing_wheel.h" #include "core/loops/timing_wheel.h"
#include "core/net/tcp_connection.h"
#include <list> #include <list>
#include <mutex> #include <mutex>
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <thread>
#include <array> #include <array>
#include <thread>
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
enum class SSLStatus
{ enum class SSLStatus {
Handshaking, Handshaking,
Connecting, Connecting,
Connected, Connected,
DisConnecting, DisConnecting,
DisConnected DisConnected
}; };
class SSLContext; class SSLContext;
class SSLConn; class SSLConn;
std::shared_ptr<SSLContext> newSSLContext( std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
bool useOldTLS, std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
bool validateCert, // void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx, const std::string &certPath, const std::string &keyPath);
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
std::shared_ptr<SSLContext> newSSLServerContext(
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
// const std::string &certPath,
// const std::string &keyPath);
#endif #endif
class Channel; class Channel;
class Socket; class Socket;
class TcpServer; class TcpServer;
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn); void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
class TcpConnectionImpl : public TcpConnection, class TcpConnectionImpl : public TcpConnection, public std::enable_shared_from_this<TcpConnectionImpl> {
public std::enable_shared_from_this<TcpConnectionImpl> friend class TcpServer;
{ friend class TcpClient;
friend class TcpServer;
friend class TcpClient; friend void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
friend void removeConnection(EventLoop *loop,
const TcpConnectionPtr &conn);
protected: protected:
TcpConnectionImpl(const TcpConnectionImpl &) = delete; TcpConnectionImpl(const TcpConnectionImpl &) = delete;
@ -86,278 +77,247 @@ protected:
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default; TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default; TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
public: public:
class KickoffEntry class KickoffEntry {
{ public:
public: explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) :
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) conn_(conn) {
: conn_(conn) }
{
}
void reset()
{
conn_.reset();
}
~KickoffEntry()
{
auto conn = conn_.lock();
if (conn)
{
conn->forceClose();
}
}
private: void reset() {
std::weak_ptr<TcpConnection> conn_; conn_.reset();
}; }
TcpConnectionImpl(EventLoop *loop, ~KickoffEntry() {
int socketfd, auto conn = conn_.lock();
const InetAddress &localAddr, if (conn) {
const InetAddress &peerAddr); conn->forceClose();
}
}
private:
std::weak_ptr<TcpConnection> conn_;
};
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr);
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
TcpConnectionImpl(EventLoop *loop, TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr<SSLContext> &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = "");
int socketfd,
const InetAddress &localAddr,
const InetAddress &peerAddr,
const std::shared_ptr<SSLContext> &ctxPtr,
bool isServer = true,
bool validateCert = true,
const std::string &hostname = "");
#endif #endif
virtual ~TcpConnectionImpl();
virtual void send(const char *msg, size_t len) override;
virtual void send(const void *msg, size_t len) override;
virtual void send(const std::string &msg) override;
virtual void send(std::string &&msg) override;
virtual void send(const MsgBuffer &buffer) override;
virtual void send(MsgBuffer &&buffer) override;
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
virtual void sendFile(const char *fileName,
size_t offset = 0,
size_t length = 0) override;
virtual const InetAddress &localAddr() const override virtual ~TcpConnectionImpl();
{
return localAddr_;
}
virtual const InetAddress &peerAddr() const override
{
return peerAddr_;
}
virtual bool connected() const override virtual void send(const char *msg, size_t len) override;
{ virtual void send(const void *msg, size_t len) override;
return status_ == ConnStatus::Connected; virtual void send(const std::string &msg) override;
} virtual void send(std::string &&msg) override;
virtual bool disconnected() const override virtual void send(const MsgBuffer &buffer) override;
{ virtual void send(MsgBuffer &&buffer) override;
return status_ == ConnStatus::Disconnected; virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
} virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override;
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;} virtual const InetAddress &localAddr() const override {
virtual MsgBuffer *getRecvBuffer() override return localAddr_;
{ }
return &readBuffer_;
}
// set callbacks
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) override
{
highWaterMarkCallback_ = cb;
highWaterMarkLen_ = markLen;
}
virtual void keepAlive() override virtual const InetAddress &peerAddr() const override {
{ return peerAddr_;
idleTimeout_ = 0; }
auto entry = kickoffEntry_.lock();
if (entry)
{
entry->reset();
}
}
virtual bool isKeepAlive() override
{
return idleTimeout_ == 0;
}
virtual void setTcpNoDelay(bool on) override;
virtual void shutdown() override;
virtual void forceClose() override;
virtual EventLoop *getLoop() override
{
return loop_;
}
virtual size_t bytesSent() const override virtual bool connected() const override {
{ return status_ == ConnStatus::Connected;
return bytesSent_; }
}
virtual size_t bytesReceived() const override
{
return bytesReceived_;
}
virtual void startClientEncryption(
std::function<void()> callback,
bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
{}) override;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> callback) override;
virtual bool isSSLConnection() const override
{
return isEncrypted_;
}
private: virtual bool disconnected() const override {
/// Internal use only. return status_ == ConnStatus::Disconnected;
}
std::weak_ptr<KickoffEntry> kickoffEntry_; // virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
std::weak_ptr<TimingWheel> timingWheelWeakPtr_; virtual MsgBuffer *getRecvBuffer() override {
size_t idleTimeout_{0}; return &readBuffer_;
Date lastTimingWheelUpdateTime_; }
void enableKickingOff(size_t timeout, // set callbacks
const std::shared_ptr<TimingWheel> &timingWheel) virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
{ size_t markLen) override {
assert(timingWheel); highWaterMarkCallback_ = cb;
assert(timingWheel->getLoop() == loop_); highWaterMarkLen_ = markLen;
assert(timeout > 0); }
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
kickoffEntry_ = entry; virtual void keepAlive() override {
timingWheelWeakPtr_ = timingWheel; idleTimeout_ = 0;
idleTimeout_ = timeout; auto entry = kickoffEntry_.lock();
timingWheel->insertEntry(timeout, entry); if (entry) {
} entry->reset();
void extendLife(); }
}
virtual bool isKeepAlive() override {
return idleTimeout_ == 0;
}
virtual void setTcpNoDelay(bool on) override;
virtual void shutdown() override;
virtual void forceClose() override;
virtual EventLoop *getLoop() override {
return loop_;
}
virtual size_t bytesSent() const override {
return bytesSent_;
}
virtual size_t bytesReceived() const override {
return bytesReceived_;
}
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) override;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) override;
virtual bool isSSLConnection() const override {
return isEncrypted_;
}
private:
/// Internal use only.
std::weak_ptr<KickoffEntry> kickoffEntry_;
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
size_t idleTimeout_{ 0 };
Date lastTimingWheelUpdateTime_;
void enableKickingOff(size_t timeout,
const std::shared_ptr<TimingWheel> &timingWheel) {
assert(timingWheel);
assert(timingWheel->getLoop() == loop_);
assert(timeout > 0);
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
kickoffEntry_ = entry;
timingWheelWeakPtr_ = timingWheel;
idleTimeout_ = timeout;
timingWheel->insertEntry(timeout, entry);
}
void extendLife();
#ifndef _WIN32 #ifndef _WIN32
void sendFile(int sfd, size_t offset = 0, size_t length = 0); void sendFile(int sfd, size_t offset = 0, size_t length = 0);
#else #else
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0); void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
#endif #endif
void setRecvMsgCallback(const RecvMessageCallback &cb) void setRecvMsgCallback(const RecvMessageCallback &cb) {
{ recvMsgCallback_ = cb;
recvMsgCallback_ = cb; }
} void setConnectionCallback(const ConnectionCallback &cb) {
void setConnectionCallback(const ConnectionCallback &cb) connectionCallback_ = cb;
{ }
connectionCallback_ = cb; void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
} writeCompleteCallback_ = cb;
void setWriteCompleteCallback(const WriteCompleteCallback &cb) }
{ void setCloseCallback(const CloseCallback &cb) {
writeCompleteCallback_ = cb; closeCallback_ = cb;
} }
void setCloseCallback(const CloseCallback &cb) void setSSLErrorCallback(const SSLErrorCallback &cb) {
{ sslErrorCallback_ = cb;
closeCallback_ = cb; }
}
void setSSLErrorCallback(const SSLErrorCallback &cb)
{
sslErrorCallback_ = cb;
}
void connectDestroyed(); void connectDestroyed();
virtual void connectEstablished(); virtual void connectEstablished();
protected: protected:
struct BufferNode struct BufferNode {
{
#ifndef _WIN32 #ifndef _WIN32
int sendFd_{-1}; int sendFd_{ -1 };
off_t offset_; off_t offset_;
#else #else
FILE *sendFp_{nullptr}; FILE *sendFp_{ nullptr };
long long offset_; long long offset_;
#endif #endif
ssize_t fileBytesToSend_; ssize_t fileBytesToSend_;
std::shared_ptr<MsgBuffer> msgBuffer_; std::shared_ptr<MsgBuffer> msgBuffer_;
~BufferNode()
{ ~BufferNode() {
#ifndef _WIN32 #ifndef _WIN32
if (sendFd_ >= 0) if (sendFd_ >= 0)
close(sendFd_); close(sendFd_);
#else #else
if (sendFp_) if (sendFp_)
fclose(sendFp_); fclose(sendFp_);
#endif #endif
} }
}; };
using BufferNodePtr = std::shared_ptr<BufferNode>;
enum class ConnStatus
{
Disconnected,
Connecting,
Connected,
Disconnecting
};
bool isEncrypted_{false};
EventLoop *loop_;
std::unique_ptr<Channel> ioChannelPtr_;
std::unique_ptr<Socket> socketPtr_;
MsgBuffer readBuffer_;
std::list<BufferNodePtr> writeBufferList_;
void readCallback();
void writeCallback();
InetAddress localAddr_, peerAddr_;
ConnStatus status_{ConnStatus::Connecting};
// callbacks
RecvMessageCallback recvMsgCallback_;
ConnectionCallback connectionCallback_;
CloseCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
SSLErrorCallback sslErrorCallback_;
void handleClose();
void handleError();
// virtual void sendInLoop(const std::string &msg);
void sendFileInLoop(const BufferNodePtr &file); using BufferNodePtr = std::shared_ptr<BufferNode>;
enum class ConnStatus {
Disconnected,
Connecting,
Connected,
Disconnecting
};
bool isEncrypted_{ false };
EventLoop *loop_;
std::unique_ptr<Channel> ioChannelPtr_;
std::unique_ptr<Socket> socketPtr_;
MsgBuffer readBuffer_;
std::list<BufferNodePtr> writeBufferList_;
void readCallback();
void writeCallback();
InetAddress localAddr_, peerAddr_;
ConnStatus status_{ ConnStatus::Connecting };
// callbacks
RecvMessageCallback recvMsgCallback_;
ConnectionCallback connectionCallback_;
CloseCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
SSLErrorCallback sslErrorCallback_;
void handleClose();
void handleError();
// virtual void sendInLoop(const std::string &msg);
void sendFileInLoop(const BufferNodePtr &file);
#ifndef _WIN32 #ifndef _WIN32
void sendInLoop(const void *buffer, size_t length); void sendInLoop(const void *buffer, size_t length);
ssize_t writeInLoop(const void *buffer, size_t length); ssize_t writeInLoop(const void *buffer, size_t length);
#else #else
void sendInLoop(const char *buffer, size_t length); void sendInLoop(const char *buffer, size_t length);
ssize_t writeInLoop(const char *buffer, size_t length); ssize_t writeInLoop(const char *buffer, size_t length);
#endif #endif
size_t highWaterMarkLen_; size_t highWaterMarkLen_;
std::string name_; std::string name_;
uint64_t sendNum_{0}; uint64_t sendNum_{ 0 };
std::mutex sendNumMutex_; std::mutex sendNumMutex_;
size_t bytesSent_{0}; size_t bytesSent_{ 0 };
size_t bytesReceived_{0}; size_t bytesReceived_{ 0 };
std::unique_ptr<std::vector<char>> fileBufferPtr_; std::unique_ptr<std::vector<char> > fileBufferPtr_;
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
private: private:
void doHandshaking(); void doHandshaking();
bool validatePeerCertificate(); bool validatePeerCertificate();
struct SSLEncryption
{ struct SSLEncryption {
SSLStatus statusOfSSL_ = SSLStatus::Handshaking; SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
// OpenSSL // OpenSSL
std::shared_ptr<SSLContext> sslCtxPtr_; std::shared_ptr<SSLContext> sslCtxPtr_;
std::unique_ptr<SSLConn> sslPtr_; std::unique_ptr<SSLConn> sslPtr_;
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_; std::unique_ptr<std::array<char, 8192> > sendBufferPtr_;
bool isServer_{false}; bool isServer_{ false };
bool isUpgrade_{false}; bool isUpgrade_{ false };
std::function<void()> upgradeCallback_; std::function<void()> upgradeCallback_;
std::string hostname_; std::string hostname_;
}; };
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
void startClientEncryptionInLoop( std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
std::function<void()> &&callback,
bool useOldTLS, void startClientEncryptionInLoop(std::function<void()> &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
bool validateCert, void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx, std::function<void()> &&callback);
const std::string &hostname,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> &&callback);
#endif #endif
}; };

View File

@ -35,6 +35,7 @@
Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) : Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) :
loop_(loop), serverAddr_(addr), retry_(retry) { loop_(loop), serverAddr_(addr), retry_(retry) {
} }
Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) : Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) :
loop_(loop), serverAddr_(std::move(addr)), retry_(retry) { loop_(loop), serverAddr_(std::move(addr)), retry_(retry) {
} }
@ -43,8 +44,10 @@ void Connector::start() {
connect_ = true; connect_ = true;
loop_->runInLoop([this]() { startInLoop(); }); loop_->runInLoop([this]() { startInLoop(); });
} }
void Connector::restart() { void Connector::restart() {
} }
void Connector::stop() { void Connector::stop() {
} }
@ -57,6 +60,7 @@ void Connector::startInLoop() {
LOG_DEBUG << "do not connect"; LOG_DEBUG << "do not connect";
} }
} }
void Connector::connect() { void Connector::connect() {
int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family()); int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family());
errno = 0; errno = 0;

View File

@ -46,23 +46,30 @@ protected:
public: public:
using NewConnectionCallback = std::function<void(int sockfd)>; using NewConnectionCallback = std::function<void(int sockfd)>;
using ConnectionErrorCallback = std::function<void()>; using ConnectionErrorCallback = std::function<void()>;
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true); Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true); Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
void setNewConnectionCallback(const NewConnectionCallback &cb) { void setNewConnectionCallback(const NewConnectionCallback &cb) {
newConnectionCallback_ = cb; newConnectionCallback_ = cb;
} }
void setNewConnectionCallback(NewConnectionCallback &&cb) { void setNewConnectionCallback(NewConnectionCallback &&cb) {
newConnectionCallback_ = std::move(cb); newConnectionCallback_ = std::move(cb);
} }
void setErrorCallback(const ConnectionErrorCallback &cb) { void setErrorCallback(const ConnectionErrorCallback &cb) {
errorCallback_ = cb; errorCallback_ = cb;
} }
void setErrorCallback(ConnectionErrorCallback &&cb) { void setErrorCallback(ConnectionErrorCallback &&cb) {
errorCallback_ = std::move(cb); errorCallback_ = std::move(cb);
} }
const InetAddress &serverAddress() const { const InetAddress &serverAddress() const {
return serverAddr_; return serverAddr_;
} }
void start(); void start();
void restart(); void restart();
void stop(); void stop();
@ -70,11 +77,13 @@ public:
private: private:
NewConnectionCallback newConnectionCallback_; NewConnectionCallback newConnectionCallback_;
ConnectionErrorCallback errorCallback_; ConnectionErrorCallback errorCallback_;
enum class Status { enum class Status {
Disconnected, Disconnected,
Connecting, Connecting,
Connected Connected
}; };
static constexpr int kMaxRetryDelayMs = 30 * 1000; static constexpr int kMaxRetryDelayMs = 30 * 1000;
static constexpr int kInitRetryDelayMs = 500; static constexpr int kInitRetryDelayMs = 500;
std::shared_ptr<Channel> channelPtr_; std::shared_ptr<Channel> channelPtr_;

View File

@ -98,8 +98,7 @@ InetAddress::InetAddress(uint16_t port, bool loopbackOnly, bool ipv6) :
isUnspecified_ = false; isUnspecified_ = false;
} }
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : isIpV6_(ipv6) {
isIpV6_(ipv6) {
if (ipv6) { if (ipv6) {
memset(&addr6_, 0, sizeof(addr6_)); memset(&addr6_, 0, sizeof(addr6_));
addr6_.sin6_family = AF_INET6; addr6_.sin6_family = AF_INET6;

View File

@ -38,195 +38,76 @@ using sa_family_t = unsigned short;
using in_addr_t = uint32_t; using in_addr_t = uint32_t;
using uint16_t = unsigned short; using uint16_t = unsigned short;
#else #else
#include <netinet/in.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h> #include <sys/socket.h>
#endif #endif
#include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <mutex>
/** class InetAddress {
* @brief Wrapper of sockaddr_in. This is an POD interface class. public:
* InetAddress(uint16_t port = 0, bool loopbackOnly = false, bool ipv6 = false);
*/
class InetAddress
{
public:
/**
* @brief Constructs an endpoint with given port number. Mostly used in
* TcpServer listening.
*
* @param port
* @param loopbackOnly
* @param ipv6
*/
InetAddress(uint16_t port = 0,
bool loopbackOnly = false,
bool ipv6 = false);
/** InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
* @brief Constructs an endpoint with given ip and port.
*
* @param ip A IPv4 or IPv6 address.
* @param port
* @param ipv6
*/
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
/** explicit InetAddress(const struct sockaddr_in &addr) :
* @brief Constructs an endpoint with given struct `sockaddr_in`. Mostly addr_(addr), isUnspecified_(false) {
* used when accepting new connections }
*
* @param addr
*/
explicit InetAddress(const struct sockaddr_in &addr)
: addr_(addr), isUnspecified_(false)
{
}
/** explicit InetAddress(const struct sockaddr_in6 &addr) :
* @brief Constructs an IPv6 endpoint with given struct `sockaddr_in6`. addr6_(addr), isIpV6_(true), isUnspecified_(false) {
* Mostly used when accepting new connections }
*
* @param addr
*/
explicit InetAddress(const struct sockaddr_in6 &addr)
: addr6_(addr), isIpV6_(true), isUnspecified_(false)
{
}
/** sa_family_t family() const {
* @brief Return the sin_family of the endpoint. return addr_.sin_family;
* }
* @return sa_family_t
*/
sa_family_t family() const
{
return addr_.sin_family;
}
/** std::string toIp() const;
* @brief Return the IP string of the endpoint. std::string toIpPort() const;
* uint16_t toPort() const;
* @return std::string
*/
std::string toIp() const;
/** bool isIpV6() const {
* @brief Return the IP and port string of the endpoint. return isIpV6_;
* }
* @return std::string
*/
std::string toIpPort() const;
/** bool isIntranetIp() const;
* @brief Return the port number of the endpoint. bool isLoopbackIp() const;
*
* @return uint16_t
*/
uint16_t toPort() const;
/** const struct sockaddr *getSockAddr() const {
* @brief Check if the endpoint is IPv4 or IPv6. return static_cast<const struct sockaddr *>((void *)(&addr6_));
* }
* @return true
* @return false
*/
bool isIpV6() const
{
return isIpV6_;
}
/** void setSockAddrInet6(const struct sockaddr_in6 &addr6) {
* @brief Return true if the endpoint is an intranet endpoint. addr6_ = addr6;
* isIpV6_ = (addr6_.sin6_family == AF_INET6);
* @return true isUnspecified_ = false;
* @return false }
*/
bool isIntranetIp() const;
/** uint32_t ipNetEndian() const;
* @brief Return true if the endpoint is a loopback endpoint. const uint32_t *ip6NetEndian() const;
*
* @return true
* @return false
*/
bool isLoopbackIp() const;
/** uint16_t portNetEndian() const {
* @brief Get the pointer to the sockaddr struct. return addr_.sin_port;
* }
* @return const struct sockaddr*
*/
const struct sockaddr *getSockAddr() const
{
return static_cast<const struct sockaddr *>((void *)(&addr6_));
}
/** void setPortNetEndian(uint16_t port) {
* @brief Set the sockaddr_in6 struct in the endpoint. addr_.sin_port = port;
* }
* @param addr6
*/
void setSockAddrInet6(const struct sockaddr_in6 &addr6)
{
addr6_ = addr6;
isIpV6_ = (addr6_.sin6_family == AF_INET6);
isUnspecified_ = false;
}
/** inline bool isUnspecified() const {
* @brief Return the integer value of the IP(v4) in net endian byte order. return isUnspecified_;
* }
* @return uint32_t
*/
uint32_t ipNetEndian() const;
/** private:
* @brief Return the pointer to the integer value of the IP(v6) in net union {
* endian byte order. struct sockaddr_in addr_;
* struct sockaddr_in6 addr6_;
* @return const uint32_t* };
*/
const uint32_t *ip6NetEndian() const;
/** bool isIpV6_{ false };
* @brief Return the port number in net endian byte order. bool isUnspecified_{ true };
*
* @return uint16_t
*/
uint16_t portNetEndian() const
{
return addr_.sin_port;
}
/**
* @brief Set the port number in net endian byte order.
*
* @param port
*/
void setPortNetEndian(uint16_t port)
{
addr_.sin_port = port;
}
/**
* @brief Return true if the address is not initalized.
*/
inline bool isUnspecified() const
{
return isUnspecified_;
}
private:
union
{
struct sockaddr_in addr_;
struct sockaddr_in6 addr6_;
};
bool isIpV6_{false};
bool isUnspecified_{true};
}; };
#endif // MUDUO_NET_INETADDRESS_H #endif // MUDUO_NET_INETADDRESS_H

View File

@ -34,44 +34,20 @@
#include "core/loops/event_loop.h" #include "core/loops/event_loop.h"
#include "core/net/inet_address.h" #include "core/net/inet_address.h"
/** //make it a reference
* @brief This class represents an asynchronous DNS resolver.
* @note Although the c-ares library is not essential, it is recommended to
* install it for higher performance
*/
class Resolver class Resolver
{ {
public: public:
using Callback = std::function<void(const InetAddress&)>; using Callback = std::function<void(const InetAddress&)>;
/** static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr, size_t timeout = 60);
* @brief Create a new DNS resolver.
*
* @param loop The event loop in which the DNS resolver runs.
* @param timeout The timeout in seconds for DNS.
* @return std::shared_ptr<Resolver>
*/
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr,
size_t timeout = 60);
/** virtual void resolve(const std::string& hostname, const Callback& callback) = 0;
* @brief Resolve an address asynchronously.
*
* @param hostname
* @param callback
*/
virtual void resolve(const std::string& hostname,
const Callback& callback) = 0;
virtual ~Resolver() virtual ~Resolver()
{ {
} }
/**
* @brief Check whether the c-ares library is used.
*
* @return true
* @return false
*/
static bool isCAresUsed(); static bool isCAresUsed();
}; };

View File

@ -29,8 +29,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ares_resolver.h" #include "ares_resolver.h"
#include <ares.h>
#include "core/loops/channel.h" #include "core/loops/channel.h"
#include <ares.h>
#ifdef _WIN32 #ifdef _WIN32
#include <winsock2.h> #include <winsock2.h>
#else #else
@ -67,14 +67,14 @@ bool Resolver::isCAresUsed() {
AresResolver::LibraryInitializer::LibraryInitializer() { AresResolver::LibraryInitializer::LibraryInitializer() {
ares_library_init(ARES_LIB_INIT_ALL); ares_library_init(ARES_LIB_INIT_ALL);
} }
AresResolver::LibraryInitializer::~LibraryInitializer() { AresResolver::LibraryInitializer::~LibraryInitializer() {
ares_library_cleanup(); ares_library_cleanup();
} }
AresResolver::LibraryInitializer AresResolver::libraryInitializer_; AresResolver::LibraryInitializer AresResolver::libraryInitializer_;
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop, std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop, size_t timeout) {
size_t timeout) {
return std::make_shared<AresResolver>(loop, timeout); return std::make_shared<AresResolver>(loop, timeout);
} }
@ -84,6 +84,7 @@ AresResolver::AresResolver(EventLoop *loop, size_t timeout) :
loop_ = getLoop(); loop_ = getLoop();
} }
} }
void AresResolver::init() { void AresResolver::init() {
if (!ctx_) { if (!ctx_) {
struct ares_options options; struct ares_options options;
@ -108,14 +109,16 @@ void AresResolver::init() {
this); this);
} }
} }
AresResolver::~AresResolver() { AresResolver::~AresResolver() {
if (ctx_) if (ctx_)
ares_destroy(ctx_); ares_destroy(ctx_);
} }
void AresResolver::resolveInLoop(const std::string &hostname, void AresResolver::resolveInLoop(const std::string &hostname, const Callback &cb) {
const Callback &cb) {
loop_->assertInLoopThread(); loop_->assertInLoopThread();
#ifdef _WIN32 #ifdef _WIN32
if (hostname == "localhost") { if (hostname == "localhost") {
const static InetAddress localhost_{ "127.0.0.1", 0 }; const static InetAddress localhost_{ "127.0.0.1", 0 };
@ -123,20 +126,17 @@ void AresResolver::resolveInLoop(const std::string &hostname,
return; return;
} }
#endif #endif
init(); init();
QueryData *queryData = new QueryData(this, cb, hostname); QueryData *queryData = new QueryData(this, cb, hostname);
ares_gethostbyname(ctx_, ares_gethostbyname(ctx_, hostname.c_str(), AF_INET, &AresResolver::ares_hostcallback_, queryData);
hostname.c_str(),
AF_INET,
&AresResolver::ares_hostcallback_,
queryData);
struct timeval tv; struct timeval tv;
struct timeval *tvp = ares_timeout(ctx_, NULL, &tv); struct timeval *tvp = ares_timeout(ctx_, NULL, &tv);
double timeout = getSeconds(tvp); double timeout = getSeconds(tvp);
// LOG_DEBUG << "timeout " << timeout << " active " << timerActive_; // LOG_DEBUG << "timeout " << timeout << " active " << timerActive_;
if (!timerActive_ && timeout >= 0.0) { if (!timerActive_ && timeout >= 0.0) {
loop_->runAfter(timeout, loop_->runAfter(timeout, std::bind(&AresResolver::onTimer, shared_from_this()));
std::bind(&AresResolver::onTimer, shared_from_this()));
timerActive_ = true; timerActive_ = true;
} }
return; return;
@ -161,18 +161,18 @@ void AresResolver::onTimer() {
} }
} }
void AresResolver::onQueryResult(int status, void AresResolver::onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback) {
struct hostent *result,
const std::string &hostname,
const Callback &callback) {
LOG_TRACE << "onQueryResult " << status; LOG_TRACE << "onQueryResult " << status;
struct sockaddr_in addr; struct sockaddr_in addr;
memset(&addr, 0, sizeof addr); memset(&addr, 0, sizeof addr);
addr.sin_family = AF_INET; addr.sin_family = AF_INET;
addr.sin_port = 0; addr.sin_port = 0;
if (result) { if (result) {
addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr); addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr);
} }
InetAddress inet(addr); InetAddress inet(addr);
{ {
std::lock_guard<std::mutex> lock(globalMutex()); std::lock_guard<std::mutex> lock(globalMutex());
@ -180,6 +180,7 @@ void AresResolver::onQueryResult(int status,
addrItem.first = addr.sin_addr; addrItem.first = addr.sin_addr;
addrItem.second = Date::date(); addrItem.second = Date::date();
} }
callback(inet); callback(inet);
} }
@ -198,6 +199,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
loop_->assertInLoopThread(); loop_->assertInLoopThread();
ChannelList::iterator it = channels_.find(sockfd); ChannelList::iterator it = channels_.find(sockfd);
assert(it != channels_.end()); assert(it != channels_.end());
if (read) { if (read) {
// update // update
// if (write) { } else { } // if (write) { } else { }
@ -209,17 +211,12 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
} }
} }
void AresResolver::ares_hostcallback_(void *data, void AresResolver::ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent) {
int status,
int timeouts,
struct hostent *hostent) {
(void)timeouts; (void)timeouts;
QueryData *query = static_cast<QueryData *>(data); QueryData *query = static_cast<QueryData *>(data);
query->owner_->onQueryResult(status, query->owner_->onQueryResult(status, hostent, query->hostname_, query->callback_);
hostent,
query->hostname_,
query->callback_);
delete query; delete query;
} }
@ -242,6 +239,7 @@ void AresResolver::ares_sock_statecallback_(void *data,
#endif #endif
int read, int read,
int write) { int write) {
LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write; LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write;
static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write); static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write);
} }

View File

@ -36,14 +36,15 @@
#include <map> #include <map>
#include <memory> #include <memory>
// Resolver will be a ref
extern "C" { extern "C" {
struct hostent; struct hostent;
struct ares_channeldata; struct ares_channeldata;
using ares_channel = struct ares_channeldata *; using ares_channel = struct ares_channeldata *;
} }
class AresResolver : public Resolver, class AresResolver : public Resolver, public std::enable_shared_from_this<AresResolver> {
public std::enable_shared_from_this<AresResolver> {
protected: protected:
AresResolver(const AresResolver &) = delete; AresResolver(const AresResolver &) = delete;
AresResolver &operator=(const AresResolver &) = delete; AresResolver &operator=(const AresResolver &) = delete;
@ -55,9 +56,9 @@ public:
AresResolver(EventLoop *loop, size_t timeout); AresResolver(EventLoop *loop, size_t timeout);
~AresResolver(); ~AresResolver();
virtual void resolve(const std::string &hostname, virtual void resolve(const std::string &hostname, const Callback &cb) override {
const Callback &cb) override {
bool cached = false; bool cached = false;
InetAddress inet; InetAddress inet;
{ {
std::lock_guard<std::mutex> lock(globalMutex()); std::lock_guard<std::mutex> lock(globalMutex());
@ -76,10 +77,12 @@ public:
} }
} }
} }
if (cached) { if (cached) {
cb(inet); cb(inet);
return; return;
} }
if (loop_->isInLoopThread()) { if (loop_->isInLoopThread()) {
resolveInLoop(hostname, cb); resolveInLoop(hostname, cb);
} else { } else {
@ -94,14 +97,15 @@ private:
AresResolver *owner_; AresResolver *owner_;
Callback callback_; Callback callback_;
std::string hostname_; std::string hostname_;
QueryData(AresResolver *o,
const Callback &cb, QueryData(AresResolver *o, const Callback &cb, const std::string &hostname) :
const std::string &hostname) :
owner_(o), callback_(cb), hostname_(hostname) { owner_(o), callback_(cb), hostname_(hostname) {
} }
}; };
void resolveInLoop(const std::string &hostname, const Callback &cb); void resolveInLoop(const std::string &hostname, const Callback &cb);
void init(); void init();
EventLoop *loop_; EventLoop *loop_;
ares_channel ctx_{ nullptr }; ares_channel ctx_{ nullptr };
bool timerActive_{ false }; bool timerActive_{ false };
@ -109,36 +113,33 @@ private:
ChannelList channels_; ChannelList channels_;
static std::unordered_map<std::string, static std::unordered_map<std::string,
std::pair<struct in_addr, Date> > & std::pair<struct in_addr, Date> > &
globalCache() { globalCache() {
static std::unordered_map<std::string, static std::unordered_map<std::string, std::pair<struct in_addr, Date> > dnsCache;
std::pair<struct in_addr, Date> >
dnsCache;
return dnsCache; return dnsCache;
} }
static std::mutex &globalMutex() { static std::mutex &globalMutex() {
static std::mutex mutex_; static std::mutex mutex_;
return mutex_; return mutex_;
} }
static EventLoop *getLoop() { static EventLoop *getLoop() {
static EventLoopThread loopThread; static EventLoopThread loopThread;
loopThread.run(); loopThread.run();
return loopThread.getLoop(); return loopThread.getLoop();
} }
const size_t timeout_{ 60 }; const size_t timeout_{ 60 };
void onRead(int sockfd); void onRead(int sockfd);
void onTimer(); void onTimer();
void onQueryResult(int status, void onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback);
struct hostent *result,
const std::string &hostname,
const Callback &callback);
void onSockCreate(int sockfd, int type); void onSockCreate(int sockfd, int type);
void onSockStateChange(int sockfd, bool read, bool write); void onSockStateChange(int sockfd, bool read, bool write);
static void ares_hostcallback_(void *data, static void ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent);
int status,
int timeouts,
struct hostent *hostent);
#ifdef _WIN32 #ifdef _WIN32
static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data); static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
#else #else
@ -152,9 +153,11 @@ private:
#endif #endif
int read, int read,
int write); int write);
struct LibraryInitializer { struct LibraryInitializer {
LibraryInitializer(); LibraryInitializer();
~LibraryInitializer(); ~LibraryInitializer();
}; };
static LibraryInitializer libraryInitializer_; static LibraryInitializer libraryInitializer_;
}; };

View File

@ -44,9 +44,11 @@ std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *,
size_t timeout) { size_t timeout) {
return std::make_shared<NormalResolver>(timeout); return std::make_shared<NormalResolver>(timeout);
} }
bool Resolver::isCAresUsed() { bool Resolver::isCAresUsed() {
return false; return false;
} }
void NormalResolver::resolve(const std::string &hostname, void NormalResolver::resolve(const std::string &hostname,
const Callback &callback) { const Callback &callback) {
{ {

View File

@ -35,10 +35,12 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
//Resolver will be a ref
constexpr size_t kResolveBufferLength{ 16 * 1024 }; constexpr size_t kResolveBufferLength{ 16 * 1024 };
class NormalResolver : public Resolver, class NormalResolver : public Resolver, public std::enable_shared_from_this<NormalResolver> {
public std::enable_shared_from_this<NormalResolver> {
protected: protected:
NormalResolver(const NormalResolver &) = delete; NormalResolver(const NormalResolver &) = delete;
NormalResolver &operator=(const NormalResolver &) = delete; NormalResolver &operator=(const NormalResolver &) = delete;
@ -56,25 +58,23 @@ public:
} }
private: private:
static std::unordered_map<std::string, static std::unordered_map<std::string, std::pair<InetAddress, Date> > &globalCache() {
std::pair<InetAddress, Date> > & static std::unordered_map<std::string, std::pair<InetAddress, Date> > dnsCache_;
globalCache() {
static std::unordered_map<
std::string,
std::pair<InetAddress, Date> >
dnsCache_;
return dnsCache_; return dnsCache_;
} }
static std::mutex &globalMutex() { static std::mutex &globalMutex() {
static std::mutex mutex_; static std::mutex mutex_;
return mutex_; return mutex_;
} }
static ConcurrentTaskQueue &concurrentTaskQueue() { static ConcurrentTaskQueue &concurrentTaskQueue() {
static ConcurrentTaskQueue queue( static ConcurrentTaskQueue queue(
std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(), std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
"Dns Queue"); "Dns Queue");
return queue; return queue;
} }
const size_t timeout_; const size_t timeout_;
std::vector<char> resolveBuffer_; std::vector<char> resolveBuffer_;
}; };

View File

@ -29,9 +29,9 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "socket.h" #include "socket.h"
#include "core/log/logger.h"
#include <assert.h> #include <assert.h>
#include <sys/types.h> #include <sys/types.h>
#include "core/log/logger.h"
#ifdef _WIN32 #ifdef _WIN32
#include <ws2tcpip.h> #include <ws2tcpip.h>
#else #else
@ -39,7 +39,6 @@
#include <sys/socket.h> #include <sys/socket.h>
#endif #endif
bool Socket::isSelfConnect(int sockfd) { bool Socket::isSelfConnect(int sockfd) {
struct sockaddr_in6 localaddr = getLocalAddr(sockfd); struct sockaddr_in6 localaddr = getLocalAddr(sockfd);
struct sockaddr_in6 peeraddr = getPeerAddr(sockfd); struct sockaddr_in6 peeraddr = getPeerAddr(sockfd);
@ -75,6 +74,7 @@ void Socket::bindAddress(const InetAddress &localaddr) {
exit(1); exit(1);
} }
} }
void Socket::listen() { void Socket::listen() {
assert(sockFd_ > 0); assert(sockFd_ > 0);
int ret = ::listen(sockFd_, SOMAXCONN); int ret = ::listen(sockFd_, SOMAXCONN);
@ -83,6 +83,7 @@ void Socket::listen() {
exit(1); exit(1);
} }
} }
int Socket::accept(InetAddress *peeraddr) { int Socket::accept(InetAddress *peeraddr) {
struct sockaddr_in6 addr6; struct sockaddr_in6 addr6;
memset(&addr6, 0, sizeof(addr6)); memset(&addr6, 0, sizeof(addr6));
@ -102,6 +103,7 @@ int Socket::accept(InetAddress *peeraddr) {
} }
return connfd; return connfd;
} }
void Socket::closeWrite() { void Socket::closeWrite() {
#ifndef _WIN32 #ifndef _WIN32
if (::shutdown(sockFd_, SHUT_WR) < 0) if (::shutdown(sockFd_, SHUT_WR) < 0)
@ -112,6 +114,7 @@ void Socket::closeWrite() {
LOG_SYSERR << "sockets::shutdownWrite"; LOG_SYSERR << "sockets::shutdownWrite";
} }
} }
int Socket::read(char *buffer, uint64_t len) { int Socket::read(char *buffer, uint64_t len) {
#ifndef _WIN32 #ifndef _WIN32
return ::read(sockFd_, buffer, len); return ::read(sockFd_, buffer, len);

View File

@ -32,8 +32,8 @@
#pragma once #pragma once
#include "core/net/inet_address.h"
#include "core/log/logger.h" #include "core/log/logger.h"
#include "core/net/inet_address.h"
#include <string> #include <string>
#ifndef _WIN32 #ifndef _WIN32
#include <unistd.h> #include <unistd.h>
@ -41,7 +41,6 @@
#include <fcntl.h> #include <fcntl.h>
class Socket { class Socket {
protected: protected:
Socket(const Socket &) = delete; Socket(const Socket &) = delete;
Socket &operator=(const Socket &) = delete; Socket &operator=(const Socket &) = delete;
@ -63,6 +62,7 @@ public:
LOG_SYSERR << "sockets::createNonblockingOrDie"; LOG_SYSERR << "sockets::createNonblockingOrDie";
exit(1); exit(1);
} }
LOG_TRACE << "sock=" << sock; LOG_TRACE << "sock=" << sock;
return sock; return sock;
} }
@ -85,15 +85,9 @@ public:
static int connect(int sockfd, const InetAddress &addr) { static int connect(int sockfd, const InetAddress &addr) {
if (addr.isIpV6()) if (addr.isIpV6())
return ::connect(sockfd, return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in6)));
addr.getSockAddr(),
static_cast<socklen_t>(
sizeof(struct sockaddr_in6)));
else else
return ::connect(sockfd, return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in)));
addr.getSockAddr(),
static_cast<socklen_t>(
sizeof(struct sockaddr_in)));
} }
static bool isSelfConnect(int sockfd); static bool isSelfConnect(int sockfd);
@ -101,7 +95,9 @@ public:
explicit Socket(int sockfd) : explicit Socket(int sockfd) :
sockFd_(sockfd) { sockFd_(sockfd) {
} }
~Socket(); ~Socket();
/// abort if address in use /// abort if address in use
void bindAddress(const InetAddress &localaddr); void bindAddress(const InetAddress &localaddr);
/// abort if address in use /// abort if address in use
@ -109,9 +105,11 @@ public:
int accept(InetAddress *peeraddr); int accept(InetAddress *peeraddr);
void closeWrite(); void closeWrite();
int read(char *buffer, uint64_t len); int read(char *buffer, uint64_t len);
int fd() { int fd() {
return sockFd_; return sockFd_;
} }
static struct sockaddr_in6 getLocalAddr(int sockfd); static struct sockaddr_in6 getLocalAddr(int sockfd);
static struct sockaddr_in6 getPeerAddr(int sockfd); static struct sockaddr_in6 getPeerAddr(int sockfd);

View File

@ -54,9 +54,7 @@ TcpClient::IgnoreSigPipe TcpClient::initObj;
#endif #endif
static void defaultConnectionCallback(const TcpConnectionPtr &conn) { static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
LOG_TRACE << conn->localAddr().toIpPort() << " -> " LOG_TRACE << conn->localAddr().toIpPort() << " -> " << conn->peerAddr().toIpPort() << " is " << (conn->connected() ? "UP" : "DOWN");
<< conn->peerAddr().toIpPort() << " is "
<< (conn->connected() ? "UP" : "DOWN");
// do not call conn->forceClose(), because some users want to register // do not call conn->forceClose(), because some users want to register
// message callback only. // message callback only.
} }
@ -65,9 +63,7 @@ static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) {
buf->retrieveAll(); buf->retrieveAll();
} }
TcpClient::TcpClient(EventLoop *loop, TcpClient::TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg) :
const InetAddress &serverAddr,
const std::string &nameArg) :
loop_(loop), loop_(loop),
connector_(new Connector(loop, serverAddr, false)), connector_(new Connector(loop, serverAddr, false)),
name_(nameArg), name_(nameArg),
@ -75,23 +71,27 @@ TcpClient::TcpClient(EventLoop *loop,
messageCallback_(defaultMessageCallback), messageCallback_(defaultMessageCallback),
retry_(false), retry_(false),
connect_(true) { connect_(true) {
connector_->setNewConnectionCallback(
std::bind(&TcpClient::newConnection, this, _1)); connector_->setNewConnectionCallback(std::bind(&TcpClient::newConnection, this, _1));
connector_->setErrorCallback([this]() { connector_->setErrorCallback([this]() {
if (connectionErrorCallback_) { if (connectionErrorCallback_) {
connectionErrorCallback_(); connectionErrorCallback_();
} }
}); });
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector "; LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
} }
TcpClient::~TcpClient() { TcpClient::~TcpClient() {
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector "; LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
TcpConnectionImplPtr conn; TcpConnectionImplPtr conn;
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_); conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
} }
if (conn) { if (conn) {
assert(loop_ == conn->getLoop()); assert(loop_ == conn->getLoop());
// TODO: not 100% safe, if we are in different thread // TODO: not 100% safe, if we are in different thread
@ -104,6 +104,7 @@ TcpClient::~TcpClient() {
}); });
}); });
}); });
conn->forceClose(); conn->forceClose();
} else { } else {
/// TODO need test in this condition /// TODO need test in this condition
@ -142,39 +143,34 @@ void TcpClient::newConnection(int sockfd) {
// TODO poll with zero timeout to double confirm the new connection // TODO poll with zero timeout to double confirm the new connection
// TODO use make_shared if necessary // TODO use make_shared if necessary
std::shared_ptr<TcpConnectionImpl> conn; std::shared_ptr<TcpConnectionImpl> conn;
if (sslCtxPtr_) { if (sslCtxPtr_) {
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
conn = std::make_shared<TcpConnectionImpl>(loop_, conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false, validateCert_, SSLHostName_);
sockfd,
localAddr,
peerAddr,
sslCtxPtr_,
false,
validateCert_,
SSLHostName_);
#else #else
LOG_FATAL << "OpenSSL is not found in your system!"; LOG_FATAL << "OpenSSL is not found in your system!";
abort(); abort();
#endif #endif
} else { } else {
conn = std::make_shared<TcpConnectionImpl>(loop_, conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr);
sockfd,
localAddr,
peerAddr);
} }
conn->setConnectionCallback(connectionCallback_); conn->setConnectionCallback(connectionCallback_);
conn->setRecvMsgCallback(messageCallback_); conn->setRecvMsgCallback(messageCallback_);
conn->setWriteCompleteCallback(writeCompleteCallback_); conn->setWriteCompleteCallback(writeCompleteCallback_);
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1)); conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
connection_ = conn; connection_ = conn;
} }
conn->setSSLErrorCallback([this](SSLError err) { conn->setSSLErrorCallback([this](SSLError err) {
if (sslErrorCallback_) { if (sslErrorCallback_) {
sslErrorCallback_(err); sslErrorCallback_(err);
} }
}); });
conn->connectEstablished(); conn->connectEstablished();
} }
@ -188,30 +184,23 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) {
connection_.reset(); connection_.reset();
} }
loop_->queueInLoop( loop_->queueInLoop(std::bind(&TcpConnectionImpl::connectDestroyed, std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
std::bind(&TcpConnectionImpl::connectDestroyed,
std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
if (retry_ && connect_) { if (retry_ && connect_) {
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " << connector_->serverAddress().toIpPort();
<< connector_->serverAddress().toIpPort();
connector_->restart(); connector_->restart();
} }
} }
void TcpClient::enableSSL( void TcpClient::enableSSL(bool useOldTLS, bool validateCert, std::string hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
bool useOldTLS,
bool validateCert,
std::string hostname,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
/* Create a new OpenSSL context */ /* Create a new OpenSSL context */
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds); sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
validateCert_ = validateCert; validateCert_ = validateCert;
if (!hostname.empty()) { if (!hostname.empty()) {
std::transform(hostname.begin(), std::transform(hostname.begin(), hostname.end(), hostname.begin(), tolower);
hostname.end(),
hostname.begin(),
tolower);
SSLHostName_ = std::move(hostname); SSLHostName_ = std::move(hostname);
} }

View File

@ -42,10 +42,7 @@
class Connector; class Connector;
using ConnectorPtr = std::shared_ptr<Connector>; using ConnectorPtr = std::shared_ptr<Connector>;
class SSLContext; class SSLContext;
/**
* @brief This class represents a TCP client.
*
*/
class TcpClient { class TcpClient {
protected: protected:
TcpClient(const TcpClient &) = delete; TcpClient(const TcpClient &) = delete;
@ -55,88 +52,35 @@ protected:
TcpClient &operator=(TcpClient &&) noexcept(true) = default; TcpClient &operator=(TcpClient &&) noexcept(true) = default;
public: public:
/** TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg);
* @brief Construct a new TCP client instance.
*
* @param loop The event loop in which the client runs.
* @param serverAddr The address of the server.
* @param nameArg The name of the client.
*/
TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg);
~TcpClient(); ~TcpClient();
/**
* @brief Connect to the server.
*
*/
void connect(); void connect();
/**
* @brief Disconnect from the server.
*
*/
void disconnect(); void disconnect();
/**
* @brief Stop connecting to the server.
*
*/
void stop(); void stop();
/**
* @brief Get the TCP connection to the server.
*
* @return TcpConnectionPtr
*/
TcpConnectionPtr connection() const { TcpConnectionPtr connection() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return connection_; return connection_;
} }
/**
* @brief Get the event loop.
*
* @return EventLoop*
*/
EventLoop *getLoop() const { EventLoop *getLoop() const {
return loop_; return loop_;
} }
/**
* @brief Check whether the client re-connect to the server.
*
* @return true
* @return false
*/
bool retry() const { bool retry() const {
return retry_; return retry_;
} }
/**
* @brief Enable retrying.
*
*/
void enableRetry() { void enableRetry() {
retry_ = true; retry_ = true;
} }
/**
* @brief Get the name of the client.
*
* @return const std::string&
*/
const std::string &name() const { const std::string &name() const {
return name_; return name_;
} }
/**
* @brief Set the connection callback.
*
* @param cb The callback is called when the connection to the server is
* established or closed.
*/
void setConnectionCallback(const ConnectionCallback &cb) { void setConnectionCallback(const ConnectionCallback &cb) {
connectionCallback_ = cb; connectionCallback_ = cb;
} }
@ -144,37 +88,18 @@ public:
connectionCallback_ = std::move(cb); connectionCallback_ = std::move(cb);
} }
/**
* @brief Set the connection error callback.
*
* @param cb The callback is called when an error occurs during connecting
* to the server.
*/
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) { void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
connectionErrorCallback_ = cb; connectionErrorCallback_ = cb;
} }
/**
* @brief Set the message callback.
*
* @param cb The callback is called when some data is received from the
* server.
*/
void setMessageCallback(const RecvMessageCallback &cb) { void setMessageCallback(const RecvMessageCallback &cb) {
messageCallback_ = cb; messageCallback_ = cb;
} }
void setMessageCallback(RecvMessageCallback &&cb) { void setMessageCallback(RecvMessageCallback &&cb) {
messageCallback_ = std::move(cb); messageCallback_ = std::move(cb);
} }
/// Set write complete callback.
/// Not thread safe.
/** /// Not thread safe.
* @brief Set the write complete callback.
*
* @param cb The callback is called when data to send is written to the
* socket.
*/
void setWriteCompleteCallback(const WriteCompleteCallback &cb) { void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
writeCompleteCallback_ = cb; writeCompleteCallback_ = cb;
} }
@ -182,10 +107,6 @@ public:
writeCompleteCallback_ = std::move(cb); writeCompleteCallback_ = std::move(cb);
} }
/**
* @brief Set the callback for errors of SSL
* @param cb The callback is called when an SSL error occurs.
*/
void setSSLErrorCallback(const SSLErrorCallback &cb) { void setSSLErrorCallback(const SSLErrorCallback &cb) {
sslErrorCallback_ = cb; sslErrorCallback_ = cb;
} }
@ -193,24 +114,7 @@ public:
sslErrorCallback_ = std::move(cb); sslErrorCallback_ = std::move(cb);
} }
/** void enableSSL(bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
* @brief Enable SSL encryption.
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
* client.
* @param validateCert If true, we try to validate if the peer's SSL cert
* is valid.
* @param hostname The server hostname for SNI. If it is empty, the SNI is
* not used.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
private: private:
/// Not thread safe, but in loop /// Not thread safe, but in loop
@ -234,6 +138,7 @@ private:
std::shared_ptr<SSLContext> sslCtxPtr_; std::shared_ptr<SSLContext> sslCtxPtr_;
bool validateCert_{ false }; bool validateCert_{ false };
std::string SSLHostName_; std::string SSLHostName_;
#ifndef _WIN32 #ifndef _WIN32
class IgnoreSigPipe { class IgnoreSigPipe {
public: public:

View File

@ -39,26 +39,13 @@
#include <string> #include <string>
class SSLContext; class SSLContext;
std::shared_ptr<SSLContext> newSSLServerContext( std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
/**
* @brief This class represents a TCP connection.
*
*/
class TcpConnection { class TcpConnection {
public: public:
TcpConnection() = default; TcpConnection() = default;
virtual ~TcpConnection(){}; virtual ~TcpConnection(){};
/**
* @brief Send some data to the peer.
*
* @param msg
* @param len
*/
virtual void send(const char *msg, size_t len) = 0; virtual void send(const char *msg, size_t len) = 0;
virtual void send(const void *msg, size_t len) = 0; virtual void send(const void *msg, size_t len) = 0;
virtual void send(const std::string &msg) = 0; virtual void send(const std::string &msg) = 0;
@ -68,95 +55,27 @@ public:
virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0; virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0; virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
/** virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) = 0;
* @brief Send a file to the peer.
*
* @param fileName
* @param offset
* @param length
*/
virtual void sendFile(const char *fileName,
size_t offset = 0,
size_t length = 0) = 0;
/**
* @brief Get the local address of the connection.
*
* @return const InetAddress&
*/
virtual const InetAddress &localAddr() const = 0; virtual const InetAddress &localAddr() const = 0;
/**
* @brief Get the remote address of the connection.
*
* @return const InetAddress&
*/
virtual const InetAddress &peerAddr() const = 0; virtual const InetAddress &peerAddr() const = 0;
/**
* @brief Return true if the connection is established.
*
* @return true
* @return false
*/
virtual bool connected() const = 0; virtual bool connected() const = 0;
/**
* @brief Return false if the connection is established.
*
* @return true
* @return false
*/
virtual bool disconnected() const = 0; virtual bool disconnected() const = 0;
/**
* @brief Get the buffer in which the received data stored.
*
* @return MsgBuffer*
*/
virtual MsgBuffer *getRecvBuffer() = 0; virtual MsgBuffer *getRecvBuffer() = 0;
/**
* @brief Set the high water mark callback
*
* @param cb The callback is called when the data in sending buffer is
* larger than the water mark.
* @param markLen The water mark in bytes.
*/
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) = 0;
/** virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) = 0;
* @brief Set the TCP_NODELAY option to the socket.
*
* @param on
*/
virtual void setTcpNoDelay(bool on) = 0; virtual void setTcpNoDelay(bool on) = 0;
/**
* @brief Shutdown the connection.
* @note This method only closes the writing direction.
*/
virtual void shutdown() = 0; virtual void shutdown() = 0;
/**
* @brief Close the connection forcefully.
*
*/
virtual void forceClose() = 0; virtual void forceClose() = 0;
/**
* @brief Get the event loop in which the connection I/O is handled.
*
* @return EventLoop*
*/
virtual EventLoop *getLoop() = 0; virtual EventLoop *getLoop() = 0;
/**
* @brief Set the custom data on the connection.
*
* @param context
*/
void setContext(const std::shared_ptr<void> &context) { void setContext(const std::shared_ptr<void> &context) {
contextPtr_ = context; contextPtr_ = context;
} }
@ -164,98 +83,30 @@ public:
contextPtr_ = std::move(context); contextPtr_ = std::move(context);
} }
/**
* @brief Get the custom data from the connection.
*
* @tparam T
* @return std::shared_ptr<T>
*/
template <typename T> template <typename T>
std::shared_ptr<T> getContext() const { std::shared_ptr<T> getContext() const {
return std::static_pointer_cast<T>(contextPtr_); return std::static_pointer_cast<T>(contextPtr_);
} }
/**
* @brief Return true if the custom data is set by user.
*
* @return true
* @return false
*/
bool hasContext() const { bool hasContext() const {
return (bool)contextPtr_; return (bool)contextPtr_;
} }
/**
* @brief Clear the custom data.
*
*/
void clearContext() { void clearContext() {
contextPtr_.reset(); contextPtr_.reset();
} }
/**
* @brief Call this method to avoid being kicked off by TcpServer, refer to
* the kickoffIdleConnections method in the TcpServer class.
*
*/
virtual void keepAlive() = 0; virtual void keepAlive() = 0;
/**
* @brief Return true if the keepAlive() method is called.
*
* @return true
* @return false
*/
virtual bool isKeepAlive() = 0; virtual bool isKeepAlive() = 0;
/**
* @brief Return the number of bytes sent
*
* @return size_t
*/
virtual size_t bytesSent() const = 0; virtual size_t bytesSent() const = 0;
/**
* @brief Return the number of bytes received.
*
* @return size_t
*/
virtual size_t bytesReceived() const = 0; virtual size_t bytesReceived() const = 0;
/**
* @brief Check whether the connection is SSL encrypted.
*
* @return true
* @return false
*/
virtual bool isSSLConnection() const = 0; virtual bool isSSLConnection() const = 0;
/** virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
* @brief Start the SSL encryption on the connection (as a client).
*
* @param callback The callback is called when the SSL connection is
* established.
* @param hostname The server hostname for SNI. If it is empty, the SNI is
* not used.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
*/
virtual void startClientEncryption(
std::function<void()> callback,
bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
/** virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) = 0;
* @brief Start the SSL encryption on the connection (as a server).
*
* @param ctx The SSL context.
* @param callback The callback is called when the SSL connection is
* established.
*/
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> callback) = 0;
protected: protected:
bool validateCert_ = false; bool validateCert_ = false;

View File

@ -28,20 +28,16 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/loops/acceptor.h"
#include "core/net/connections/tcp_connection_impl.h"
#include "core/net/tcp_server.h" #include "core/net/tcp_server.h"
#include "core/log/logger.h" #include "core/log/logger.h"
#include "core/loops/acceptor.h"
#include "core/net/connections/tcp_connection_impl.h"
#include <functional> #include <functional>
#include <vector> #include <vector>
using namespace std::placeholders; using namespace std::placeholders;
TcpServer::TcpServer(EventLoop *loop, TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr, bool reUsePort) :
const InetAddress &address,
const std::string &name,
bool reUseAddr,
bool reUsePort) :
loop_(loop), loop_(loop),
acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)), acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)),
serverName_(name), serverName_(name),
@ -50,8 +46,8 @@ TcpServer::TcpServer(EventLoop *loop,
<< " bytes]"; << " bytes]";
buffer->retrieveAll(); buffer->retrieveAll();
}) { }) {
acceptorPtr_->setNewConnectionCallback(
std::bind(&TcpServer::newConnection, this, _1, _2)); acceptorPtr_->setNewConnectionCallback(std::bind(&TcpServer::newConnection, this, _1, _2));
} }
TcpServer::~TcpServer() { TcpServer::~TcpServer() {
@ -69,14 +65,20 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
// LOG_TRACE<<"vector size:"<<str.size(); // LOG_TRACE<<"vector size:"<<str.size();
// size_t n=write(sockfd,&str[0],str.size()); // size_t n=write(sockfd,&str[0],str.size());
// LOG_TRACE<<"write "<<n<<" bytes"; // LOG_TRACE<<"write "<<n<<" bytes";
loop_->assertInLoopThread(); loop_->assertInLoopThread();
EventLoop *ioLoop = NULL; EventLoop *ioLoop = NULL;
if (loopPoolPtr_ && loopPoolPtr_->size() > 0) { if (loopPoolPtr_ && loopPoolPtr_->size() > 0) {
ioLoop = loopPoolPtr_->getNextLoop(); ioLoop = loopPoolPtr_->getNextLoop();
} }
if (ioLoop == NULL)
if (ioLoop == NULL) {
ioLoop = loop_; ioLoop = loop_;
}
std::shared_ptr<TcpConnectionImpl> newPtr; std::shared_ptr<TcpConnectionImpl> newPtr;
if (sslCtxPtr_) { if (sslCtxPtr_) {
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
newPtr = std::make_shared<TcpConnectionImpl>( newPtr = std::make_shared<TcpConnectionImpl>(
@ -90,14 +92,14 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
abort(); abort();
#endif #endif
} else { } else {
newPtr = std::make_shared<TcpConnectionImpl>( newPtr = std::make_shared<TcpConnectionImpl>(ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
} }
if (idleTimeout_ > 0) { if (idleTimeout_ > 0) {
assert(timingWheelMap_[ioLoop]); assert(timingWheelMap_[ioLoop]);
newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]); newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]);
} }
newPtr->setRecvMsgCallback(recvMessageCallback_); newPtr->setRecvMsgCallback(recvMessageCallback_);
newPtr->setConnectionCallback( newPtr->setConnectionCallback(
@ -105,11 +107,13 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
if (connectionCallback_) if (connectionCallback_)
connectionCallback_(connectionPtr); connectionCallback_(connectionPtr);
}); });
newPtr->setWriteCompleteCallback( newPtr->setWriteCompleteCallback(
[this](const TcpConnectionPtr &connectionPtr) { [this](const TcpConnectionPtr &connectionPtr) {
if (writeCompleteCallback_) if (writeCompleteCallback_)
writeCompleteCallback_(connectionPtr); writeCompleteCallback_(connectionPtr);
}); });
newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1)); newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1));
connSet_.insert(newPtr); connSet_.insert(newPtr);
newPtr->connectEstablished(); newPtr->connectEstablished();
@ -143,6 +147,7 @@ void TcpServer::start() {
acceptorPtr_->listen(); acceptorPtr_->listen();
}); });
} }
void TcpServer::stop() { void TcpServer::stop() {
loop_->runInLoop([this]() { acceptorPtr_.reset(); }); loop_->runInLoop([this]() { acceptorPtr_.reset(); });
for (auto connection : connSet_) { for (auto connection : connSet_) {
@ -159,6 +164,7 @@ void TcpServer::stop() {
f.get(); f.get();
} }
} }
void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) { void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) {
LOG_TRACE << "connectionClosed"; LOG_TRACE << "connectionClosed";
// loop_->assertInLoopThread(); // loop_->assertInLoopThread();
@ -179,11 +185,7 @@ const InetAddress &TcpServer::address() const {
return acceptorPtr_->addr(); return acceptorPtr_->addr();
} }
void TcpServer::enableSSL( void TcpServer::enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
#ifdef USE_OPENSSL #ifdef USE_OPENSSL
/* Create a new OpenSSL context */ /* Create a new OpenSSL context */
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds); sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);

View File

@ -43,10 +43,7 @@
class Acceptor; class Acceptor;
class SSLContext; class SSLContext;
/**
* @brief This class represents a TCP server.
*
*/
class TcpServer { class TcpServer {
protected: protected:
TcpServer(const TcpServer &) = delete; TcpServer(const TcpServer &) = delete;
@ -56,53 +53,18 @@ protected:
TcpServer &operator=(TcpServer &&) noexcept(true) = default; TcpServer &operator=(TcpServer &&) noexcept(true) = default;
public: public:
/** TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr = true, bool reUsePort = true);
* @brief Construct a new TCP server instance.
*
* @param loop The event loop in which the acceptor of the server is
* handled.
* @param address The address of the server.
* @param name The name of the server.
* @param reUseAddr The SO_REUSEADDR option.
* @param reUsePort The SO_REUSEPORT option.
*/
TcpServer(EventLoop *loop,
const InetAddress &address,
const std::string &name,
bool reUseAddr = true,
bool reUsePort = true);
~TcpServer(); ~TcpServer();
/**
* @brief Start the server.
*
*/
void start(); void start();
/**
* @brief Stop the server.
*
*/
void stop(); void stop();
/**
* @brief Set the number of event loops in which the I/O of connections to
* the server is handled.
*
* @param num
*/
void setIoLoopNum(size_t num) { void setIoLoopNum(size_t num) {
assert(!started_); assert(!started_);
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num); loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
loopPoolPtr_->start(); loopPoolPtr_->start();
} }
/**
* @brief Set the event loops pool in which the I/O of connections to
* the server is handled.
*
* @param pool
*/
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) { void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
assert(pool->size() > 0); assert(pool->size() > 0);
assert(!started_); assert(!started_);
@ -110,12 +72,6 @@ public:
loopPoolPtr_->start(); loopPoolPtr_->start();
} }
/**
* @brief Set the message callback.
*
* @param cb The callback is called when some data is received on a
* connection to the server.
*/
void setRecvMessageCallback(const RecvMessageCallback &cb) { void setRecvMessageCallback(const RecvMessageCallback &cb) {
recvMessageCallback_ = cb; recvMessageCallback_ = cb;
} }
@ -123,12 +79,6 @@ public:
recvMessageCallback_ = std::move(cb); recvMessageCallback_ = std::move(cb);
} }
/**
* @brief Set the connection callback.
*
* @param cb The callback is called when a connection is established or
* closed.
*/
void setConnectionCallback(const ConnectionCallback &cb) { void setConnectionCallback(const ConnectionCallback &cb) {
connectionCallback_ = cb; connectionCallback_ = cb;
} }
@ -136,12 +86,6 @@ public:
connectionCallback_ = std::move(cb); connectionCallback_ = std::move(cb);
} }
/**
* @brief Set the write complete callback.
*
* @param cb The callback is called when data to send is written to the
* socket of a connection.
*/
void setWriteCompleteCallback(const WriteCompleteCallback &cb) { void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
writeCompleteCallback_ = cb; writeCompleteCallback_ = cb;
} }
@ -149,53 +93,21 @@ public:
writeCompleteCallback_ = std::move(cb); writeCompleteCallback_ = std::move(cb);
} }
/**
* @brief Get the name of the server.
*
* @return const std::string&
*/
const std::string &name() const { const std::string &name() const {
return serverName_; return serverName_;
} }
/**
* @brief Get the IP and port string of the server.
*
* @return const std::string
*/
const std::string ipPort() const; const std::string ipPort() const;
/**
* @brief Get the address of the server.
*
* @return const InetAddress&
*/
const InetAddress &address() const; const InetAddress &address() const;
/**
* @brief Get the event loop of the server.
*
* @return EventLoop*
*/
EventLoop *getLoop() const { EventLoop *getLoop() const {
return loop_; return loop_;
} }
/**
* @brief Get the I/O event loops of the server.
*
* @return std::vector<EventLoop *>
*/
std::vector<EventLoop *> getIoLoops() const { std::vector<EventLoop *> getIoLoops() const {
return loopPoolPtr_->getLoops(); return loopPoolPtr_->getLoops();
} }
/**
* @brief An idle connection is a connection that has no read or write, kick
* off it after timeout seconds.
*
* @param timeout
*/
void kickoffIdleConnections(size_t timeout) { void kickoffIdleConnections(size_t timeout) {
loop_->runInLoop([this, timeout]() { loop_->runInLoop([this, timeout]() {
assert(!started_); assert(!started_);
@ -203,23 +115,12 @@ public:
}); });
} }
/** // certPath The path of the certificate file.
* @brief Enable SSL encryption. // keyPath The path of the private key file.
* // useOldTLS If true, the TLS 1.0 and 1.1 are supported by the server.
* @param certPath The path of the certificate file. // sslConfCmds The commands used to call the SSL_CONF_cmd function in OpenSSL.
* @param keyPath The path of the private key file. // Note: It's well known that TLS 1.0 and 1.1 are not considered secure in 2020. And it's a good practice to only use TLS 1.2 and above.
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the void enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
* server.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
private: private:
EventLoop *loop_; EventLoop *loop_;
@ -236,6 +137,7 @@ private:
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_; std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
void connectionClosed(const TcpConnectionPtr &connectionPtr); void connectionClosed(const TcpConnectionPtr &connectionPtr);
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_; std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
#ifndef _WIN32 #ifndef _WIN32
class IgnoreSigPipe { class IgnoreSigPipe {
public: public:
@ -247,6 +149,7 @@ private:
IgnoreSigPipe initObj; IgnoreSigPipe initObj;
#endif #endif
bool started_{ false }; bool started_{ false };
// OpenSSL SSL context Object; // OpenSSL SSL context Object;