initial codestyle cleanup in the net folder.

This commit is contained in:
Relintai 2022-02-10 13:41:02 +01:00
parent 8992cf49b3
commit 38a28ee9ce
18 changed files with 662 additions and 1128 deletions

File diff suppressed because it is too large Load Diff

View File

@ -30,54 +30,45 @@
#pragma once
#include "core/net/tcp_connection.h"
#include "core/loops/timing_wheel.h"
#include "core/net/tcp_connection.h"
#include <list>
#include <mutex>
#ifndef _WIN32
#include <unistd.h>
#endif
#include <thread>
#include <array>
#include <thread>
#ifdef USE_OPENSSL
enum class SSLStatus
{
Handshaking,
Connecting,
Connected,
DisConnecting,
DisConnected
enum class SSLStatus {
Handshaking,
Connecting,
Connected,
DisConnecting,
DisConnected
};
class SSLContext;
class SSLConn;
std::shared_ptr<SSLContext> newSSLContext(
bool useOldTLS,
bool validateCert,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
std::shared_ptr<SSLContext> newSSLServerContext(
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx,
// const std::string &certPath,
// const std::string &keyPath);
std::shared_ptr<SSLContext> newSSLContext(bool useOldTLS, bool validateCert, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
// void initServerSSLContext(const std::shared_ptr<SSLContext> &ctx, const std::string &certPath, const std::string &keyPath);
#endif
class Channel;
class Socket;
class TcpServer;
void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
class TcpConnectionImpl : public TcpConnection,
public std::enable_shared_from_this<TcpConnectionImpl>
{
friend class TcpServer;
friend class TcpClient;
friend void removeConnection(EventLoop *loop,
const TcpConnectionPtr &conn);
class TcpConnectionImpl : public TcpConnection, public std::enable_shared_from_this<TcpConnectionImpl> {
friend class TcpServer;
friend class TcpClient;
friend void removeConnection(EventLoop *loop, const TcpConnectionPtr &conn);
protected:
TcpConnectionImpl(const TcpConnectionImpl &) = delete;
@ -86,278 +77,247 @@ protected:
TcpConnectionImpl(TcpConnectionImpl &&) noexcept(true) = default;
TcpConnectionImpl &operator=(TcpConnectionImpl &&) noexcept(true) = default;
public:
class KickoffEntry
{
public:
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn)
: conn_(conn)
{
}
void reset()
{
conn_.reset();
}
~KickoffEntry()
{
auto conn = conn_.lock();
if (conn)
{
conn->forceClose();
}
}
public:
class KickoffEntry {
public:
explicit KickoffEntry(const std::weak_ptr<TcpConnection> &conn) :
conn_(conn) {
}
private:
std::weak_ptr<TcpConnection> conn_;
};
void reset() {
conn_.reset();
}
TcpConnectionImpl(EventLoop *loop,
int socketfd,
const InetAddress &localAddr,
const InetAddress &peerAddr);
~KickoffEntry() {
auto conn = conn_.lock();
if (conn) {
conn->forceClose();
}
}
private:
std::weak_ptr<TcpConnection> conn_;
};
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr);
#ifdef USE_OPENSSL
TcpConnectionImpl(EventLoop *loop,
int socketfd,
const InetAddress &localAddr,
const InetAddress &peerAddr,
const std::shared_ptr<SSLContext> &ctxPtr,
bool isServer = true,
bool validateCert = true,
const std::string &hostname = "");
TcpConnectionImpl(EventLoop *loop, int socketfd, const InetAddress &localAddr, const InetAddress &peerAddr, const std::shared_ptr<SSLContext> &ctxPtr, bool isServer = true, bool validateCert = true, const std::string &hostname = "");
#endif
virtual ~TcpConnectionImpl();
virtual void send(const char *msg, size_t len) override;
virtual void send(const void *msg, size_t len) override;
virtual void send(const std::string &msg) override;
virtual void send(std::string &&msg) override;
virtual void send(const MsgBuffer &buffer) override;
virtual void send(MsgBuffer &&buffer) override;
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
virtual void sendFile(const char *fileName,
size_t offset = 0,
size_t length = 0) override;
virtual const InetAddress &localAddr() const override
{
return localAddr_;
}
virtual const InetAddress &peerAddr() const override
{
return peerAddr_;
}
virtual ~TcpConnectionImpl();
virtual bool connected() const override
{
return status_ == ConnStatus::Connected;
}
virtual bool disconnected() const override
{
return status_ == ConnStatus::Disconnected;
}
virtual void send(const char *msg, size_t len) override;
virtual void send(const void *msg, size_t len) override;
virtual void send(const std::string &msg) override;
virtual void send(std::string &&msg) override;
virtual void send(const MsgBuffer &buffer) override;
virtual void send(MsgBuffer &&buffer) override;
virtual void send(const std::shared_ptr<std::string> &msgPtr) override;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) override;
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) override;
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
virtual MsgBuffer *getRecvBuffer() override
{
return &readBuffer_;
}
// set callbacks
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) override
{
highWaterMarkCallback_ = cb;
highWaterMarkLen_ = markLen;
}
virtual const InetAddress &localAddr() const override {
return localAddr_;
}
virtual void keepAlive() override
{
idleTimeout_ = 0;
auto entry = kickoffEntry_.lock();
if (entry)
{
entry->reset();
}
}
virtual bool isKeepAlive() override
{
return idleTimeout_ == 0;
}
virtual void setTcpNoDelay(bool on) override;
virtual void shutdown() override;
virtual void forceClose() override;
virtual EventLoop *getLoop() override
{
return loop_;
}
virtual const InetAddress &peerAddr() const override {
return peerAddr_;
}
virtual size_t bytesSent() const override
{
return bytesSent_;
}
virtual size_t bytesReceived() const override
{
return bytesReceived_;
}
virtual void startClientEncryption(
std::function<void()> callback,
bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string>> &sslConfCmds =
{}) override;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> callback) override;
virtual bool isSSLConnection() const override
{
return isEncrypted_;
}
virtual bool connected() const override {
return status_ == ConnStatus::Connected;
}
private:
/// Internal use only.
virtual bool disconnected() const override {
return status_ == ConnStatus::Disconnected;
}
std::weak_ptr<KickoffEntry> kickoffEntry_;
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
size_t idleTimeout_{0};
Date lastTimingWheelUpdateTime_;
// virtual MsgBuffer* getSendBuffer() override{ return &writeBuffer_;}
virtual MsgBuffer *getRecvBuffer() override {
return &readBuffer_;
}
void enableKickingOff(size_t timeout,
const std::shared_ptr<TimingWheel> &timingWheel)
{
assert(timingWheel);
assert(timingWheel->getLoop() == loop_);
assert(timeout > 0);
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
kickoffEntry_ = entry;
timingWheelWeakPtr_ = timingWheel;
idleTimeout_ = timeout;
timingWheel->insertEntry(timeout, entry);
}
void extendLife();
// set callbacks
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) override {
highWaterMarkCallback_ = cb;
highWaterMarkLen_ = markLen;
}
virtual void keepAlive() override {
idleTimeout_ = 0;
auto entry = kickoffEntry_.lock();
if (entry) {
entry->reset();
}
}
virtual bool isKeepAlive() override {
return idleTimeout_ == 0;
}
virtual void setTcpNoDelay(bool on) override;
virtual void shutdown() override;
virtual void forceClose() override;
virtual EventLoop *getLoop() override {
return loop_;
}
virtual size_t bytesSent() const override {
return bytesSent_;
}
virtual size_t bytesReceived() const override {
return bytesReceived_;
}
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) override;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) override;
virtual bool isSSLConnection() const override {
return isEncrypted_;
}
private:
/// Internal use only.
std::weak_ptr<KickoffEntry> kickoffEntry_;
std::weak_ptr<TimingWheel> timingWheelWeakPtr_;
size_t idleTimeout_{ 0 };
Date lastTimingWheelUpdateTime_;
void enableKickingOff(size_t timeout,
const std::shared_ptr<TimingWheel> &timingWheel) {
assert(timingWheel);
assert(timingWheel->getLoop() == loop_);
assert(timeout > 0);
auto entry = std::make_shared<KickoffEntry>(shared_from_this());
kickoffEntry_ = entry;
timingWheelWeakPtr_ = timingWheel;
idleTimeout_ = timeout;
timingWheel->insertEntry(timeout, entry);
}
void extendLife();
#ifndef _WIN32
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
void sendFile(int sfd, size_t offset = 0, size_t length = 0);
#else
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
void sendFile(FILE *fp, size_t offset = 0, size_t length = 0);
#endif
void setRecvMsgCallback(const RecvMessageCallback &cb)
{
recvMsgCallback_ = cb;
}
void setConnectionCallback(const ConnectionCallback &cb)
{
connectionCallback_ = cb;
}
void setWriteCompleteCallback(const WriteCompleteCallback &cb)
{
writeCompleteCallback_ = cb;
}
void setCloseCallback(const CloseCallback &cb)
{
closeCallback_ = cb;
}
void setSSLErrorCallback(const SSLErrorCallback &cb)
{
sslErrorCallback_ = cb;
}
void setRecvMsgCallback(const RecvMessageCallback &cb) {
recvMsgCallback_ = cb;
}
void setConnectionCallback(const ConnectionCallback &cb) {
connectionCallback_ = cb;
}
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
writeCompleteCallback_ = cb;
}
void setCloseCallback(const CloseCallback &cb) {
closeCallback_ = cb;
}
void setSSLErrorCallback(const SSLErrorCallback &cb) {
sslErrorCallback_ = cb;
}
void connectDestroyed();
virtual void connectEstablished();
void connectDestroyed();
virtual void connectEstablished();
protected:
struct BufferNode
{
protected:
struct BufferNode {
#ifndef _WIN32
int sendFd_{-1};
off_t offset_;
int sendFd_{ -1 };
off_t offset_;
#else
FILE *sendFp_{nullptr};
long long offset_;
FILE *sendFp_{ nullptr };
long long offset_;
#endif
ssize_t fileBytesToSend_;
std::shared_ptr<MsgBuffer> msgBuffer_;
~BufferNode()
{
ssize_t fileBytesToSend_;
std::shared_ptr<MsgBuffer> msgBuffer_;
~BufferNode() {
#ifndef _WIN32
if (sendFd_ >= 0)
close(sendFd_);
if (sendFd_ >= 0)
close(sendFd_);
#else
if (sendFp_)
fclose(sendFp_);
if (sendFp_)
fclose(sendFp_);
#endif
}
};
using BufferNodePtr = std::shared_ptr<BufferNode>;
enum class ConnStatus
{
Disconnected,
Connecting,
Connected,
Disconnecting
};
bool isEncrypted_{false};
EventLoop *loop_;
std::unique_ptr<Channel> ioChannelPtr_;
std::unique_ptr<Socket> socketPtr_;
MsgBuffer readBuffer_;
std::list<BufferNodePtr> writeBufferList_;
void readCallback();
void writeCallback();
InetAddress localAddr_, peerAddr_;
ConnStatus status_{ConnStatus::Connecting};
// callbacks
RecvMessageCallback recvMsgCallback_;
ConnectionCallback connectionCallback_;
CloseCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
SSLErrorCallback sslErrorCallback_;
void handleClose();
void handleError();
// virtual void sendInLoop(const std::string &msg);
}
};
void sendFileInLoop(const BufferNodePtr &file);
using BufferNodePtr = std::shared_ptr<BufferNode>;
enum class ConnStatus {
Disconnected,
Connecting,
Connected,
Disconnecting
};
bool isEncrypted_{ false };
EventLoop *loop_;
std::unique_ptr<Channel> ioChannelPtr_;
std::unique_ptr<Socket> socketPtr_;
MsgBuffer readBuffer_;
std::list<BufferNodePtr> writeBufferList_;
void readCallback();
void writeCallback();
InetAddress localAddr_, peerAddr_;
ConnStatus status_{ ConnStatus::Connecting };
// callbacks
RecvMessageCallback recvMsgCallback_;
ConnectionCallback connectionCallback_;
CloseCallback closeCallback_;
WriteCompleteCallback writeCompleteCallback_;
HighWaterMarkCallback highWaterMarkCallback_;
SSLErrorCallback sslErrorCallback_;
void handleClose();
void handleError();
// virtual void sendInLoop(const std::string &msg);
void sendFileInLoop(const BufferNodePtr &file);
#ifndef _WIN32
void sendInLoop(const void *buffer, size_t length);
ssize_t writeInLoop(const void *buffer, size_t length);
void sendInLoop(const void *buffer, size_t length);
ssize_t writeInLoop(const void *buffer, size_t length);
#else
void sendInLoop(const char *buffer, size_t length);
ssize_t writeInLoop(const char *buffer, size_t length);
void sendInLoop(const char *buffer, size_t length);
ssize_t writeInLoop(const char *buffer, size_t length);
#endif
size_t highWaterMarkLen_;
std::string name_;
size_t highWaterMarkLen_;
std::string name_;
uint64_t sendNum_{0};
std::mutex sendNumMutex_;
uint64_t sendNum_{ 0 };
std::mutex sendNumMutex_;
size_t bytesSent_{0};
size_t bytesReceived_{0};
size_t bytesSent_{ 0 };
size_t bytesReceived_{ 0 };
std::unique_ptr<std::vector<char>> fileBufferPtr_;
std::unique_ptr<std::vector<char> > fileBufferPtr_;
#ifdef USE_OPENSSL
private:
void doHandshaking();
bool validatePeerCertificate();
struct SSLEncryption
{
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
// OpenSSL
std::shared_ptr<SSLContext> sslCtxPtr_;
std::unique_ptr<SSLConn> sslPtr_;
std::unique_ptr<std::array<char, 8192>> sendBufferPtr_;
bool isServer_{false};
bool isUpgrade_{false};
std::function<void()> upgradeCallback_;
std::string hostname_;
};
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
void startClientEncryptionInLoop(
std::function<void()> &&callback,
bool useOldTLS,
bool validateCert,
const std::string &hostname,
const std::vector<std::pair<std::string, std::string>> &sslConfCmds);
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> &&callback);
private:
void doHandshaking();
bool validatePeerCertificate();
struct SSLEncryption {
SSLStatus statusOfSSL_ = SSLStatus::Handshaking;
// OpenSSL
std::shared_ptr<SSLContext> sslCtxPtr_;
std::unique_ptr<SSLConn> sslPtr_;
std::unique_ptr<std::array<char, 8192> > sendBufferPtr_;
bool isServer_{ false };
bool isUpgrade_{ false };
std::function<void()> upgradeCallback_;
std::string hostname_;
};
std::unique_ptr<SSLEncryption> sslEncryptionPtr_;
void startClientEncryptionInLoop(std::function<void()> &&callback, bool useOldTLS, bool validateCert, const std::string &hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds);
void startServerEncryptionInLoop(const std::shared_ptr<SSLContext> &ctx, std::function<void()> &&callback);
#endif
};

View File

@ -35,6 +35,7 @@
Connector::Connector(EventLoop *loop, const InetAddress &addr, bool retry) :
loop_(loop), serverAddr_(addr), retry_(retry) {
}
Connector::Connector(EventLoop *loop, InetAddress &&addr, bool retry) :
loop_(loop), serverAddr_(std::move(addr)), retry_(retry) {
}
@ -43,8 +44,10 @@ void Connector::start() {
connect_ = true;
loop_->runInLoop([this]() { startInLoop(); });
}
void Connector::restart() {
}
void Connector::stop() {
}
@ -57,6 +60,7 @@ void Connector::startInLoop() {
LOG_DEBUG << "do not connect";
}
}
void Connector::connect() {
int sockfd = Socket::createNonblockingSocketOrDie(serverAddr_.family());
errno = 0;

View File

@ -46,23 +46,30 @@ protected:
public:
using NewConnectionCallback = std::function<void(int sockfd)>;
using ConnectionErrorCallback = std::function<void()>;
Connector(EventLoop *loop, const InetAddress &addr, bool retry = true);
Connector(EventLoop *loop, InetAddress &&addr, bool retry = true);
void setNewConnectionCallback(const NewConnectionCallback &cb) {
newConnectionCallback_ = cb;
}
void setNewConnectionCallback(NewConnectionCallback &&cb) {
newConnectionCallback_ = std::move(cb);
}
void setErrorCallback(const ConnectionErrorCallback &cb) {
errorCallback_ = cb;
}
void setErrorCallback(ConnectionErrorCallback &&cb) {
errorCallback_ = std::move(cb);
}
const InetAddress &serverAddress() const {
return serverAddr_;
}
void start();
void restart();
void stop();
@ -70,11 +77,13 @@ public:
private:
NewConnectionCallback newConnectionCallback_;
ConnectionErrorCallback errorCallback_;
enum class Status {
Disconnected,
Connecting,
Connected
};
static constexpr int kMaxRetryDelayMs = 30 * 1000;
static constexpr int kInitRetryDelayMs = 500;
std::shared_ptr<Channel> channelPtr_;

View File

@ -98,8 +98,7 @@ InetAddress::InetAddress(uint16_t port, bool loopbackOnly, bool ipv6) :
isUnspecified_ = false;
}
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) :
isIpV6_(ipv6) {
InetAddress::InetAddress(const std::string &ip, uint16_t port, bool ipv6) : isIpV6_(ipv6) {
if (ipv6) {
memset(&addr6_, 0, sizeof(addr6_));
addr6_.sin6_family = AF_INET6;

View File

@ -38,195 +38,76 @@ using sa_family_t = unsigned short;
using in_addr_t = uint32_t;
using uint16_t = unsigned short;
#else
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#endif
#include <mutex>
#include <string>
#include <unordered_map>
#include <mutex>
/**
* @brief Wrapper of sockaddr_in. This is an POD interface class.
*
*/
class InetAddress
{
public:
/**
* @brief Constructs an endpoint with given port number. Mostly used in
* TcpServer listening.
*
* @param port
* @param loopbackOnly
* @param ipv6
*/
InetAddress(uint16_t port = 0,
bool loopbackOnly = false,
bool ipv6 = false);
class InetAddress {
public:
InetAddress(uint16_t port = 0, bool loopbackOnly = false, bool ipv6 = false);
/**
* @brief Constructs an endpoint with given ip and port.
*
* @param ip A IPv4 or IPv6 address.
* @param port
* @param ipv6
*/
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
InetAddress(const std::string &ip, uint16_t port, bool ipv6 = false);
/**
* @brief Constructs an endpoint with given struct `sockaddr_in`. Mostly
* used when accepting new connections
*
* @param addr
*/
explicit InetAddress(const struct sockaddr_in &addr)
: addr_(addr), isUnspecified_(false)
{
}
explicit InetAddress(const struct sockaddr_in &addr) :
addr_(addr), isUnspecified_(false) {
}
/**
* @brief Constructs an IPv6 endpoint with given struct `sockaddr_in6`.
* Mostly used when accepting new connections
*
* @param addr
*/
explicit InetAddress(const struct sockaddr_in6 &addr)
: addr6_(addr), isIpV6_(true), isUnspecified_(false)
{
}
explicit InetAddress(const struct sockaddr_in6 &addr) :
addr6_(addr), isIpV6_(true), isUnspecified_(false) {
}
/**
* @brief Return the sin_family of the endpoint.
*
* @return sa_family_t
*/
sa_family_t family() const
{
return addr_.sin_family;
}
sa_family_t family() const {
return addr_.sin_family;
}
/**
* @brief Return the IP string of the endpoint.
*
* @return std::string
*/
std::string toIp() const;
std::string toIp() const;
std::string toIpPort() const;
uint16_t toPort() const;
/**
* @brief Return the IP and port string of the endpoint.
*
* @return std::string
*/
std::string toIpPort() const;
bool isIpV6() const {
return isIpV6_;
}
/**
* @brief Return the port number of the endpoint.
*
* @return uint16_t
*/
uint16_t toPort() const;
bool isIntranetIp() const;
bool isLoopbackIp() const;
/**
* @brief Check if the endpoint is IPv4 or IPv6.
*
* @return true
* @return false
*/
bool isIpV6() const
{
return isIpV6_;
}
const struct sockaddr *getSockAddr() const {
return static_cast<const struct sockaddr *>((void *)(&addr6_));
}
/**
* @brief Return true if the endpoint is an intranet endpoint.
*
* @return true
* @return false
*/
bool isIntranetIp() const;
void setSockAddrInet6(const struct sockaddr_in6 &addr6) {
addr6_ = addr6;
isIpV6_ = (addr6_.sin6_family == AF_INET6);
isUnspecified_ = false;
}
/**
* @brief Return true if the endpoint is a loopback endpoint.
*
* @return true
* @return false
*/
bool isLoopbackIp() const;
uint32_t ipNetEndian() const;
const uint32_t *ip6NetEndian() const;
/**
* @brief Get the pointer to the sockaddr struct.
*
* @return const struct sockaddr*
*/
const struct sockaddr *getSockAddr() const
{
return static_cast<const struct sockaddr *>((void *)(&addr6_));
}
uint16_t portNetEndian() const {
return addr_.sin_port;
}
/**
* @brief Set the sockaddr_in6 struct in the endpoint.
*
* @param addr6
*/
void setSockAddrInet6(const struct sockaddr_in6 &addr6)
{
addr6_ = addr6;
isIpV6_ = (addr6_.sin6_family == AF_INET6);
isUnspecified_ = false;
}
void setPortNetEndian(uint16_t port) {
addr_.sin_port = port;
}
/**
* @brief Return the integer value of the IP(v4) in net endian byte order.
*
* @return uint32_t
*/
uint32_t ipNetEndian() const;
inline bool isUnspecified() const {
return isUnspecified_;
}
/**
* @brief Return the pointer to the integer value of the IP(v6) in net
* endian byte order.
*
* @return const uint32_t*
*/
const uint32_t *ip6NetEndian() const;
private:
union {
struct sockaddr_in addr_;
struct sockaddr_in6 addr6_;
};
/**
* @brief Return the port number in net endian byte order.
*
* @return uint16_t
*/
uint16_t portNetEndian() const
{
return addr_.sin_port;
}
/**
* @brief Set the port number in net endian byte order.
*
* @param port
*/
void setPortNetEndian(uint16_t port)
{
addr_.sin_port = port;
}
/**
* @brief Return true if the address is not initalized.
*/
inline bool isUnspecified() const
{
return isUnspecified_;
}
private:
union
{
struct sockaddr_in addr_;
struct sockaddr_in6 addr6_;
};
bool isIpV6_{false};
bool isUnspecified_{true};
bool isIpV6_{ false };
bool isUnspecified_{ true };
};
#endif // MUDUO_NET_INETADDRESS_H
#endif // MUDUO_NET_INETADDRESS_H

View File

@ -34,44 +34,20 @@
#include "core/loops/event_loop.h"
#include "core/net/inet_address.h"
/**
* @brief This class represents an asynchronous DNS resolver.
* @note Although the c-ares library is not essential, it is recommended to
* install it for higher performance
*/
//make it a reference
class Resolver
{
public:
using Callback = std::function<void(const InetAddress&)>;
/**
* @brief Create a new DNS resolver.
*
* @param loop The event loop in which the DNS resolver runs.
* @param timeout The timeout in seconds for DNS.
* @return std::shared_ptr<Resolver>
*/
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr,
size_t timeout = 60);
static std::shared_ptr<Resolver> newResolver(EventLoop* loop = nullptr, size_t timeout = 60);
/**
* @brief Resolve an address asynchronously.
*
* @param hostname
* @param callback
*/
virtual void resolve(const std::string& hostname,
const Callback& callback) = 0;
virtual void resolve(const std::string& hostname, const Callback& callback) = 0;
virtual ~Resolver()
{
}
/**
* @brief Check whether the c-ares library is used.
*
* @return true
* @return false
*/
static bool isCAresUsed();
};

View File

@ -29,8 +29,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ares_resolver.h"
#include <ares.h>
#include "core/loops/channel.h"
#include <ares.h>
#ifdef _WIN32
#include <winsock2.h>
#else
@ -67,14 +67,14 @@ bool Resolver::isCAresUsed() {
AresResolver::LibraryInitializer::LibraryInitializer() {
ares_library_init(ARES_LIB_INIT_ALL);
}
AresResolver::LibraryInitializer::~LibraryInitializer() {
ares_library_cleanup();
}
AresResolver::LibraryInitializer AresResolver::libraryInitializer_;
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop,
size_t timeout) {
std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *loop, size_t timeout) {
return std::make_shared<AresResolver>(loop, timeout);
}
@ -84,6 +84,7 @@ AresResolver::AresResolver(EventLoop *loop, size_t timeout) :
loop_ = getLoop();
}
}
void AresResolver::init() {
if (!ctx_) {
struct ares_options options;
@ -108,14 +109,16 @@ void AresResolver::init() {
this);
}
}
AresResolver::~AresResolver() {
if (ctx_)
ares_destroy(ctx_);
}
void AresResolver::resolveInLoop(const std::string &hostname,
const Callback &cb) {
void AresResolver::resolveInLoop(const std::string &hostname, const Callback &cb) {
loop_->assertInLoopThread();
#ifdef _WIN32
if (hostname == "localhost") {
const static InetAddress localhost_{ "127.0.0.1", 0 };
@ -123,20 +126,17 @@ void AresResolver::resolveInLoop(const std::string &hostname,
return;
}
#endif
init();
QueryData *queryData = new QueryData(this, cb, hostname);
ares_gethostbyname(ctx_,
hostname.c_str(),
AF_INET,
&AresResolver::ares_hostcallback_,
queryData);
ares_gethostbyname(ctx_, hostname.c_str(), AF_INET, &AresResolver::ares_hostcallback_, queryData);
struct timeval tv;
struct timeval *tvp = ares_timeout(ctx_, NULL, &tv);
double timeout = getSeconds(tvp);
// LOG_DEBUG << "timeout " << timeout << " active " << timerActive_;
if (!timerActive_ && timeout >= 0.0) {
loop_->runAfter(timeout,
std::bind(&AresResolver::onTimer, shared_from_this()));
loop_->runAfter(timeout, std::bind(&AresResolver::onTimer, shared_from_this()));
timerActive_ = true;
}
return;
@ -161,18 +161,18 @@ void AresResolver::onTimer() {
}
}
void AresResolver::onQueryResult(int status,
struct hostent *result,
const std::string &hostname,
const Callback &callback) {
void AresResolver::onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback) {
LOG_TRACE << "onQueryResult " << status;
struct sockaddr_in addr;
memset(&addr, 0, sizeof addr);
addr.sin_family = AF_INET;
addr.sin_port = 0;
if (result) {
addr.sin_addr = *reinterpret_cast<in_addr *>(result->h_addr);
}
InetAddress inet(addr);
{
std::lock_guard<std::mutex> lock(globalMutex());
@ -180,6 +180,7 @@ void AresResolver::onQueryResult(int status,
addrItem.first = addr.sin_addr;
addrItem.second = Date::date();
}
callback(inet);
}
@ -198,6 +199,7 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
loop_->assertInLoopThread();
ChannelList::iterator it = channels_.find(sockfd);
assert(it != channels_.end());
if (read) {
// update
// if (write) { } else { }
@ -209,17 +211,12 @@ void AresResolver::onSockStateChange(int sockfd, bool read, bool write) {
}
}
void AresResolver::ares_hostcallback_(void *data,
int status,
int timeouts,
struct hostent *hostent) {
void AresResolver::ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent) {
(void)timeouts;
QueryData *query = static_cast<QueryData *>(data);
query->owner_->onQueryResult(status,
hostent,
query->hostname_,
query->callback_);
query->owner_->onQueryResult(status, hostent, query->hostname_, query->callback_);
delete query;
}
@ -242,6 +239,7 @@ void AresResolver::ares_sock_statecallback_(void *data,
#endif
int read,
int write) {
LOG_TRACE << "sockfd=" << sockfd << " read=" << read << " write=" << write;
static_cast<AresResolver *>(data)->onSockStateChange(sockfd, read, write);
}

View File

@ -36,14 +36,15 @@
#include <map>
#include <memory>
// Resolver will be a ref
extern "C" {
struct hostent;
struct ares_channeldata;
using ares_channel = struct ares_channeldata *;
}
class AresResolver : public Resolver,
public std::enable_shared_from_this<AresResolver> {
class AresResolver : public Resolver, public std::enable_shared_from_this<AresResolver> {
protected:
AresResolver(const AresResolver &) = delete;
AresResolver &operator=(const AresResolver &) = delete;
@ -55,9 +56,9 @@ public:
AresResolver(EventLoop *loop, size_t timeout);
~AresResolver();
virtual void resolve(const std::string &hostname,
const Callback &cb) override {
virtual void resolve(const std::string &hostname, const Callback &cb) override {
bool cached = false;
InetAddress inet;
{
std::lock_guard<std::mutex> lock(globalMutex());
@ -76,10 +77,12 @@ public:
}
}
}
if (cached) {
cb(inet);
return;
}
if (loop_->isInLoopThread()) {
resolveInLoop(hostname, cb);
} else {
@ -94,14 +97,15 @@ private:
AresResolver *owner_;
Callback callback_;
std::string hostname_;
QueryData(AresResolver *o,
const Callback &cb,
const std::string &hostname) :
QueryData(AresResolver *o, const Callback &cb, const std::string &hostname) :
owner_(o), callback_(cb), hostname_(hostname) {
}
};
void resolveInLoop(const std::string &hostname, const Callback &cb);
void init();
EventLoop *loop_;
ares_channel ctx_{ nullptr };
bool timerActive_{ false };
@ -109,36 +113,33 @@ private:
ChannelList channels_;
static std::unordered_map<std::string,
std::pair<struct in_addr, Date> > &
globalCache() {
static std::unordered_map<std::string,
std::pair<struct in_addr, Date> >
dnsCache;
static std::unordered_map<std::string, std::pair<struct in_addr, Date> > dnsCache;
return dnsCache;
}
static std::mutex &globalMutex() {
static std::mutex mutex_;
return mutex_;
}
static EventLoop *getLoop() {
static EventLoopThread loopThread;
loopThread.run();
return loopThread.getLoop();
}
const size_t timeout_{ 60 };
void onRead(int sockfd);
void onTimer();
void onQueryResult(int status,
struct hostent *result,
const std::string &hostname,
const Callback &callback);
void onQueryResult(int status, struct hostent *result, const std::string &hostname, const Callback &callback);
void onSockCreate(int sockfd, int type);
void onSockStateChange(int sockfd, bool read, bool write);
static void ares_hostcallback_(void *data,
int status,
int timeouts,
struct hostent *hostent);
static void ares_hostcallback_(void *data, int status, int timeouts, struct hostent *hostent);
#ifdef _WIN32
static int ares_sock_createcallback_(SOCKET sockfd, int type, void *data);
#else
@ -152,9 +153,11 @@ private:
#endif
int read,
int write);
struct LibraryInitializer {
LibraryInitializer();
~LibraryInitializer();
};
static LibraryInitializer libraryInitializer_;
};

View File

@ -44,9 +44,11 @@ std::shared_ptr<Resolver> Resolver::newResolver(EventLoop *,
size_t timeout) {
return std::make_shared<NormalResolver>(timeout);
}
bool Resolver::isCAresUsed() {
return false;
}
void NormalResolver::resolve(const std::string &hostname,
const Callback &callback) {
{

View File

@ -35,10 +35,12 @@
#include <thread>
#include <vector>
//Resolver will be a ref
constexpr size_t kResolveBufferLength{ 16 * 1024 };
class NormalResolver : public Resolver,
public std::enable_shared_from_this<NormalResolver> {
class NormalResolver : public Resolver, public std::enable_shared_from_this<NormalResolver> {
protected:
NormalResolver(const NormalResolver &) = delete;
NormalResolver &operator=(const NormalResolver &) = delete;
@ -56,25 +58,23 @@ public:
}
private:
static std::unordered_map<std::string,
std::pair<InetAddress, Date> > &
globalCache() {
static std::unordered_map<
std::string,
std::pair<InetAddress, Date> >
dnsCache_;
static std::unordered_map<std::string, std::pair<InetAddress, Date> > &globalCache() {
static std::unordered_map<std::string, std::pair<InetAddress, Date> > dnsCache_;
return dnsCache_;
}
static std::mutex &globalMutex() {
static std::mutex mutex_;
return mutex_;
}
static ConcurrentTaskQueue &concurrentTaskQueue() {
static ConcurrentTaskQueue queue(
std::thread::hardware_concurrency() < 8 ? 8 : std::thread::hardware_concurrency(),
"Dns Queue");
return queue;
}
const size_t timeout_;
std::vector<char> resolveBuffer_;
};

View File

@ -29,9 +29,9 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "socket.h"
#include "core/log/logger.h"
#include <assert.h>
#include <sys/types.h>
#include "core/log/logger.h"
#ifdef _WIN32
#include <ws2tcpip.h>
#else
@ -39,7 +39,6 @@
#include <sys/socket.h>
#endif
bool Socket::isSelfConnect(int sockfd) {
struct sockaddr_in6 localaddr = getLocalAddr(sockfd);
struct sockaddr_in6 peeraddr = getPeerAddr(sockfd);
@ -75,6 +74,7 @@ void Socket::bindAddress(const InetAddress &localaddr) {
exit(1);
}
}
void Socket::listen() {
assert(sockFd_ > 0);
int ret = ::listen(sockFd_, SOMAXCONN);
@ -83,6 +83,7 @@ void Socket::listen() {
exit(1);
}
}
int Socket::accept(InetAddress *peeraddr) {
struct sockaddr_in6 addr6;
memset(&addr6, 0, sizeof(addr6));
@ -102,6 +103,7 @@ int Socket::accept(InetAddress *peeraddr) {
}
return connfd;
}
void Socket::closeWrite() {
#ifndef _WIN32
if (::shutdown(sockFd_, SHUT_WR) < 0)
@ -112,6 +114,7 @@ void Socket::closeWrite() {
LOG_SYSERR << "sockets::shutdownWrite";
}
}
int Socket::read(char *buffer, uint64_t len) {
#ifndef _WIN32
return ::read(sockFd_, buffer, len);

View File

@ -32,8 +32,8 @@
#pragma once
#include "core/net/inet_address.h"
#include "core/log/logger.h"
#include "core/net/inet_address.h"
#include <string>
#ifndef _WIN32
#include <unistd.h>
@ -41,7 +41,6 @@
#include <fcntl.h>
class Socket {
protected:
Socket(const Socket &) = delete;
Socket &operator=(const Socket &) = delete;
@ -63,6 +62,7 @@ public:
LOG_SYSERR << "sockets::createNonblockingOrDie";
exit(1);
}
LOG_TRACE << "sock=" << sock;
return sock;
}
@ -85,15 +85,9 @@ public:
static int connect(int sockfd, const InetAddress &addr) {
if (addr.isIpV6())
return ::connect(sockfd,
addr.getSockAddr(),
static_cast<socklen_t>(
sizeof(struct sockaddr_in6)));
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in6)));
else
return ::connect(sockfd,
addr.getSockAddr(),
static_cast<socklen_t>(
sizeof(struct sockaddr_in)));
return ::connect(sockfd, addr.getSockAddr(), static_cast<socklen_t>(sizeof(struct sockaddr_in)));
}
static bool isSelfConnect(int sockfd);
@ -101,7 +95,9 @@ public:
explicit Socket(int sockfd) :
sockFd_(sockfd) {
}
~Socket();
/// abort if address in use
void bindAddress(const InetAddress &localaddr);
/// abort if address in use
@ -109,9 +105,11 @@ public:
int accept(InetAddress *peeraddr);
void closeWrite();
int read(char *buffer, uint64_t len);
int fd() {
return sockFd_;
}
static struct sockaddr_in6 getLocalAddr(int sockfd);
static struct sockaddr_in6 getPeerAddr(int sockfd);

View File

@ -54,9 +54,7 @@ TcpClient::IgnoreSigPipe TcpClient::initObj;
#endif
static void defaultConnectionCallback(const TcpConnectionPtr &conn) {
LOG_TRACE << conn->localAddr().toIpPort() << " -> "
<< conn->peerAddr().toIpPort() << " is "
<< (conn->connected() ? "UP" : "DOWN");
LOG_TRACE << conn->localAddr().toIpPort() << " -> " << conn->peerAddr().toIpPort() << " is " << (conn->connected() ? "UP" : "DOWN");
// do not call conn->forceClose(), because some users want to register
// message callback only.
}
@ -65,9 +63,7 @@ static void defaultMessageCallback(const TcpConnectionPtr &, MsgBuffer *buf) {
buf->retrieveAll();
}
TcpClient::TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg) :
TcpClient::TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg) :
loop_(loop),
connector_(new Connector(loop, serverAddr, false)),
name_(nameArg),
@ -75,23 +71,27 @@ TcpClient::TcpClient(EventLoop *loop,
messageCallback_(defaultMessageCallback),
retry_(false),
connect_(true) {
connector_->setNewConnectionCallback(
std::bind(&TcpClient::newConnection, this, _1));
connector_->setNewConnectionCallback(std::bind(&TcpClient::newConnection, this, _1));
connector_->setErrorCallback([this]() {
if (connectionErrorCallback_) {
connectionErrorCallback_();
}
});
LOG_TRACE << "TcpClient::TcpClient[" << name_ << "] - connector ";
}
TcpClient::~TcpClient() {
LOG_TRACE << "TcpClient::~TcpClient[" << name_ << "] - connector ";
TcpConnectionImplPtr conn;
{
std::lock_guard<std::mutex> lock(mutex_);
conn = std::dynamic_pointer_cast<TcpConnectionImpl>(connection_);
}
if (conn) {
assert(loop_ == conn->getLoop());
// TODO: not 100% safe, if we are in different thread
@ -104,6 +104,7 @@ TcpClient::~TcpClient() {
});
});
});
conn->forceClose();
} else {
/// TODO need test in this condition
@ -142,39 +143,34 @@ void TcpClient::newConnection(int sockfd) {
// TODO poll with zero timeout to double confirm the new connection
// TODO use make_shared if necessary
std::shared_ptr<TcpConnectionImpl> conn;
if (sslCtxPtr_) {
#ifdef USE_OPENSSL
conn = std::make_shared<TcpConnectionImpl>(loop_,
sockfd,
localAddr,
peerAddr,
sslCtxPtr_,
false,
validateCert_,
SSLHostName_);
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr, sslCtxPtr_, false, validateCert_, SSLHostName_);
#else
LOG_FATAL << "OpenSSL is not found in your system!";
abort();
#endif
} else {
conn = std::make_shared<TcpConnectionImpl>(loop_,
sockfd,
localAddr,
peerAddr);
conn = std::make_shared<TcpConnectionImpl>(loop_, sockfd, localAddr, peerAddr);
}
conn->setConnectionCallback(connectionCallback_);
conn->setRecvMsgCallback(messageCallback_);
conn->setWriteCompleteCallback(writeCompleteCallback_);
conn->setCloseCallback(std::bind(&TcpClient::removeConnection, this, _1));
{
std::lock_guard<std::mutex> lock(mutex_);
connection_ = conn;
}
conn->setSSLErrorCallback([this](SSLError err) {
if (sslErrorCallback_) {
sslErrorCallback_(err);
}
});
conn->connectEstablished();
}
@ -188,30 +184,23 @@ void TcpClient::removeConnection(const TcpConnectionPtr &conn) {
connection_.reset();
}
loop_->queueInLoop(
std::bind(&TcpConnectionImpl::connectDestroyed,
std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
loop_->queueInLoop(std::bind(&TcpConnectionImpl::connectDestroyed, std::dynamic_pointer_cast<TcpConnectionImpl>(conn)));
if (retry_ && connect_) {
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to "
<< connector_->serverAddress().toIpPort();
LOG_TRACE << "TcpClient::connect[" << name_ << "] - Reconnecting to " << connector_->serverAddress().toIpPort();
connector_->restart();
}
}
void TcpClient::enableSSL(
bool useOldTLS,
bool validateCert,
std::string hostname,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
void TcpClient::enableSSL(bool useOldTLS, bool validateCert, std::string hostname, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
#ifdef USE_OPENSSL
/* Create a new OpenSSL context */
sslCtxPtr_ = newSSLContext(useOldTLS, validateCert, sslConfCmds);
validateCert_ = validateCert;
if (!hostname.empty()) {
std::transform(hostname.begin(),
hostname.end(),
hostname.begin(),
tolower);
std::transform(hostname.begin(), hostname.end(), hostname.begin(), tolower);
SSLHostName_ = std::move(hostname);
}

View File

@ -42,10 +42,7 @@
class Connector;
using ConnectorPtr = std::shared_ptr<Connector>;
class SSLContext;
/**
* @brief This class represents a TCP client.
*
*/
class TcpClient {
protected:
TcpClient(const TcpClient &) = delete;
@ -55,88 +52,35 @@ protected:
TcpClient &operator=(TcpClient &&) noexcept(true) = default;
public:
/**
* @brief Construct a new TCP client instance.
*
* @param loop The event loop in which the client runs.
* @param serverAddr The address of the server.
* @param nameArg The name of the client.
*/
TcpClient(EventLoop *loop,
const InetAddress &serverAddr,
const std::string &nameArg);
TcpClient(EventLoop *loop, const InetAddress &serverAddr, const std::string &nameArg);
~TcpClient();
/**
* @brief Connect to the server.
*
*/
void connect();
/**
* @brief Disconnect from the server.
*
*/
void disconnect();
/**
* @brief Stop connecting to the server.
*
*/
void stop();
/**
* @brief Get the TCP connection to the server.
*
* @return TcpConnectionPtr
*/
TcpConnectionPtr connection() const {
std::lock_guard<std::mutex> lock(mutex_);
return connection_;
}
/**
* @brief Get the event loop.
*
* @return EventLoop*
*/
EventLoop *getLoop() const {
return loop_;
}
/**
* @brief Check whether the client re-connect to the server.
*
* @return true
* @return false
*/
bool retry() const {
return retry_;
}
/**
* @brief Enable retrying.
*
*/
void enableRetry() {
retry_ = true;
}
/**
* @brief Get the name of the client.
*
* @return const std::string&
*/
const std::string &name() const {
return name_;
}
/**
* @brief Set the connection callback.
*
* @param cb The callback is called when the connection to the server is
* established or closed.
*/
void setConnectionCallback(const ConnectionCallback &cb) {
connectionCallback_ = cb;
}
@ -144,37 +88,18 @@ public:
connectionCallback_ = std::move(cb);
}
/**
* @brief Set the connection error callback.
*
* @param cb The callback is called when an error occurs during connecting
* to the server.
*/
void setConnectionErrorCallback(const ConnectionErrorCallback &cb) {
connectionErrorCallback_ = cb;
}
/**
* @brief Set the message callback.
*
* @param cb The callback is called when some data is received from the
* server.
*/
void setMessageCallback(const RecvMessageCallback &cb) {
messageCallback_ = cb;
}
void setMessageCallback(RecvMessageCallback &&cb) {
messageCallback_ = std::move(cb);
}
/// Set write complete callback.
/// Not thread safe.
/**
* @brief Set the write complete callback.
*
* @param cb The callback is called when data to send is written to the
* socket.
*/
/// Not thread safe.
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
writeCompleteCallback_ = cb;
}
@ -182,10 +107,6 @@ public:
writeCompleteCallback_ = std::move(cb);
}
/**
* @brief Set the callback for errors of SSL
* @param cb The callback is called when an SSL error occurs.
*/
void setSSLErrorCallback(const SSLErrorCallback &cb) {
sslErrorCallback_ = cb;
}
@ -193,24 +114,7 @@ public:
sslErrorCallback_ = std::move(cb);
}
/**
* @brief Enable SSL encryption.
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
* client.
* @param validateCert If true, we try to validate if the peer's SSL cert
* is valid.
* @param hostname The server hostname for SNI. If it is empty, the SNI is
* not used.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
void enableSSL(bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
private:
/// Not thread safe, but in loop
@ -234,6 +138,7 @@ private:
std::shared_ptr<SSLContext> sslCtxPtr_;
bool validateCert_{ false };
std::string SSLHostName_;
#ifndef _WIN32
class IgnoreSigPipe {
public:

View File

@ -39,26 +39,13 @@
#include <string>
class SSLContext;
std::shared_ptr<SSLContext> newSSLServerContext(
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
/**
* @brief This class represents a TCP connection.
*
*/
std::shared_ptr<SSLContext> newSSLServerContext(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
class TcpConnection {
public:
TcpConnection() = default;
virtual ~TcpConnection(){};
/**
* @brief Send some data to the peer.
*
* @param msg
* @param len
*/
virtual void send(const char *msg, size_t len) = 0;
virtual void send(const void *msg, size_t len) = 0;
virtual void send(const std::string &msg) = 0;
@ -68,95 +55,27 @@ public:
virtual void send(const std::shared_ptr<std::string> &msgPtr) = 0;
virtual void send(const std::shared_ptr<MsgBuffer> &msgPtr) = 0;
/**
* @brief Send a file to the peer.
*
* @param fileName
* @param offset
* @param length
*/
virtual void sendFile(const char *fileName,
size_t offset = 0,
size_t length = 0) = 0;
virtual void sendFile(const char *fileName, size_t offset = 0, size_t length = 0) = 0;
/**
* @brief Get the local address of the connection.
*
* @return const InetAddress&
*/
virtual const InetAddress &localAddr() const = 0;
/**
* @brief Get the remote address of the connection.
*
* @return const InetAddress&
*/
virtual const InetAddress &peerAddr() const = 0;
/**
* @brief Return true if the connection is established.
*
* @return true
* @return false
*/
virtual bool connected() const = 0;
/**
* @brief Return false if the connection is established.
*
* @return true
* @return false
*/
virtual bool disconnected() const = 0;
/**
* @brief Get the buffer in which the received data stored.
*
* @return MsgBuffer*
*/
virtual MsgBuffer *getRecvBuffer() = 0;
/**
* @brief Set the high water mark callback
*
* @param cb The callback is called when the data in sending buffer is
* larger than the water mark.
* @param markLen The water mark in bytes.
*/
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb,
size_t markLen) = 0;
/**
* @brief Set the TCP_NODELAY option to the socket.
*
* @param on
*/
virtual void setHighWaterMarkCallback(const HighWaterMarkCallback &cb, size_t markLen) = 0;
virtual void setTcpNoDelay(bool on) = 0;
/**
* @brief Shutdown the connection.
* @note This method only closes the writing direction.
*/
virtual void shutdown() = 0;
/**
* @brief Close the connection forcefully.
*
*/
virtual void forceClose() = 0;
/**
* @brief Get the event loop in which the connection I/O is handled.
*
* @return EventLoop*
*/
virtual EventLoop *getLoop() = 0;
/**
* @brief Set the custom data on the connection.
*
* @param context
*/
void setContext(const std::shared_ptr<void> &context) {
contextPtr_ = context;
}
@ -164,98 +83,30 @@ public:
contextPtr_ = std::move(context);
}
/**
* @brief Get the custom data from the connection.
*
* @tparam T
* @return std::shared_ptr<T>
*/
template <typename T>
std::shared_ptr<T> getContext() const {
return std::static_pointer_cast<T>(contextPtr_);
}
/**
* @brief Return true if the custom data is set by user.
*
* @return true
* @return false
*/
bool hasContext() const {
return (bool)contextPtr_;
}
/**
* @brief Clear the custom data.
*
*/
void clearContext() {
contextPtr_.reset();
}
/**
* @brief Call this method to avoid being kicked off by TcpServer, refer to
* the kickoffIdleConnections method in the TcpServer class.
*
*/
virtual void keepAlive() = 0;
/**
* @brief Return true if the keepAlive() method is called.
*
* @return true
* @return false
*/
virtual bool isKeepAlive() = 0;
/**
* @brief Return the number of bytes sent
*
* @return size_t
*/
virtual size_t bytesSent() const = 0;
/**
* @brief Return the number of bytes received.
*
* @return size_t
*/
virtual size_t bytesReceived() const = 0;
/**
* @brief Check whether the connection is SSL encrypted.
*
* @return true
* @return false
*/
virtual bool isSSLConnection() const = 0;
/**
* @brief Start the SSL encryption on the connection (as a client).
*
* @param callback The callback is called when the SSL connection is
* established.
* @param hostname The server hostname for SNI. If it is empty, the SNI is
* not used.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
*/
virtual void startClientEncryption(
std::function<void()> callback,
bool useOldTLS = false,
bool validateCert = true,
std::string hostname = "",
const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
virtual void startClientEncryption(std::function<void()> callback, bool useOldTLS = false, bool validateCert = true, std::string hostname = "", const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {}) = 0;
/**
* @brief Start the SSL encryption on the connection (as a server).
*
* @param ctx The SSL context.
* @param callback The callback is called when the SSL connection is
* established.
*/
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx,
std::function<void()> callback) = 0;
virtual void startServerEncryption(const std::shared_ptr<SSLContext> &ctx, std::function<void()> callback) = 0;
protected:
bool validateCert_ = false;

View File

@ -28,20 +28,16 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/loops/acceptor.h"
#include "core/net/connections/tcp_connection_impl.h"
#include "core/net/tcp_server.h"
#include "core/log/logger.h"
#include "core/loops/acceptor.h"
#include "core/net/connections/tcp_connection_impl.h"
#include <functional>
#include <vector>
using namespace std::placeholders;
TcpServer::TcpServer(EventLoop *loop,
const InetAddress &address,
const std::string &name,
bool reUseAddr,
bool reUsePort) :
TcpServer::TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr, bool reUsePort) :
loop_(loop),
acceptorPtr_(new Acceptor(loop, address, reUseAddr, reUsePort)),
serverName_(name),
@ -50,8 +46,8 @@ TcpServer::TcpServer(EventLoop *loop,
<< " bytes]";
buffer->retrieveAll();
}) {
acceptorPtr_->setNewConnectionCallback(
std::bind(&TcpServer::newConnection, this, _1, _2));
acceptorPtr_->setNewConnectionCallback(std::bind(&TcpServer::newConnection, this, _1, _2));
}
TcpServer::~TcpServer() {
@ -69,14 +65,20 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
// LOG_TRACE<<"vector size:"<<str.size();
// size_t n=write(sockfd,&str[0],str.size());
// LOG_TRACE<<"write "<<n<<" bytes";
loop_->assertInLoopThread();
EventLoop *ioLoop = NULL;
if (loopPoolPtr_ && loopPoolPtr_->size() > 0) {
ioLoop = loopPoolPtr_->getNextLoop();
}
if (ioLoop == NULL)
if (ioLoop == NULL) {
ioLoop = loop_;
}
std::shared_ptr<TcpConnectionImpl> newPtr;
if (sslCtxPtr_) {
#ifdef USE_OPENSSL
newPtr = std::make_shared<TcpConnectionImpl>(
@ -90,14 +92,14 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
abort();
#endif
} else {
newPtr = std::make_shared<TcpConnectionImpl>(
ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
newPtr = std::make_shared<TcpConnectionImpl>(ioLoop, sockfd, InetAddress(Socket::getLocalAddr(sockfd)), peer);
}
if (idleTimeout_ > 0) {
assert(timingWheelMap_[ioLoop]);
newPtr->enableKickingOff(idleTimeout_, timingWheelMap_[ioLoop]);
}
newPtr->setRecvMsgCallback(recvMessageCallback_);
newPtr->setConnectionCallback(
@ -105,11 +107,13 @@ void TcpServer::newConnection(int sockfd, const InetAddress &peer) {
if (connectionCallback_)
connectionCallback_(connectionPtr);
});
newPtr->setWriteCompleteCallback(
[this](const TcpConnectionPtr &connectionPtr) {
if (writeCompleteCallback_)
writeCompleteCallback_(connectionPtr);
});
newPtr->setCloseCallback(std::bind(&TcpServer::connectionClosed, this, _1));
connSet_.insert(newPtr);
newPtr->connectEstablished();
@ -143,6 +147,7 @@ void TcpServer::start() {
acceptorPtr_->listen();
});
}
void TcpServer::stop() {
loop_->runInLoop([this]() { acceptorPtr_.reset(); });
for (auto connection : connSet_) {
@ -159,6 +164,7 @@ void TcpServer::stop() {
f.get();
}
}
void TcpServer::connectionClosed(const TcpConnectionPtr &connectionPtr) {
LOG_TRACE << "connectionClosed";
// loop_->assertInLoopThread();
@ -179,11 +185,7 @@ const InetAddress &TcpServer::address() const {
return acceptorPtr_->addr();
}
void TcpServer::enableSSL(
const std::string &certPath,
const std::string &keyPath,
bool useOldTLS,
const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
void TcpServer::enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS, const std::vector<std::pair<std::string, std::string> > &sslConfCmds) {
#ifdef USE_OPENSSL
/* Create a new OpenSSL context */
sslCtxPtr_ = newSSLServerContext(certPath, keyPath, useOldTLS, sslConfCmds);

View File

@ -43,10 +43,7 @@
class Acceptor;
class SSLContext;
/**
* @brief This class represents a TCP server.
*
*/
class TcpServer {
protected:
TcpServer(const TcpServer &) = delete;
@ -56,53 +53,18 @@ protected:
TcpServer &operator=(TcpServer &&) noexcept(true) = default;
public:
/**
* @brief Construct a new TCP server instance.
*
* @param loop The event loop in which the acceptor of the server is
* handled.
* @param address The address of the server.
* @param name The name of the server.
* @param reUseAddr The SO_REUSEADDR option.
* @param reUsePort The SO_REUSEPORT option.
*/
TcpServer(EventLoop *loop,
const InetAddress &address,
const std::string &name,
bool reUseAddr = true,
bool reUsePort = true);
TcpServer(EventLoop *loop, const InetAddress &address, const std::string &name, bool reUseAddr = true, bool reUsePort = true);
~TcpServer();
/**
* @brief Start the server.
*
*/
void start();
/**
* @brief Stop the server.
*
*/
void stop();
/**
* @brief Set the number of event loops in which the I/O of connections to
* the server is handled.
*
* @param num
*/
void setIoLoopNum(size_t num) {
assert(!started_);
loopPoolPtr_ = std::make_shared<EventLoopThreadPool>(num);
loopPoolPtr_->start();
}
/**
* @brief Set the event loops pool in which the I/O of connections to
* the server is handled.
*
* @param pool
*/
void setIoLoopThreadPool(const std::shared_ptr<EventLoopThreadPool> &pool) {
assert(pool->size() > 0);
assert(!started_);
@ -110,12 +72,6 @@ public:
loopPoolPtr_->start();
}
/**
* @brief Set the message callback.
*
* @param cb The callback is called when some data is received on a
* connection to the server.
*/
void setRecvMessageCallback(const RecvMessageCallback &cb) {
recvMessageCallback_ = cb;
}
@ -123,12 +79,6 @@ public:
recvMessageCallback_ = std::move(cb);
}
/**
* @brief Set the connection callback.
*
* @param cb The callback is called when a connection is established or
* closed.
*/
void setConnectionCallback(const ConnectionCallback &cb) {
connectionCallback_ = cb;
}
@ -136,12 +86,6 @@ public:
connectionCallback_ = std::move(cb);
}
/**
* @brief Set the write complete callback.
*
* @param cb The callback is called when data to send is written to the
* socket of a connection.
*/
void setWriteCompleteCallback(const WriteCompleteCallback &cb) {
writeCompleteCallback_ = cb;
}
@ -149,53 +93,21 @@ public:
writeCompleteCallback_ = std::move(cb);
}
/**
* @brief Get the name of the server.
*
* @return const std::string&
*/
const std::string &name() const {
return serverName_;
}
/**
* @brief Get the IP and port string of the server.
*
* @return const std::string
*/
const std::string ipPort() const;
/**
* @brief Get the address of the server.
*
* @return const InetAddress&
*/
const InetAddress &address() const;
/**
* @brief Get the event loop of the server.
*
* @return EventLoop*
*/
EventLoop *getLoop() const {
return loop_;
}
/**
* @brief Get the I/O event loops of the server.
*
* @return std::vector<EventLoop *>
*/
std::vector<EventLoop *> getIoLoops() const {
return loopPoolPtr_->getLoops();
}
/**
* @brief An idle connection is a connection that has no read or write, kick
* off it after timeout seconds.
*
* @param timeout
*/
void kickoffIdleConnections(size_t timeout) {
loop_->runInLoop([this, timeout]() {
assert(!started_);
@ -203,23 +115,12 @@ public:
});
}
/**
* @brief Enable SSL encryption.
*
* @param certPath The path of the certificate file.
* @param keyPath The path of the private key file.
* @param useOldTLS If true, the TLS 1.0 and 1.1 are supported by the
* server.
* @param sslConfCmds The commands used to call the SSL_CONF_cmd function in
* OpenSSL.
* @note It's well known that TLS 1.0 and 1.1 are not considered secure in
* 2020. And it's a good practice to only use TLS 1.2 and above.
*/
void enableSSL(const std::string &certPath,
const std::string &keyPath,
bool useOldTLS = false,
const std::vector<std::pair<std::string, std::string> >
&sslConfCmds = {});
// certPath The path of the certificate file.
// keyPath The path of the private key file.
// useOldTLS If true, the TLS 1.0 and 1.1 are supported by the server.
// sslConfCmds The commands used to call the SSL_CONF_cmd function in OpenSSL.
// Note: It's well known that TLS 1.0 and 1.1 are not considered secure in 2020. And it's a good practice to only use TLS 1.2 and above.
void enableSSL(const std::string &certPath, const std::string &keyPath, bool useOldTLS = false, const std::vector<std::pair<std::string, std::string> > &sslConfCmds = {});
private:
EventLoop *loop_;
@ -236,6 +137,7 @@ private:
std::map<EventLoop *, std::shared_ptr<TimingWheel> > timingWheelMap_;
void connectionClosed(const TcpConnectionPtr &connectionPtr);
std::shared_ptr<EventLoopThreadPool> loopPoolPtr_;
#ifndef _WIN32
class IgnoreSigPipe {
public:
@ -247,6 +149,7 @@ private:
IgnoreSigPipe initObj;
#endif
bool started_{ false };
// OpenSSL SSL context Object;